# Train Model 

In [None]:
import os
import itertools
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from typing import List
from typing_extensions import override

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, EarlyStopping, RichProgressBar
from lightning.pytorch.loggers import WandbLogger

import subprocess
import timm

from tqdm import tqdm

import monai as mn
from transforms.Transform4ClassifierBase import Transform4ClassifierBase
from models.ClassifierBase import Classifier

SEED = 5566
pl.seed_everything(SEED)
torch.set_float32_matmul_precision('medium')

In [None]:
def get_data_dict_part(df_part):
    "Important! Modify this function"
    
    BASE_PATH =  #edit
    IMG_PATH_COLUMN_NAME = # edit
    
    data_dict = list()
    for i in tqdm(range(len(df_part)), desc="Processing part"):
        row = df_part.iloc[i]

        data_dict.append({
            'img':f'{BASE_PATH}/'+row[f"{IMG_PATH_COLUMN_NAME}"],
            #"label": row[f"{LABEL_COLUMN_NAME}"],
            "paths": f'{BASE_PATH}/'+row[f"{IMG_PATH_COLUMN_NAME}"],
            ** {l: np.array([row[f'{l}']]) for l in LABEL_COLUMN_NAMES}
        })
    
    return data_dict

def get_data_dict(df, num_cores=2):
    parts = np.array_split(df, num_cores)
    func = partial(get_data_dict_part)
    
    with ProcessPoolExecutor(num_cores) as executor:
        data_dicts = executor.map(func, parts)
    
    return list(itertools.chain(*data_dicts))

def split_data(df, group_column, n_splits):
    frac = 1/n_splits
    val_idx = df[group_column].drop_duplicates().sample(frac=frac)
    df_temp = df.set_index(group_column)
    df_val = df_temp.loc[val_idx,:].reset_index()
    df_train = df_temp.drop(index=val_idx).reset_index()
    return df_train, df_val

### Set parameters

In [None]:
# IMPORTANT BEFORE PROCEEDING --> DO YOU WANT TO DELETE CACHE???
DELETE_CACHE = False

INPUT = './Train.csv' # #edit 

TIMM_MODEL = "hf-hub:timm/convnext_base.fb_in22k_ft_in1k"

LABEL_COLUMN_NAMES = #edit ex: ['Pneumothorax','Cardiomegaly']

PROJECT =  #edit 
TEST_NAME =  #edit 
MONAI_CACHE_DIR = f'./cache/{TEST_NAME}' #edit 
IMG_SIZE = 256 #edit 
BATCH_SIZE = 16 #edit 
PRECISION = 'bf16-mixed' 
LEARNING_RATE = 1e-5 #edit 
EPOCHS = 300 #edit 
WEIGHT_PATH = f'./weights/{TEST_NAME}' 

ENTITY =  #edit, wandb id 

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = '1' #edit 
os.environ['WANDB_API_KEY']= '' #edit
os.environ['WANDB_SILENT']='true'

In [None]:
if DELETE_CACHE:
    if os.path.exists(MONAI_CACHE_DIR):
        subprocess.call(['rm', '-rf', f'{MONAI_CACHE_DIR}'])
        print(f"MONAI's {MONAI_CACHE_DIR} cache directory removed successfully!")
    else:
        print(f"MONAI's {MONAI_CACHE_DIR} cache directory does not exist!")

### Read input file

In [None]:
df = pd.read_csv(INPUT)
df

In [None]:
# Split train and val data

PATIENT_ID_COLUMN = #edit ex:'empi_anon'

train_df, val_df = split_data(df, n_splits=10, group_column=PATIENT_ID_COLUMN)

val_df.to_csv(f"val_data_{TEST_NAME}.csv", index=False)

print(len(train_df), len(val_df))

train_dict = get_data_dict(train_df)
val_dict = get_data_dict(val_df)

### Model setup

In [None]:
# define transforms

train_transforms = Transform4ClassifierBase(IMG_SIZE, LABEL_COLUMN_NAMES).train
val_transforms = Transform4ClassifierBase(IMG_SIZE, LABEL_COLUMN_NAMES).val

# define datasets

train_ds = mn.data.PersistentDataset(data=train_dict, transform=train_transforms, cache_dir=f"{MONAI_CACHE_DIR}/train")
val_ds = mn.data.PersistentDataset(data=val_dict, transform=val_transforms, cache_dir=f"{MONAI_CACHE_DIR}/val")

# define data loader

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, persistent_workers=True, num_workers=2, drop_last=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, drop_last=False, persistent_workers=True)

# instantiate the model
model = Classifier(TIMM_MODEL=TIMM_MODEL, num_classes=len(LABEL_COLUMN_NAMES), LEARNING_RATE=LEARNING_RATE, BATCH_SIZE=BATCH_SIZE, use_ema=True)

In [None]:
# SPOT CHECK
test_ds=mn.data.Dataset(data=train_dict, transform=train_transforms)

for _ in range(3):
    random_i = np.random.randint(0, len(test_ds))
    for data_ in test_ds[random_i:random_i+1]:
        
        print(f"{data_['paths']}")
        plt.imshow(np.flipud(np.rot90(np.squeeze(np.array(data_['img'])))), cmap='gray')
        plt.show()

### Define Callbacks

In [None]:
lr_monitor = LearningRateMonitor(logging_interval="step")

checkpoint_callback = ModelCheckpoint(dirpath=f"{WEIGHT_PATH}",
                                    filename=f'{TEST_NAME}_{{epoch}}_{{valid_loss:0.4F}}',
                                    monitor="valid_loss",
                                    mode="min",
                                    save_last=False,
                                    save_top_k=1)

early_stop_callback = EarlyStopping(monitor='valid_loss',
                                    min_delta=0.00001,
                                    patience=5,
                                    verbose=False,
                                    mode='min')

wandb_logger = WandbLogger(save_dir=f"{WEIGHT_PATH}",
                           name=f'{TEST_NAME}',
                           project=PROJECT,
                           entity=ENTITY,
                           offline=False,
                           log_model=False,
                           config={"Creator": "HITI"})

# csv_logger = CSVLogger("logs", name="demo", flush_logs_every_n_steps=10)

progress_bar = RichProgressBar()

### Training

In [None]:
# instantiate trainer

trainer = pl.Trainer(gradient_clip_val=1.0,
                    callbacks=[progress_bar, lr_monitor, checkpoint_callback, early_stop_callback],
                    logger= wandb_logger,
                    precision = PRECISION,
                    accelerator = "gpu",
                    devices=1,
                    log_every_n_steps=1,
                    default_root_dir= WEIGHT_PATH,
                    max_epochs=EPOCHS)

In [None]:
# train the model

trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)