**About** : This notebook is used to generate the livecell data, and pretrain models on it.

In [None]:
# %load_ext nb_black
%load_ext autoreload
%autoreload 2

In [None]:
cd ../src/

## Initialization

### Imports

In [None]:
import os
import gc
import ast
import cv2
import glob
import json
import torch
import warnings
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from pycocotools.coco import COCO

warnings.simplefilter("ignore", UserWarning)

In [None]:
from params import *

from utils.plots import *
from utils.logger import prepare_log_folder, create_logger, save_config

from data.preparation import prepare_extra_data
from data.dataset import SartoriusDataset
from data.transforms import define_pipelines

from main_training import BATCH_SIZES
from training.main import pretrain

## Preprocessing

In [None]:
annotations = glob.glob(DATA_PATH + "LIVECell_dataset_2021/annotations/LIVECell_single_cells/*/*.json")
annotations = []  # do not recompute

In [None]:
SHSY5Y_ONLY = False
NO_SHSY5Y = True
SINGLE_CLASS = False

name = "livecell.csv"
classes = LIVECELL_CLASSES

if SHSY5Y_ONLY:
    annotations = [a for a in annotations if "shsy5y" in a]
    name = "livecell_shsy5y.csv"
elif NO_SHSY5Y:
    annotations = [a for a in annotations if "shsy5y" not in a]
    name = "livecell_no_shsy5y.csv"
    classes = ['', '', ''] + LIVECELL_CLASSES[:-1]

if SINGLE_CLASS:
    assert NO_SHSY5Y
    name = "livecell_no_shsy5y_single.csv"

In [None]:
metas = []
for path in tqdm(annotations):
    filename = path.split('/')[-1]
    _, cell_type, split = filename.split('.')[0].split('_')
    print(f"\n -> Processing {cell_type}_{split}")
    annots = json.load(open(path, 'r'))
    
    annots["annotations"] = list(annots["annotations"].values())
    coco = COCO()
    coco.dataset = annots
    coco.createIndex()

    cell_index = -1 if SHSY5Y_ONLY else classes.index(cell_type.lower())
    cell_index = 3 if SINGLE_CLASS else cell_index
    
    for image in annots['images']:
        
        boxes, rles = [], []
        for annot in coco.anns.values():
            if annot['image_id'] == image['id']:
                rles.append(coco.annToRLE(annot))
                
                box = np.array(annot["bbox"])
                box[2] += box[0]
                box[3] += box[1]
                boxes.append(box)
        
        meta = {
            'filename': image['file_name'],
            'width': image['width'],
            'height': image['height'],
            'cell_type': cell_type,
            'split': split,
            'ann': {
                'bboxes': np.array(boxes).astype(int).tolist(),
                'labels': [cell_index] * len(boxes),
                'masks': rles
            }
        }
        metas.append(meta)
        
#     break

if len(metas):
    meta_df = pd.DataFrame.from_dict(metas)
    meta_df.to_csv(OUT_PATH + name, index=False)
    
    print(f' -> Saved to "{OUT_PATH + name}"')
    
    sns.countplot(x=meta_df['cell_type'])

## Data

In [None]:
class Config:
    """
    Parameters used for training
    """
    # Images
    use_mosaic = False
    use_tta = False  # TODO
    data_config = "configs/config_aug_mosaic.py" if use_mosaic else "configs/config_aug.py"

In [None]:
# df = prepare_extra_data(name="livecell")
df = prepare_extra_data(name="livecell_shsy5y")

In [None]:
pipelines = define_pipelines(Config.data_config)

In [None]:
dataset = SartoriusDataset(df, pipelines['val_viz'], precompute_masks=False)

In [None]:
for idx in range(10):
    idx = np.random.choice(len(dataset))

    data = dataset[idx]

    img = data['img']
    boxes = data['gt_bboxes']

    plt.figure(figsize=(15, 15))
    plot_sample(img, data['gt_masks'], plotly=False)
    plt.title(df['cell_folder'][idx])
    plt.axis(False)
    plt.show()
    
    break

## Training

In [None]:
class Config:
    """
    Parameters used for training
    """
    # General
    seed = 42
    verbose = 1
    first_epoch_eval = 10
    compute_val_loss = False
    verbose_eval = 5

    device = "cuda" if torch.cuda.is_available() else "cpu"
    save_weights = True

    # Images
    fix = True
    extra_name = ""
    use_extra_samples = False
    num_classes = 8

    use_mosaic = False
    use_pl = False
    data_config = "configs/config_aug_mosaic.py" if use_mosaic else "configs/config_aug.py"

    # k-fold
    k = 50
    random_state = 0
    selected_folds = [0]

    # Model
    name = "cascade"  # "cascade"
    encoder = "resnext50"
    model_config = f"configs/config_{name}.py"
    pretrained_livecell = False
    
    if name == "htc":
        data_config = "configs/config_aug_semantic.py"

    # Training
    optimizer = "AdamW"  # "Adam"
    scheduler = "linear"
    weight_decay = 0.01  # "0"
    batch_size = BATCH_SIZES[name][encoder]
    val_bs = batch_size
    freeze_bn = batch_size < 3
    loss_decay = False

    epochs = 2 * batch_size

    lr = 3e-4
    warmup_prop= 0.05

    use_fp16 = False  # TODO

In [None]:
DEBUG = False
log_folder = None
LOG_PATH = "../logs/pretrain/"

In [None]:
if not DEBUG:
    log_folder = prepare_log_folder(LOG_PATH)
    print(f"Logging results to {log_folder}\n")
    save_config(Config, log_folder)
    create_logger(directory=log_folder, name="logs.txt")

results = pretrain(Config, log_folder=log_folder)