**About** : This notebook is used to train models.


In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.logger import prepare_log_folder, create_logger, save_config

from data.preparation import prepare_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines

from training.main import k_fold
from main_training import BATCH_SIZES

## Data

In [None]:
df = prepare_data()
pipelines = define_pipelines("configs/config_aug.py")

In [None]:
dataset = SartoriusDataset(df, pipelines['val_viz'], precompute_masks=False)

for _ in range(1):
    idx = np.random.choice(len(dataset))
    data = dataset[idx]
    
    img = data['img']
    boxes = data['gt_bboxes']

    plt.figure(figsize=(15, 15))
    plot_sample(img, data['gt_masks'], boxes, plotly=False)
    plt.title(df["sample_id"][idx])
    plt.axis(False)
    plt.show()


In [None]:
dataset = SartoriusDataset(df, pipelines['train_viz'], precompute_masks=False)

for _ in range(1):
    plt.figure(figsize=(15, 15))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        idx = np.random.choice(len(dataset))
        data = dataset[idx]
        plot_sample(data['img'], data['gt_masks'], plotly=False)
        plt.axis(False)

    plt.show()

## Training

In [None]:
class Config:
    """
    Parameters used for training
    """

    # General
    seed = 42
    verbose = 1
    first_epoch_eval = 5
    compute_val_loss = False
    verbose_eval = 5

    device = "cuda" if torch.cuda.is_available() else "cpu"
    save_weights = True

    # Images
    fix = True
    remove_anomalies = True

    extra_name = "livecell_no_shsy5y"
    use_extra_samples = False
    use_pl = True

    num_classes = 3

    data_config = "configs/config_aug.py"

    # k-fold
    split = "gkf"
    k = 5
    random_state = 0
    selected_folds = [0]

    # Model
    name = "cascade"  # "cascade" "maskrcnn"
    encoder = "resnext101"
    model_config = f"configs/config_{name}.py"
    pretrained_livecell = True
    freeze_bn = True

    if name == "htc":
        data_config = "configs/config_aug_semantic.py"

    # Training
    optimizer = "AdamW"
    scheduler = "linear"
    weight_decay = 0.01 if optimizer == "AdamW" else 0
    batch_size = BATCH_SIZES[name][encoder]
    val_bs = batch_size
    loss_decay = True

    epochs = 10 * batch_size
    
    if use_pl or use_extra_samples:
        epochs = epochs // 2

    lr = 2e-4
    warmup_prop = 0.05

    use_fp16 = False  # TODO

In [None]:
DEBUG = True
log_folder = None

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f"Logging results to {log_folder}")
    save_config(Config, log_folder)
    create_logger(directory=log_folder, name="logs.txt")

results = k_fold(Config, log_folder=log_folder)