# Imports

In [None]:
import os
import cv2
import stat
import dill
import wandb
import shutil
import tempfile
import pandas as pd

from pathlib import Path
from fastai.torch_core import set_seed
from fastai.learner import Learner

import project_config as pc
import evaluation_config as ec
from training_config import config as tc

cv2.setNumThreads(0)

# Retrieve dataset

In [None]:
# Init run
run = wandb.init(project=pc.WANDB_PROJECT,
				 entity=pc.WANDB_ENTITY,
				 dir=pc.WANDB_LOCAL_LOGS_PATH,
				 job_type='model_training',
				 config=tc)

# Download latest dataset version (if not already downloaded)
dataset_artifact = run.use_artifact(f'{pc.DATASET_ARTIFACT_NAME}:latest')
dataset_dir = pc.WANDB_LOCAL_ARTIFACTS_PATH+Path(dataset_artifact._default_root()).stem
if not os.path.exists(dataset_dir):
	_ = dataset_artifact.download(root=dataset_dir)

In [None]:
# Read dataframe and add full file paths
df = pd.read_csv(dataset_dir+'/dataset.csv')
df['file_path'] = df['file_path'].apply(lambda x: dataset_dir+'/'+x)
df

# Data setup

In [None]:
set_seed(tc.SEED, reproducible=True)

# Datablock
block = tc.DATABLOCK

# Dataloaders
dls = block.dataloaders(df, bs=tc.BS, shuffle=True)
dls.rng.seed(tc.SEED)

# Sanity check
n_out = dls.c
print('Number of outputs: ', n_out)

In [None]:
# Show train batch
dls.train.show_batch(max_n=16, figsize=(15,12))

In [None]:
# Show transforms
dls.train.show_batch(max_n=16, unique=True, figsize=(15,12))

In [None]:
# Show valid batch
dls.valid.show_batch(max_n=16, figsize=(15,12))

# Model setup

In [None]:
from DLOlympus.fastai.imbalanced import get_class_weights, set_controlled_oversampling

# Apply class weights configuration
for i,c in enumerate(tc.CLASS_WEIGHTS_CONFIGS):
	if c is not None:
		tc.LOSS.loss_functions[i].weights = get_class_weights(dls, **c)

# Create and freeze learner
learn = Learner(dls=dls,
				model=tc.MODEL,
				loss_func=tc.LOSS,
				opt_func=tc.OPTIMIZER,
				splitter=tc.SPLITTER,
				metrics=tc.METRICS,
				wd=tc.WD,).to_fp16()
if tc.PRETRAINED: 
	learn.freeze()

# Apply oversampling configuration
if tc.OVERSAMPLING_LABEL is not None:
	learn = set_controlled_oversampling(learn, col=tc.OVERSAMPLING_LABEL)

# Training

In [None]:
# Find LR
learn.lr_find()

In [10]:
# Set LR
tc.LR = 1e-3

In [None]:
# Train
learn.unfreeze()
learn.fit_one_cycle(tc.EPOCHS, slice(tc.LR/100, tc.LR), pct_start=0.3, div=5.0, cbs=tc.CALLBACKS)

# Evaluation

In [None]:
# Create predictions dataframes and confusion matrices
valid_preds = ec.create_predictions_df(learn, learn.dls.valid, dataset_dir+'/')
train_preds = ec.create_predictions_df(learn, learn.dls.train, dataset_dir+'/')
plt_cms, wandb_cms, names_cms = ec.create_confusion_matrices(valid_preds, dls.vocab)

# Logs

In [13]:
def remove_readonly(func, path, _):
    "Clear the readonly bit and reattempt the removal"
    os.chmod(path, stat.S_IWRITE)
    func(path)

In [None]:
# Log final metrics
names = [m.name for m in learn.metrics]
values = learn.validate()[1:]
for n,v in zip(names,values):
    run.summary[n] = v

# Log model
learn.export('models/model.pkl', pickle_module=dill)
run.log_model('models/model.pkl', 'model')
shutil.rmtree('models', onexc=remove_readonly)

# Log wandb confusion matrices
for cm, n in zip(wandb_cms, names_cms):
	wandb.log({n: cm})

In [None]:
# Create evaluation artifact
evaluation_artifact = wandb.Artifact('evaluation', type='evaluation')

# Save dataframes and plt confusion matrices to temporary files and add them to the artifact
with tempfile.TemporaryDirectory() as temp_dir:
	valid_preds.to_csv(temp_dir+'/valid_preds.csv', index=False)
	train_preds.to_csv(temp_dir+'/train_preds.csv', index=False)
	evaluation_artifact.add_file(temp_dir+'/valid_preds.csv', name='valid_preds.csv')
	evaluation_artifact.add_file(temp_dir+'/train_preds.csv', name='train_preds.csv')
	for cm, n in zip(plt_cms, names_cms):
		cm.savefig(temp_dir+f'/{n}.png', bbox_inches='tight')
		evaluation_artifact.add_file(temp_dir+f'/{n}.png', name=f'{n}.png')

# Create and log wandb table referencing local files
table = wandb.Table(columns=['image']+list(valid_preds.columns.values))
for _, row in valid_preds.iterrows():
	local_path = dataset_dir+'/'+ row['file_path']
	table.add_data(
		wandb.Image(local_path),
		*row.values
	)
evaluation_artifact.add(table, 'evaluation_table')

# Log artifact and finish run
run.log_artifact(evaluation_artifact)
run.finish()