In [None]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

import torch
import matplotlib.pyplot as plt
import numpy as np

import himyb.models.ddpmpp as ddpmpp
import himyb.training.save_load as save_load
import himyb.models.preconditioning as preconditioning
import himyb.sampler.sampler as sampler
import himyb.sampler.generate as generate
import himyb.misc_utils as misc_utils

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
root_dir = ... # set this to the root directory of the model save
device = torch.device("cuda")

#load the configurations
config_filename = ... #Â set this to the config filename
states_filename = ... # set this to the config filename
dataset_conf, model_conf, optim_config, cur_img = save_load.load_training_configs(os.path.join(root_dir, "checkpoints", config_filename))

# create the model
internal_model = ddpmpp.DDPMPP(**model_conf)
model = preconditioning.EDMPrecond(
        model=internal_model,
).eval().to(device)

#load the weights of the model
save_load.load_training_state(file_name=os.path.join(root_dir, "checkpoints", states_filename), model=model)

In [None]:
def sample_n_imgs(
    n_imgs_per_class, 
    num_steps, 
    model, 
    device, 
    guidance_scale=0.0, 
    use_unconditional_model=True, 
    s_churn=0, 
    sample_uncond=False,
    return_class_labels=False
):
    """
    Sample a specified number of images per class from the model.
    Args:
        n_imgs_per_class (int): Number of images to sample per class.
        num_steps (int): Number of sampling steps.
        model (torch.nn.Module): The model to sample from.
        device (torch.device): The device to run the sampling on.
        guidance_scale (float): Guidance scale of CFG
        use_unconditional_model (bool): Whether to use the unconditional model to obtain the unconditional score in CFG
        s_churn (float): S_churn parameter for the stochastic sampler (controls the variance of the fresh noise added during sampling,\
            S_churn=0 means deterministic sampling, S_churn>0 means stochastic sampling)
        sample_uncond (bool): Whether to sample from the unconditional class (class 0).
        return_class_labels (bool): Whether to return the class labels along with the generated images.
    """
    if sample_uncond:
        batch_size = n_imgs_per_class * model.label_dim
        class_labels = torch.arange(batch_size, device=device, dtype=torch.long) // n_imgs_per_class
    else:
        batch_size = n_imgs_per_class * (model.label_dim-1)
        class_labels = (torch.arange(batch_size, device=device, dtype=torch.long) // n_imgs_per_class)+1
    shape = (batch_size, model.in_channels, model.img_resolution, model.img_resolution)
    result = sampler.stoch_edm_sampler(
        model=model,
        class_labels=class_labels, 
        shape=shape, 
        s_churn=s_churn, 
        s_min = 0.01,
        s_max = 80,
        s_noise = 1.003,
        num_steps=num_steps, 
        return_history=False, 
        device=device, 
        guidance_scale=guidance_scale, 
        use_unconditional_model=use_unconditional_model)
    if return_class_labels:
        return result["generated_imgs"], class_labels
    return result["generated_imgs"]

In [None]:
def plot_imgs(imgs, n_rows, n_cols, scale_f = 1.) :
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*scale_f, n_rows*scale_f))
    for i, ax in enumerate(axes.flatten()):
        if i >= len(imgs):
            break
        with misc_utils.DisableImshowWarning():
            ax.imshow((imgs[i].permute(1, 2, 0).cpu().numpy()+1)/2)
        ax.axis('off')
    plt.tight_layout()

In [None]:
n_imgs = 80
n_steps = 10
imgs, cl_labels = sample_n_imgs(
    n_imgs_per_class=n_imgs, 
    num_steps=n_steps, 
    model=model, 
    device=device, 
    guidance_scale=0., 
    use_unconditional_model=False, 
    s_churn=0,
    sample_uncond=False,
    return_class_labels=True
)

# 0 is the unconditional class when we give the tokens to the model 
# so to obtain the class labels as they are in the dataset we need to subtract 1
cl_labels -= 1 

In [None]:
plot_imgs(imgs, 8,int(n_imgs/4), 2)
plt.show();

In [None]:
def save_imgs(imgs, class_labels, folder) :
    os.makedirs(folder, exist_ok=True)
    cur_idx = len(os.listdir(folder))
    for i, img in enumerate(imgs):
        img = ((img.permute(1, 2, 0).cpu().numpy()+1)/2).clip(0, 1)
        plt.imsave(os.path.join(folder, f"img_{i+cur_idx:05d}_{int(class_labels[i].item())}.png"), img)

In [None]:
SAVE_FOLDER = os.path.join(root_dir, "generated_imgs")
save_imgs(imgs, cl_labels, SAVE_FOLDER)
print(f"Saved generated images to {SAVE_FOLDER}")