## Info
This notebook is designed for model training. The experimental configurations are loaded from the `config.yaml` file, and the training setup is initialized accordingly.

During the training process:
- Metrics are logged using `TensorBoardLogger` and saved under the specified `output_path`.
- The training configuration file (`config.yaml`) is copied to the output directory for reference.
- Model weights are also saved in the output directory after training.

The setup ensures that key information related to training is easily accessible and logged for future analysis and model comparison.

In [None]:
import datetime
import os
import random
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as tt

from data import ClassificationDataset
from data import Transforms as T
from models import Evaluator, GMIC, GMICLoss, Trainer
from visualization import TensorboardLogger
from utils import Config

### Experiment configuration

In [None]:
# For reproducibility, set the random seeds for torch and the Python random module.
torch.manual_seed(0)
random.seed(0)

# Prefix the experiment name with today’s date.
date = str(datetime.date.today())
date = date.replace('-', '_')

# Path to save the experiment logs and outputs.
output_path = '../../models/Experiment1/{}_HospitalA'.format(date)
weight_path = os.path.join(output_path, "weights/")

In [None]:
# Load the configuration as a Python object for easy access to parameters.
# For example, the batch size can be accessed using dot notation:
# print(cfg.data.batch_size)
# > 8

cfg_path = 'config.yaml'
cfg = Config(cfg_path)

### Dataset

In [None]:
# Define the transformations used for training and validation.
# The 'pytorch' key specifies transformations from the official PyTorch library,
# while the 'dicom' key refers to custom transformations developed for mammography images.
transform_train = {'dicom': None, 'pytorch': None}
transform_val = {'dicom': None, 'pytorch': None}

transform_train['dicom'] = [# T.FlipToLeft(), T.CropBreastRegion(),
                            # T.Resize(height=cfg.data.inp_height, width=cfg.data.inp_width),
                            T.UIntToFloat32(), T.StandardScoreNormalization(),
                            T.RandomGaussianNoise(mean=.0, std=.005)]

transform_train['pytorch'] = tt.Compose([tt.RandomHorizontalFlip(p=0.5),
                                         tt.RandomRotation([-15, +15]),
                                         tt.RandomAffine(degrees=0, translate=(0,0.1), shear=(-25, +25)),
                                         tt.RandomResizedCrop((cfg.data.inp_height, cfg.data.inp_width), scale=(0.8, 1.6))])

transform_val['dicom'] = [# T.FlipToLeft(), T.CropBreastRegion(),
                          # T.Resize(height=cfg.data.inp_height, width=cfg.data.inp_width),
                          T.UIntToFloat32(), T.StandardScoreNormalization()]

# Add data augmentations to the config file for logging and reproducibility during experiments.
cfg['data']['transforms'] = f"training={[str(transform) for transform in transform_train['dicom']]}" + \
                            ' * ' + f"{[str(transform) for transform in transform_train['pytorch'].transforms]}" + \
                            ' | ' + f"validation={[str(transform) for transform in transform_val['dicom']]}"

# Create dataset objects using the Classification class.
# Each dataset returns breast_id, image, label, and optionally a domain_label.
# The metadata of the dataset object can also be viewed using `print(dataset.metadata)`.
train = ClassificationDataset(metadata_path=cfg.data.train_xlsx_path, transform=transform_train)
val = ClassificationDataset(metadata_path=cfg.data.val_xlsx_path, transform=transform_val)

### Dataloader

In [None]:
# Define dataloaders for training and validation sets.
train_loader = DataLoader(train, batch_size=cfg.data.batch_size, 
                          shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val, batch_size=cfg.data.batch_size, 
                        shuffle=True, num_workers=4, pin_memory=True)

### Model

In [None]:
# Create the Tensorboard object to save experiment outputs.
tb = TensorboardLogger(output_path)

In [None]:
# Initialize the model object.
model = GMIC(cfg.gmic_parameters)

In [None]:
# If pretrained model weights are defined in the config, load them.
if cfg.model.weight_path:
    weights = torch.load(cfg.model.weight_path, map_location=torch.device('cpu'))
    # Skip loading the 'shared_rep_filter' key by setting strict=False.
    model.load_state_dict(weights, strict=False)
    print('Model weights are loaded!')
# Send the model to the device specified in `config.yaml`, either 'cuda' or 'cpu'.
model = model.to(cfg.gmic_parameters.device_type)

# Log the model architecture to TensorBoard's graph.
tb.add_graph(model, (1, cfg.data.inp_height, cfg.data.inp_width), cfg.gmic_parameters.device_type)
tb.flush()

### Training Setup

In [None]:
# Set the loss function (criterion) and optimizer.
criterion = GMICLoss(beta=cfg.train.beta)
optimizer = optim.Adam(model.parameters(), lr=cfg.train.lr, weight_decay=0.001)

# Log the loss function and optimizer along with their parameters for reproducibility.
cfg['train']['LossFunction'] = str(criterion)
cfg['train']['Optimizer'] = str(optimizer)

In [None]:
# Initialize the trainer object. The training strategy and loop are defined within this class
# based on the objects created above.
trainer = Trainer(criterion=criterion, model=model, optimizer=optimizer, 
                  total_epochs=cfg.train.epoch, data_loader=train_loader)

# Initialize the evaluator object. The evaluation strategy and loop are defined within this class
# based on the objects created above.
evaluator = Evaluator(model=model, data_loader=val_loader)

### Training

In [None]:
# Create the weight folder along with the output folder if they do not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

# Save the `config.yaml` file into the experiment folder.
# This file can be used for evaluating and investigating the experiment later.
cfg.save(os.path.join(output_path, 'config.yaml'))

# The 'config.to_markdown()' function adds the `config.yaml` context as markdown in TensorBoard. 
# You can view it under the Text tab in TensorBoard.
tb.add_text('HyperParameters', cfg.to_markdown())
tb.flush()

In [None]:
# Initialize the starting point for the PR AUC (Precision-Recall AUC).
prev_pr_auc = .0

In [None]:
# This is the highest level of the training loop, defining the number of epochs.
# At each epoch, the model is trained for one iteration, and the trained model is evaluated.
# Collected metrics are logged to TensorBoard, and the model weights are saved to the output folder.

for epoch in range(0, cfg.train.epoch):
    # Exclusively assign the current epoch to the trainer object.
    # This is used to monitor the progress bar and is also helpful for schedulers
    # that need to track the current epoch during training.
    trainer.curr_epoch = epoch
    # Before iterating over the training dataset, get the learning rate for logging purposes.
    # Schedulers may update the learning rate after the training loop, 
    # which could cause it to reflect an incorrect value for the current epoch.
    # To avoid this inconsistency, retrieve the learning rate before calling `fit()`
    curr_lr = optimizer.param_groups[0]['lr']
    train_metrics = trainer.fit()
    
    # Add results to the Tensorboard.
    tb.add_scalars(step=epoch+1, lr=curr_lr, train_loss=train_metrics['total_loss'],
                  roc_auc=train_metrics['roc']['auc'], pr_auc=train_metrics['pr']['auc'], data_split='Train')
    tb.flush()
    
    # Evaluate the model on the validation set. 
    # It is configured to evaluate at the end of every epoch with a frequency of '1'.
    if epoch % 1 == 0:
        val_metrics = evaluator.evaluate()
        tb.add_scalars(step=epoch+1, roc_auc=val_metrics['roc']['auc'], 
                       pr_auc=val_metrics['pr']['auc'], data_split='Val')
        tb.flush()

    # Save the currently trained model as 'last_model'.
    # If any issues occur during training, update 'prev_pr_auc' with the last PR-AUC score
    # update to current epoch in 'range(0, cfg.train.epoch)'
    # and 'cfg.model.weight_path' with the last saved model. 
    # Then, rerun the notebook.
    trainer.save_model('{0}/last_model.pth'.format(weight_path))
    # If the current model yields a better PR-AUC, save it as 'best_model.pth' in the `weights` folder.
    if val_metrics['pr']['auc'] > prev_pr_auc:
        prev_pr_auc = val_metrics['pr']['auc']
        trainer.save_model('{0}/best_model.pth'.format(weight_path))
    # Save the model as a checkpoint every 10 epochs.
    if epoch % 10 == 0:
        trainer.save_model('{0}/{1}_model.pth'.format(weight_path, epoch))
            
# Save the last model too.
trainer.save_model('{0}/last_model.pth'.format(weight_path))
tb.close()

In [None]:
# Release GPU memory after training ends. 
# Note that this may not release all GPU memory.
torch.cuda.empty_cache()