#### Code to train models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src

## Imports

In [None]:
import os
import cv2
import ast
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from collections import Counter
from tqdm.notebook import tqdm

In [None]:
from params import *

In [None]:
from data.preparation import prepare_dataframe, handle_duplicates, add_additional_boxes
from data.dataset import CovidDetDataset, CovidClsDataset
from data.transforms import get_transfos_det, get_transfos_cls

from model_zoo.models import get_model
from model_zoo.encoders import get_encoder

from utils.plot import plot_sample
from utils.boxes import treat_boxes
from utils.logger import prepare_log_folder, save_config, create_logger, update_overall_logs

from training.main import k_fold

## Data

### Load

In [None]:
df = prepare_dataframe().copy()

In [None]:
clusts = np.load("../output/clusts.npy", allow_pickle=True)
found = np.load("../output/found.npy")
transpositions = np.load("../output/transpositions.npy", allow_pickle=True)

df = handle_duplicates(df, clusts, transpositions, plot=False)

In [None]:
df = add_additional_boxes(df)

In [None]:
plt.figure(figsize=(15, 5))
sns.countplot(x="label", hue="study_label", data=df)
plt.show()

plt.figure(figsize=(15, 5))
sns.countplot(x="study_label", hue="label", data=df)
plt.show()

### Dataset

In [None]:
transforms = get_transfos_det(augment=False, bbox_format="yolo")

In [None]:
df_ = df[df['crop_starts'].apply(lambda x: np.max(x) > 500)].reset_index()

In [None]:
dataset = CovidDetDataset(df_, DATA_PATH + f"train_{SIZE}/", bbox_format="yolo") #, transforms=transforms)

In [None]:
for i in np.random.choice(len(dataset), 10):
    img, mask, y, y_img, boxes = dataset[i]
    
    if isinstance(img, torch.Tensor):
        img = img.cpu().numpy().transpose(1, 2, 0)
        mask = mask.cpu().numpy()[:, :, None]

    if len(boxes):
        plt.figure(figsize=(9, 9))
        plot_sample(img, boxes, bbox_format="yolo")
        plt.title(
            f'{df_["save_name"][i][:-4]}  -  Study target : {CLASSES[int(y)]} - '
            f'Img target : {CLASSES_IMG[int(y_img)]}'
        )

## Model

In [None]:
# model = get_encoder('tf_efficientnet_b4_ns')

In [None]:
model = get_model('tf_efficientnetv2_m_in21ft1k', num_classes=4)

In [None]:
transforms = get_transfos_cls(augment=False)
dataset = CovidClsDataset(df, DATA_PATH + f"train_{SIZE}/", transforms=transforms)

x, m, y, y_img = dataset[0]

plt.figure(figsize=(9, 9))
plt.subplot(1, 2, 1)
plt.imshow(x.cpu().numpy().transpose(1, 2, 0))
plt.axis(False)
plt.subplot(1, 2, 2)
plt.imshow(m.cpu().numpy()[:, :, None])
plt.axis(False)

x = x.unsqueeze(0).float()
m = m.unsqueeze(0).float()
y = y.unsqueeze(0)
y_img = y_img.unsqueeze(0)

In [None]:
model.encoder.nb_fts

In [None]:
pred = model(x)

In [None]:
for p in pred:
    try:
        print(p.size())
    except:
        for p_ in p:
            print(' ', p_.size())

In [None]:
from training.losses import CovidLoss
loss = CovidLoss()

In [None]:
y, pred[0]

In [None]:
loss = CovidLoss()
loss(pred[0], pred[1], pred[2], [y, y], y_img, m, 1)

## Training

In [None]:
BATCH_SIZES = {
    "resnext50_32x4d": 16,
    'tf_efficientnetv2_s_in21ft1k': 8,
    'tf_efficientnetv2_m_in21ft1k': 12,
    'tf_efficientnet_b2_ns': 32,
    'tf_efficientnet_b3_ns': 16,
    'tf_efficientnet_b4_ns': 12,
    'tf_efficientnet_b5_ns': 8,
}

In [None]:
class Config:
    """
    Parameters used for training
    """
    # General
    seed = 42
    verbose = 1
    
    size = SIZE
    bbox_format = "yolo"
    root_dir = DATA_PATH + f"train_{SIZE}/"
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    save_weights = True

    # k-fold
    k = 5
    folds_col = "kfold"
    selected_folds = [0, 1, 2, 3, 4]

    # Model
    selected_model = 'tf_efficientnetv2_s_in21ft1k'
    use_unet = False
    num_classes = len(CLASSES)

    # Training
    use_fp16 = False
    samples_per_patient = 1
    optimizer = ["Adam", "RAdam", "Adam"]
    batch_size = BATCH_SIZES[selected_model]
    epochs = [10]#, 5, 5] 

    lr = [1e-3] #, 1e-4, 1e-5]
    warmup_prop = [0.05, 0.5, 0.5]
    val_bs = batch_size * 2

    first_epoch_eval = 0

    mix = "cutmix"
    mix_proba = 0.5
    mix_alpha = 0.4

    name = "model"

In [None]:
DEBUG = True
log_folder = None

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f'Logging results to {log_folder}')
    save_config(Config, log_folder + 'config')
    df.to_csv(log_folder + 'data.csv', index=False)
    create_logger(directory=log_folder, name="logs.txt")

pred_oof_study, pred_oof_img = k_fold(
    Config,
    df,
    df_extra=None,
    log_folder=log_folder
)