# Training a model from ACDC dataset

**Authors :**
* _Louis Lacroix_
* _Benjamin Ternot_

## I. Importing Libraries and Global Settings

In [None]:
import datetime
import gc
import os
from functools import partial

import psutil

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import Adam

from data_manager.datamanager import DataLoader
from models.model import Unet
from models.modeltrainer import ModelPreprocessor, DiffusionModelTrainer, Diffusion
from utils.utils import VerboseLevel

In [None]:
# Parameters to use for the preprocessing
IMAGE_SIZE=64
PADDING=0.2
IMAGE_NAMES=["ED_gt", "ES_gt"]
LINK_GT_TO_DATA=False
KEEP_3D_CONSISTENCY=False
CHANNELS=4
RESCALE_OUTPUT_KEY="rescaled_image_data"
MAX_ANGLE=45
NB_ROTATIONS=7
VERBOSE=VerboseLevel.DISPLAY
# VerboseLevel.NONE to avoid outputs
# VerboseLevel.TQDM to use tqdm progress bars
# VerboseLevel.PRINT to print information
# VerboseLevel.DISPLAY to display images

# Execution parameters
LIBERATE_MEMORY=True
CUDA_DEVICE = 0

# Parameters for the model training
SAVE_MODEL = True
SAVE_INTERMEDIATE_MODELS = {"toggle": False, "frequency": 20}

# Model parameters
BATCH_SIZE = 16
EPOCHS = 100
T = 1000
DIM_MULTS = (1, 2, 4, 8)

In [None]:
# Modifier les couleurs des textes et des axes en fonction du thème de Jupyter
DARK_BG = False

if DARK_BG:
    plt.rcParams['text.color'] = 'black'
    plt.rcParams['axes.labelcolor'] = 'white'
    plt.rcParams['xtick.color'] = 'white'
    plt.rcParams['ytick.color'] = 'white'
    plt.rcParams['axes.titlecolor'] = 'white'
else:
    plt.rcParams['text.color'] = 'black'
    plt.rcParams['axes.labelcolor'] = 'black'
    plt.rcParams['xtick.color'] = 'black'
    plt.rcParams['ytick.color'] = 'black'
    plt.rcParams['axes.titlecolor'] = 'black'

## II. Data Loading and Preprocessing

In [None]:
# Define the root data folder
root_data_folder = os.path.join(os.path.dirname(os.getcwd()), 'database')

# Define the sub path to the folders containing the data
data_sub_folders = {
    "train": "training",
    "test": "testing",
}

# Define the mapping from group labels to diagnostic classes
group_map = {
    "NOR": "Healthy control",
    "MINF": "Myocardial infarction",
    "DCM": "Dilated cardiomyopathy",
    "HCM": "Hypertrophic cardiomyopathy",
    "RV": "Abnormal right ventricle"
}

In [None]:
# Create the data loader
data_loader = DataLoader(root_folder=root_data_folder)

# Load the data
for key, sub_folder in data_sub_folders.items():
    data_loader.load_data(sub_folder, name=key, store=True, verbose=VERBOSE)

# Create the model trainer
model_preprocessor = ModelPreprocessor(data_loader, group_map)

In [None]:
# Preprocess the data
preprocessed_data = model_preprocessor.preprocess_data(
    target_shape=(IMAGE_SIZE, IMAGE_SIZE),
    padding=PADDING,
    image_names=IMAGE_NAMES,
    link_gt_to_data=LINK_GT_TO_DATA,
    keep_3d_consistency=KEEP_3D_CONSISTENCY,
    create_channels_from_gt=CHANNELS>1,
    rescale_output_key=RESCALE_OUTPUT_KEY,
    max_angle=MAX_ANGLE,
    nb_rotations=NB_ROTATIONS,
    verbose=VERBOSE
)

In [None]:
# Liberate memory if needed
def get_memory_usage():
    """Return the memory usage in MB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1e6

if LIBERATE_MEMORY:
    memory_before = get_memory_usage()
    del data_loader
    del model_preprocessor
    gc.collect()
    memory_after = get_memory_usage()
    if VERBOSE >= VerboseLevel.PRINT:
        print(f"Memory usage before: {memory_before/1000:.4f} GB"
              f"\nMemory usage after: {memory_after/1000:.4f} GB"
              f"\nMemory liberated: {memory_before - memory_after:.2f} MB"
        )

## III. Model Training

In [None]:
# Define the root data folder
ROOT_RES_FOLDER = os.path.join(os.path.dirname(os.getcwd()), 'resources')

# Saving paths
current_datetime = datetime.datetime.now()

IMAGES_SAVE_FOLDER = os.path.join(
    ROOT_RES_FOLDER,
    "images",
    f"{current_datetime.strftime('%Y-%m-%d-%H-%M')}_" + (f'{CHANNELS}-channels_' if CHANNELS>1 else '1-channel'),
    "training"
)
IMAGES_SAVE_PATH = os.path.join(IMAGES_SAVE_FOLDER, "{}")

MODEL_SAVE_FOLDER = os.path.join(
    ROOT_RES_FOLDER,
    "trained_models",
    f"{current_datetime.strftime('%Y-%m-%d-%H-%M')}_" + (f'{CHANNELS}-channels' if CHANNELS>1 else '1-channel')
)
MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_FOLDER, "{}_unet.pt")
PARAMS_SAVE_PATH = os.path.join(MODEL_SAVE_FOLDER, "params.txt")

# Create folders to save the models
os.makedirs(IMAGES_SAVE_FOLDER, exist_ok=True)
os.makedirs(MODEL_SAVE_FOLDER, exist_ok=True)

In [None]:
# Write parameters to the file
SAVE_PARAMS = {
    "IMAGE_SIZE": IMAGE_SIZE,
    "CHANNELS": CHANNELS,
    "BATCH_SIZE": BATCH_SIZE,
    "EPOCHS": EPOCHS,
    "T": T,
    "DIM_MULTS": DIM_MULTS
}
with open(PARAMS_SAVE_PATH, "w") as file:
    for key, value in SAVE_PARAMS.items():
        file.write(f"{key} = {value}\n")
if VERBOSE >= VerboseLevel.PRINT:
    print(f"Parameters saved to '{PARAMS_SAVE_PATH}'")

In [None]:
device = torch.device(f"cuda:{CUDA_DEVICE}" if torch.cuda.is_available() else "cpu")

In [None]:
# Define the model
model = Unet(
    dim=IMAGE_SIZE,
    init_dim=None,
    out_dim=None,
    dim_mults=DIM_MULTS,
    channels= CHANNELS,
    with_time_emb=True,
    convnext_mult=2,
).to(device)
model_trainer = DiffusionModelTrainer(
    data_set=preprocessed_data,
    val_split=0.1,
    model=model,
    batch_size=BATCH_SIZE,
    criterion=nn.SmoothL1Loss(),
    optimizer=Adam(model.parameters(), lr=1e-4),
    device=device,
    verbose=VERBOSE,
    image_filename=IMAGES_SAVE_PATH
)

In [None]:
# Train the model
constants_scheduler = partial(Diffusion.cosine_beta_schedule, s=0.008)
losses_history=model_trainer.train(
    epochs=EPOCHS,
    timesteps=T,
    constants_scheduler=constants_scheduler,
    save_model_path=MODEL_SAVE_PATH,
    save_images_path=IMAGES_SAVE_PATH,
    save_intermediate_models=SAVE_INTERMEDIATE_MODELS,
    verbose=VERBOSE
)