# Sampling from a model trained on ACDC dataset

**Authors :**
* _Louis Lacroix_
* _Benjamin Ternot_

## I. Importing Libraries and Global Settings

In [None]:
import datetime
import gc
import os
from functools import partial

import psutil

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import Adam

from data_manager.datamanager import DataLoader
from models.model import Unet
from models.modeltrainer import ModelPreprocessor, DiffusionModelTrainer, Diffusion
from utils.utils import VerboseLevel

In [None]:
# Parameters to use for loading the model
MODEL_DATETIME = "" # Date and time of the model format : "YYYY-MM-DD_HH-MM-SS"
MODEL_NAME = "" # Name of the model (e.g. 'best-epoch-46')
MODEL_LOAD_PATH = f"models/trained_models/{MODEL_DATETIME}_4-channels_{MODEL_NAME}_unet.pt"
PARAMS_LOAD_PATH = f"models/parameters/{MODEL_DATETIME}_4-channels_params.txt"

VERBOSE=VerboseLevel.DISPLAY
# VerboseLevel.NONE to avoid outputs
# VerboseLevel.TQDM to use tqdm progress bars
# VerboseLevel.PRINT to print information
# VerboseLevel.DISPLAY to display images

# Execution parameters
CUDA_DEVICE = 0

In [None]:
# Load parameters from file
loaded_params = {}
with open(PARAMS_LOAD_PATH, "r") as file:
    for line in file:
        # Strip any whitespace and split key-value pairs
        key, value = line.strip().split(" = ")
        # Convert the value back to its original type (e.g., int, float, list, etc.)
        try:
            loaded_params[key] = eval(value)
        except:
            loaded_params[key] = value

# Assign loaded parameters back to your variables
IMAGE_SIZE = loaded_params.get("IMAGE_SIZE")
MULTI_CHANNEL = loaded_params.get("MULTI_CHANNEL")
BATCH_SIZE = loaded_params.get("BATCH_SIZE")
EPOCHS = loaded_params.get("EPOCHS")
T = loaded_params.get("T")
DIM_MULTS = loaded_params.get("DIM_MULTS")

print(f"Parameters loaded from '{PARAMS_LOAD_PATH}'\n")

# Display loaded parameters
print("Loaded Parameters:")
for key, value in loaded_params.items():
    print(f"- {key}: {value} {type(value)}")

## II. Loading the model

In [None]:
DEVICE = torch.device(f"cuda:{CUDA_DEVICE}" if torch.cuda.is_available() else "cpu")

In [None]:
model = Unet(
        dim=IMAGE_SIZE,
        init_dim=None,
        out_dim=None,
        dim_mults=DIM_MULTS,
        channels=4 if MULTI_CHANNEL else 1,
        with_time_emb=True,
        convnext_mult=2,
    ).to(DEVICE)
model.load_state_dict(torch.load(MODEL_LOAD_PATH, weights_only=True, map_location=torch.device(DEVICE)))
model.eval()