# Sampling from a model trained on ACDC dataset

**Authors :**
* _Louis Lacroix_
* _Benjamin Ternot_

## I. Importing Libraries and Global Settings

In [None]:
import datetime
from functools import partial
import os

import torch

from data_manager.datamanager import DataDisplayer
from models.model import Unet
from models.modeltrainer import DiffusionModelSampler, Diffusion
from utils.utils import VerboseLevel

In [None]:
# Define the root resources folder
ROOT_RES_FOLDER = os.path.join(os.path.dirname(os.getcwd()), 'resources')

# Parameters to use for loading the model
MODEL_DATETIME = "" # Date and time of the model format : "YYYY-mm-dd-HH-MM"
MODEL_NAME = "" # Name of the model (e.g. 'best-epoch-46')
MODEL_LOAD_PATH = os.path.join(ROOT_RES_FOLDER, "trained_models",f"{MODEL_DATETIME}_4-channels",f"{MODEL_NAME}_unet.pt")
PARAMS_LOAD_PATH = os.path.join(ROOT_RES_FOLDER, "trained_models",f"{MODEL_DATETIME}_4-channels",f"params.txt")

current_datetime = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")

# Parameters for saving images
SAVE_IMAGES_FOLDER = os.path.join(ROOT_RES_FOLDER, "images",f"{MODEL_DATETIME}_4-channels","sampling",MODEL_NAME,current_datetime)
SAVE_IMAGES_PATH = os.path.join(SAVE_IMAGES_FOLDER, "{}")

# Create folders to save the images
os.makedirs(SAVE_IMAGES_FOLDER, exist_ok=True)

VERBOSE=VerboseLevel.DISPLAY
# VerboseLevel.NONE to avoid outputs
# VerboseLevel.TQDM to use tqdm progress bars
# VerboseLevel.PRINT to print information
# VerboseLevel.DISPLAY to display images

# Execution parameters
CUDA_DEVICE = 0

In [None]:
DEVICE = torch.device(f"cuda:{CUDA_DEVICE}" if torch.cuda.is_available() else "cpu")

## II. Loading the model

In [None]:
diffusion_model_sampler = DiffusionModelSampler(
    path_params=PARAMS_LOAD_PATH,
    path_model=MODEL_LOAD_PATH,
    device=DEVICE,
    model_class=Unet,
    model_params={'init_dim':None, 'out_dim':None, 'with_time_emb':True, 'convnext_mult':2},
    constants_scheduler=partial(Diffusion.cosine_beta_schedule, s=0.008),
    verbose=VERBOSE
)

In [None]:
generated_images = diffusion_model_sampler.sample_images()

In [None]:
# Display the generated images
DataDisplayer.display_batch(
    batch=generated_images[-1],
    show=VERBOSE >= VerboseLevel.DISPLAY,
    filename= SAVE_IMAGES_PATH.format("generated-batch.jpg"),
    title=f"Generated sample from model\n{MODEL_DATETIME}-{MODEL_NAME}\n",
    one_hot_encode=False
)
DataDisplayer.display_batch(
    batch=generated_images[-1],
    show=VERBOSE >= VerboseLevel.DISPLAY,
    filename= SAVE_IMAGES_PATH.format("generated-batch-one-hot.jpg"),
    title=f"Generated sample from model\n{MODEL_DATETIME}-{MODEL_NAME}\nOne-hot encoded\n",
    one_hot_encode=True
)

In [None]:
# Display the gif of the generated images
DataDisplayer.make_gif(
    frame_list=generated_images,
    filename=SAVE_IMAGES_PATH.format("generated.gif.png"),
    step=len(generated_images)//50,
    one_hot_encode=False,
    verbose=VERBOSE
)

# Display the gif of the generated images one-hot encoded
DataDisplayer.make_gif(
    frame_list=generated_images,
    filename=SAVE_IMAGES_PATH.format("generated-one-hot.gif.png"),
    step=len(generated_images)//50,
    one_hot_encode=True,
    verbose=VERBOSE
)