**TODO :**
- Recheck Augmentations
- Recheck LAB normalization
- sampler for faster convergence ?

In [1]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

## Initialization

### Imports

In [2]:
import os
import sys
import torch
import zipfile
import numpy as np
import pandas as pd
import plotly.express as px

from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

sys.path.append("../code/")

In [3]:
from params import *

from data.transforms import HE_preprocess
from data.dataset import TileDataset

from model_zoo.models import define_model

from training.main import k_fold
from utils.logger import (
    prepare_log_folder,
    save_config,
    create_logger,
    update_overall_logs,
)

from utils.plots import plot_contours

from params import DATA_PATH, OUT_PATH



### Load

In [4]:
df_info = pd.read_csv(DATA_PATH + f"HuBMAP-20-dataset_information.csv")
df_mask = pd.read_csv(DATA_PATH + "train.csv")
df = pd.read_csv(OUT_PATH + "df_images.csv")

## Model

## Training

In [5]:
BATCH_SIZES = {
    "resnet18": 64,
    "resnet34": 32, 
    "resnext50_32x4d": 32, 
    "se_resnext50_32x4d": 32,
    "efficientnet-b4": 32,
    "efficientnet-b5": 16,
    "efficientnet-b6": 8,
}

In [6]:
class Config:
    """
    Parameters used for training
    """
    
    # General
    seed = 42
    verbose = 1
    img_dir = IMG_PATH
    mask_dir = MASK_PATH
    device = "cuda" if torch.cuda.is_available() else "cpu"
    save_weights = True
    iter_per_epoch = 5000 #10000
    
    # Image size
    train_tile_size = 256
    reduce_factor = 4
    on_spot_sampling = 0.9

    # k-fold
    cv_column = "5fold"
    random_state = 0
    selected_folds = [0, 1, 2, 3, 4]  # [0]

    # Model
    encoder = "efficientnet-b5"  # "resnet18" "resnext50_32x4d", "resnet34", "efficientnet-b5"
    decoder = "Unet"  # "Unet", "DeepLabV3Plus"
    encoder_weights = "imagenet"
    num_classes = 1

    # Training
    loss = "BCEWithLogitsLoss"  # "SoftDiceLoss" / "BCEWithLogitsLoss"  / "lovasz"
    activation = "none" if loss == "lovasz" else "sigmoid"

    optimizer = "Adam"
    batch_size = BATCH_SIZES[encoder]
    
         
    if batch_size == 32:
        epochs = 40
    elif batch_size == 16:
        epochs = 30
    elif batch_size == 8:
        epochs = 20
        
    if train_tile_size == 512:
        batch_size = int(batch_size/4)

    lr = 1e-3
    swa_first_epoch = 50

    warmup_prop = 0.05
    val_bs = batch_size * 2

    first_epoch_eval = 0

    # Inference
    overlap_factor = 2

In [7]:
DEBUG = False
log_folder = None

In [8]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f"Logging results to {log_folder}")
    config_df = save_config(Config, log_folder + "config.json")
    df.to_csv(log_folder + "data.csv", index=False)
    create_logger(directory=log_folder, name="logs.txt")

metrics = k_fold(Config, df, log_folder=log_folder)

Logging results to ../logs/2021-03-28/16/
Creating in memory dataset (once)...
Took 257.946297 seconds.

-------------   Fold 1 / 5  -------------

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth" to /root/.cache/torch/checkpoints/efficientnet-b5-b6417697.pth


HBox(children=(FloatProgress(value=0.0, max=122410125.0), HTML(value='')))


    -> 31216001 trainable parameters

train tiles : ['8242609fa', 'aaa6a05cc', 'cb2d976f4', 'b9a3865fc', '0486052bb', 'e79de561c', '095bf7a1f', '54f2eec69', '26dc41664', 'c68fe75ea', 'afa5e8098', '1e2425f28']
valid tiles : ['2f6ecfcdf', 'b2dc8411c', '4ef6695ce']
Epoch 01/30 	 lr=6.7e-04	 t=126s	loss=0.531	val_loss=0.114 	 dice=0.7727
Epoch 02/30 	 lr=9.8e-04	 t=127s	loss=0.052	val_loss=0.028 	 dice=0.8462
Epoch 03/30 	 lr=9.5e-04	 t=128s	loss=0.027	val_loss=0.014 	 dice=0.9290
Epoch 04/30 	 lr=9.1e-04	 t=128s	loss=0.021	val_loss=0.011 	 dice=0.9170
Epoch 05/30 	 lr=8.8e-04	 t=128s	loss=0.019	val_loss=0.011 	 dice=0.9241
Epoch 06/30 	 lr=8.4e-04	 t=128s	loss=0.018	val_loss=0.010 	 dice=0.9359
Epoch 07/30 	 lr=8.1e-04	 t=128s	loss=0.016	val_loss=0.011 	 dice=0.9312
Epoch 08/30 	 lr=7.7e-04	 t=128s	loss=0.016	val_loss=0.012 	 dice=0.9251
Epoch 09/30 	 lr=7.4e-04	 t=128s	loss=0.015	val_loss=0.010 	 dice=0.9352
Epoch 10/30 	 lr=7.0e-04	 t=128s	loss=0.014	val_loss=0.014 	 dice=0.9261
Epoch 