# Model Training

In [None]:
import sys
sys.path.append('../')

In [None]:
import os
import torch
import numpy as np
import pandas as pd

from tqdm import tqdm

from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter

from monai.data import DataLoader
from monai.networks.nets import DenseNet
from monai.utils import set_determinism

from src.data.dataset import BrainMriDataset
from src.data.transforms import Transforms
from src.utils.meter import AverageMetricsMeter
from src.utils.model import calculate_metrics, save_model
from src.utils.log import log_metrics

In [None]:
print('PyTorch Version:', torch.__version__)
print('Is CUDA Available:', torch.cuda.is_available())

In [None]:
DEVICE        = 'cuda' if torch.cuda.is_available() else 'cpu'
DATASET_CSV   = '../data/processed/dataset_nifti.csv'
OUTPUT_PATH   = '../models/'
LOGS_PATH     = '../logs/'
NUM_WORKERS   = 8
BATCH_SIZE    = 16
EPOCHS        = 130
LEARNING_RATE = 1e-4
WEIGHT_DECAY  = 1e-5

In [None]:
assert os.path.exists(DATASET_CSV)
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [None]:
np.random.seed(seed=1234)
torch.random.manual_seed(seed=1234)
set_determinism(seed=1234)

In [None]:
dataset = pd.read_csv(DATASET_CSV)
train_data = dataset[dataset['split'] == 'train']
valid_data = dataset[dataset['split'] == 'valid']

In [None]:
train_transform = Transforms.get_data_loading()
valid_transform = Transforms.get_data_loading()

train_dataset = BrainMriDataset(
    dataset_df=train_data,
    transform=train_transform
)

valid_dataset = BrainMriDataset(
    dataset_df=valid_data,
    transform=valid_transform
)

In [None]:
class_counts = train_data['diagnosis'].value_counts()
class_weights = 1.0 / class_counts
sample_weights = train_data['diagnosis'].map(class_weights)

In [None]:
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    persistent_workers=True, 
    pin_memory=True,
    sampler=WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
)

valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    persistent_workers=True,
    pin_memory=True,
    shuffle=False
)

loaders = {
    'train': train_loader,
    'valid': valid_loader
}

In [None]:
criterion = BCEWithLogitsLoss()
model = DenseNet(spatial_dims=3, in_channels=1, out_channels=1, dropout_prob=0.2).to(DEVICE)
optimizer = AdamW(params=model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [None]:
best_valid_loss = float('inf')
meter = AverageMetricsMeter()

for epoch in range(EPOCHS):
    for mode in ['train', 'valid']:
    
        meter.reset()
        model.train() if mode == 'train' else model.eval()
        description = 'Epoch [%d] in [%s]' % (epoch, mode)
    
        for batch in tqdm(loaders[mode], description):
            with torch.set_grad_enabled(mode == 'train'):
                
                with torch.autocast(DEVICE):
                    y_true = batch['label'].to(DEVICE).float().unsqueeze(1)
                    y_pred = model(batch['image'].to(DEVICE))
                
                loss, performance = calculate_metrics(y_true, y_pred, criterion, DEVICE)
                meter.add(loss.item(), performance, len(batch))

                if mode == 'train':
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

        with SummaryWriter(LOGS_PATH) as writer:
            log_metrics(writer, mode, epoch, meter.loss_value(), meter.performance_value())

    torch.cuda.empty_cache()
    if meter.loss_value() < best_valid_loss:
        best_valid_loss = meter.loss_value()
        save_model(model, optimizer, OUTPUT_PATH)