In [None]:
from fastai.callback.mixup import *
from fastai.vision.all import *

from src.learner import get_learner_task2
from src.utils import (
    load_configuration,
    create_submission,
    save_clean_labels,
    do_fit
)

### ! Reproducibility is endured by **get_learnertask2** - it sets the seed for the learner  

In [None]:
import warnings
warnings.filterwarnings('ignore')

# Initial train + Cleaning 

Train a ResNet50 based classifier on the noisy dataset (using 224x224 image resolution).
 - First train only the final linear layers using flat cos policy
 - Unfreeze and train the entire network using the one cycle policy
 - Use the trained model in order to detect errneous labels and clean the dataset

In [None]:
config  = load_configuration('configs/config_task2_224.yml')
learn = get_learner_task2(config)
cbs = [MixUp()]

In [None]:
do_fit(learn, 'task2_resnet50', epochs=15, lr=1e-3, pct_start=0.75, cbs=cbs)

In [None]:
learn.unfreeze()

In [None]:
do_fit(learn, 'task2_resnet50_unfrozen', epochs=3, lr=1e-5, fit_type='one_cycle')

In [None]:
%time save_clean_labels(learn, config)

# Training with cleaned labels 

In [None]:
config  = load_configuration('configs/config_task2_224_clean.yml')
learn = get_learner_task2(config)
cbs = [MixUp()]

The cleaned dataset is expected to have **~40.5k** samples (80% the original size)

In [None]:
learn.dls.train.n + learn.dls.valid.n

In [None]:
do_fit(learn, 'task2_resnet50_clean', epochs=20, lr=1e-3, pct_start=0.75, cbs=cbs)

In [None]:
learn.unfreeze()

In [None]:
do_fit(learn, 'task2_resnet50_unfrozen_clean', epochs=10,  lr=1e-5, pct_start=0.75, 
       cbs=cbs, save_state_dict=True)

# Creating submission 

In [None]:
%%time

create_submission(
    path_learn='task2_resnet50_unfrozen_clean.pkl',
    path_test_images='data/task2/val_data',
    submission_name='task2.csv'
)