# Training a model from ACDC dataset

**Authors :**
* _Louis Lacroix_
* _Benjamin Ternot_

## I. Importing Libraries and Global Settings

In [1]:
import datetime
import gc
import os
import psutil

import matplotlib.pyplot as plt

from data_manager.datamanager import DataLoader
from models.modeltrainer import ModelPreprocessor
from utils.utils import VerboseLevel

In [2]:
# Parameters to use for the preprocessing
IMAGE_SIZE=128
PADDING=0.2
IMAGE_NAMES=["ED_gt", "ES_gt"]
LINK_GT_TO_DATA=False
KEEP_3D_CONSISTENCY=False
MULTI_CHANNEL=True
RESCALE_OUTPUT_KEY="rescaled_image_data"
MAX_ANGLE=45
NB_ROTATIONS=7
VERBOSE=VerboseLevel.PRINT
# VerboseLevel.NONE to avoid outputs
# VerboseLevel.TQDM to use tqdm progress bars
# VerboseLevel.PRINT to print information
# VerboseLevel.DISPLAY to display images

# Execution parameters
LIBERATE_MEMORY=True
CUDA_DEVICE = 0

# Parameters for the model training
SAVE_MODEL = True
SAVE_INTERMEDIATE_MODELS = {"toggle": True, "frequency": 20}

# Model parameters
BATCH_SIZE = 16
EPOCHS = 100
T = 1000
DIM_MULTS = (1, 2, 4, 8)

In [3]:
# Modifier les couleurs des textes et des axes en fonction du thème de Jupyter
DARK_BG = True

if DARK_BG:
    plt.rcParams['text.color'] = 'white'
    plt.rcParams['axes.labelcolor'] = 'white'
    plt.rcParams['xtick.color'] = 'white'
    plt.rcParams['ytick.color'] = 'white'
    plt.rcParams['axes.titlecolor'] = 'white'
else:
    plt.rcParams['text.color'] = 'black'
    plt.rcParams['axes.labelcolor'] = 'black'
    plt.rcParams['xtick.color'] = 'black'
    plt.rcParams['ytick.color'] = 'black'
    plt.rcParams['axes.titlecolor'] = 'black'

## II. Data Loading and Preprocessing

In [4]:
# Define the root data folder
root_data_folder = os.path.join(os.path.dirname(os.getcwd()), 'database')

# Define the sub path to the folders containing the data
data_sub_folders = {
    "train": "training",
    "test": "testing",
}

# Define the mapping from group labels to diagnostic classes
group_map = {
    "NOR": "Healthy control",
    "MINF": "Myocardial infarction",
    "DCM": "Dilated cardiomyopathy",
    "HCM": "Hypertrophic cardiomyopathy",
    "RV": "Abnormal right ventricle"
}

In [5]:
# Create the data loader
data_loader = DataLoader(root_folder=root_data_folder)

# Load the data
for key, sub_folder in data_sub_folders.items():
    data_loader.load_data(sub_folder, name=key, store=True, verbose=VERBOSE)

# Create the model trainer
model_preprocessor = ModelPreprocessor(data_loader, group_map)

Loading data in 'C:\Users\benji\Documents\Git-repositories\Telecom-Paris\3A\PRIM-AI-Diffusion-Models-for-Cardi…

Loading data in 'C:\Users\benji\Documents\Git-repositories\Telecom-Paris\3A\PRIM-AI-Diffusion-Models-for-Cardi…

In [6]:
# Preprocess the data
preprocessed_data = model_preprocessor.preprocess_data(
    target_shape=(IMAGE_SIZE, IMAGE_SIZE),
    padding=PADDING,
    image_names=IMAGE_NAMES,
    link_gt_to_data=LINK_GT_TO_DATA,
    keep_3d_consistency=KEEP_3D_CONSISTENCY,
    create_channels_from_gt=MULTI_CHANNEL,
    rescale_output_key=RESCALE_OUTPUT_KEY,
    max_angle=MAX_ANGLE,
    nb_rotations=NB_ROTATIONS,
    verbose=VERBOSE
)

Transforming images in 'train':   0%|          | 0/100 [00:00<?, ?it/s]

Transforming images in 'test':   0%|          | 0/50 [00:00<?, ?it/s]

data_loader.data
├── train
│	├── patient001
│	│	├── image_data
│	│	│	├── ED
│	│	│	├── ED_gt
│	│	│	├── ES
│	│	│	├── ES_gt
│	│	├── height
│	│	├── weight
│	│	├── group
│	│	├── nb_frames
│	│	├── rescaled_image_data
│	│	│	├── ED_gt
│	│	│	├── ES_gt
│	├── patient002
│	│	├── image_data
│	│	│	├── ED
│	│	│	├── ED_gt
│	│	│	├── ES
│	│	│	├── ES_gt
│	│	├── height
│	│	├── weight
│	│	├── group
│	│	├── nb_frames
│	│	├── rescaled_image_data
│	│	│	├── ED_gt
│	│	│	├── ES_gt
│	├── patient003
│	│	├── image_data
│	│	│	├── ED
│	│	│	├── ED_gt
│	│	│	├── ES
│	│	│	├── ES_gt
│	│	├── height
│	│	├── weight
│	│	├── group
│	│	├── nb_frames
│	│	├── rescaled_image_data
│	│	│	├── ED_gt
│	│	│	├── ES_gt
│	├── patient004
│	│	├── image_data
│	│	│	├── ED
│	│	│	├── ED_gt
│	│	│	├── ES
│	│	│	├── ES_gt
│	│	├── height
│	│	├── weight
│	│	├── group
│	│	├── nb_frames
│	│	├── rescaled_image_data
│	│	│	├── ED_gt
│	│	│	├── ES_gt
│	├── patient005
│	│	├── image_data
│	│	│	├── ED
│	│	│	├── ED_gt
│	│	│	├── ES
│	│	│	├── ES_gt
│	│	├── height


Extracting images in 'train':   0%|          | 0/100 [00:00<?, ?it/s]

Extracting images in 'test':   0%|          | 0/50 [00:00<?, ?it/s]

Rotating images:   0%|          | 0/7 [00:00<?, ?it/s]

Number of images after rotation: 2100


In [7]:
# Liberate memory if needed
def get_memory_usage():
    """Return the memory usage in MB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1e6

if LIBERATE_MEMORY:
    memory_before = get_memory_usage()
    del data_loader
    del model_preprocessor
    gc.collect()
    memory_after = get_memory_usage()
    if VERBOSE >= VerboseLevel.PRINT:
        print(f"Memory usage before: {memory_before/1000:.4f} GB"
              f"\nMemory usage after: {memory_after/1000:.4f} GB"
              f"\nMemory liberated: {memory_before - memory_after:.2f} MB"
        )

Memory usage before: 5264.88 MB
Memory usage after: 5239.44 MB
Memory liberated: 25.43 MB


## III. Model Training

In [None]:
# Saving paths
current_datetime = datetime.datetime.now()

IMAGES_SAVE_PATH = (f"images/{current_datetime.strftime('%Y-%m-%d-%H-%M')}_"
                    + ('4-channels' if MULTI_CHANNEL else '1-channel')
                    + "{}")
MODEL_SAVE_PATH = (f"models/trained_models/{current_datetime.strftime('%Y-%m-%d-%H-%M')}_"
                   + ('4-channels' if MULTI_CHANNEL else '1-channel')
                   + "{}")
PARAMS_SAVE_PATH = (f"models/parameters/{current_datetime.strftime('%Y-%m-%d-%H-%M')}_"
                    + ('4-channels' if MULTI_CHANNEL else '1-channel') +
                    "_params.txt")

In [None]:
# Write parameters to the file
SAVE_PARAMS = {
    "IMAGE_SIZE": IMAGE_SIZE,
    "MULTI_CHANNEL": MULTI_CHANNEL,
    "BATCH_SIZE": BATCH_SIZE,
    "EPOCHS": EPOCHS,
    "T": T,
    "DIM_MULTS": DIM_MULTS
}
with open(PARAMS_SAVE_PATH, "w") as file:
    for key, value in SAVE_PARAMS.items():
        file.write(f"{key} = {value}\n")
if VERBOSE >= VerboseLevel.PRINT:
    print(f"Parameters saved to '{PARAMS_SAVE_PATH}'")