# SET UP DDPM_Conditional

In [None]:
!pip3 install --upgrade pip
!pip3 install einops
!pip3 install transformers
!pip3 install sentencepiece

In [None]:
import os
import math
import torch
import torchvision
import copy
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
from einops.layers.torch import Rearrange
from PIL import Image
from torch.utils.data import DataLoader
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# connect to Google Drive and unzip the folder for training;
!unzip "./drive/MyDrive/{dataset_folder}.zip"

In [None]:
# setting the parameters
epochs = 100
batch_size = 32
image_size = 64
num_classes = 21
dataset_path = "./{dataset_folder}/"
lr = 3e-4
time_embed_dim = 256
path_drive = "./drive/My Drive/{model_folder}/"

In [None]:
# Setting reproducibility
SEED = 3
torch.manual_seed(SEED)

In [None]:
# this flag if true loads the model from {path_drive};TRAIN ONLY;
EXIST_MODEL = False

In [None]:
# getting device;
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU"))

# Utils for DDPM_Conditional

In [None]:
def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)


# get the data and transforms the images, then pass the dataset to the DataLoader
def get_data():
    transforms = torchvision.transforms.Compose([
        # data augmentation
        torchvision.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        torchvision.transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = torchvision.datasets.ImageFolder(dataset_path, transform=transforms)
    #print(dataset.class_to_idx)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# return a similar label to the label passed as argument, from the list of class_labels know in the dataset train folder;
# in this way is possible to condition the model to generate image from a text/label not present in the class_labels of the train dataset;
# CLASS_LABEL_TRAIN contain the label for every conditional label used in the training, in my case these are the 21 labels of dataset_3 using in the training of the model;

# if you have to test the model use this:
CLASS_LABEL_TRAIN = ['00000', '00001', '00010', '00011', '00100', '00110', '01000', '01010', '01011', '01100', '01110', '10000', '10001', '10010', '10011', '10100', '10110', '11000', '11010', '11100', '11110']

# else use this for train the model with another dataset and sets of labels:
#CLASS_LABEL_TRAIN = []

def get_similar_class(label):

    if len(CLASS_LABEL_TRAIN) != 0:
        dirs = CLASS_LABEL_TRAIN

    else:
        dirs = os.listdir(dataset_path)
        dirs.sort()

    similar_list = []
    for i in range(len(dirs)):
        if label == dirs[i]:
            similar_list.append(i)

    dirs_copy = dirs[:]
    while len(similar_list) == 0 and len(label) > 1:
        truncated_list = []
        for s in dirs_copy:
            s_truncated = s[:-1]
            truncated_list.append(s_truncated)
        label = label[:-1]
        for i in range(len(truncated_list)):
            if label == truncated_list[i]:
                similar_list.append(i)
        dirs_copy = truncated_list
    j = similar_list[0]

    return j



# the Data Loader load the label of any class in the format 0 to num_classes, for this reason i have to rebind any label to the index, for example label "00111" to class index "3";
# next label to text convert text from the label, in this way i have a description text for any class;TRAIN ONLY;
def get_label_index(labels):
    text = []
    descriptions = []
    dirs = os.listdir(dataset_path)
    dirs.sort()
    labels = labels.detach().cpu().numpy()

    for i in labels:
        text.append(dirs[i])

    for j in text:
        descriptions.append(label_to_text(j))


    return descriptions

# get the text for n sample label; this method is used to generate image from the text;
def get_text_for_sample(label,n):
    descriptions = []
    for i in range(n):
        text = label_to_text(label)
        descriptions.append(text)
    return descriptions

FEATURE_LABEL = ['frangetta','occhiali','barba','sorridente','giovane']

# transform the label to the respective text, to pass later to text encoder;
def label_to_text(label):
    text = ''

    for i in range(len(FEATURE_LABEL)):
        if label[i] == '1':
            text += FEATURE_LABEL[i]+','

    text+='viso'+','
    s = text.split(',')[::-1]
    l = []
    for i in s:
        # appending reversed words to l;
        l.append(i)
        # printing reverse words;
    l.remove('')
    text = " ".join(l)

    return text

In [None]:
#import matplotlib.pyplot as plt

# plot the loss function for the training;

#def plot_loss(losses):
#    plt.plot(losses)
#    plt.title('Loss Graph')
#    plt.xlabel('Epoch')
#    plt.ylabel('Loss')
#    plt.show()
#    plt.savefig("Loss")

# T5 Text Encoder
Load a pretrained Text encoder model

In [None]:
from einops import rearrange
from transformers import T5Tokenizer, T5EncoderModel

MAX_LENGTH = 256

DEFAULT_T5_NAME = 't5_small'

# Variants: https://huggingface.co/docs/transformers/model_doc/t5v1.1. 1.1 versions must be finetuned.
T5_VERSIONS = {
    't5_small': {'tokenizer': None, 'model': None, 'handle': 't5-small', 'dim': 512, 'size': .24},
    't5_base': {'tokenizer': None, 'model': None, 'handle': 't5-base', 'dim': 768, 'size': .890},
    't5_large': {'tokenizer': None, 'model': None, 'handle': 't5-large', 'dim': 1024, 'size': 2.75},
    't5_3b': {'tokenizer': None, 'model': None, 'handle': 't5-3b', 'dim': 1024, 'size': 10.6},
    't5_11b': {'tokenizer': None, 'model': None, 'handle': 't5-11b', 'dim': 1024, 'size': 42.1},
    'small1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-small', 'dim': 512, 'size': .3},
    'base1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-base', 'dim': 768, 'size': .99},
    'large1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-large', 'dim': 1024, 'size': 3.13},
    'xl1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-xl', 'dim': 2048, 'size': 11.4},
    'xxl1.1': {'tokenizer': None, 'model': None, 'handle': 'google/t5-v1_1-xxl', 'dim': 4096, 'size': 44.5},
}

# Fast tokenizers: https://huggingface.co/docs/transformers/main_classes/tokenizer
def _check_downloads(name):
    if T5_VERSIONS[name]['tokenizer'] is None:
        T5_VERSIONS[name]['tokenizer'] = T5Tokenizer.from_pretrained(T5_VERSIONS[name]['handle'])
    if T5_VERSIONS[name]['model'] is None:
        T5_VERSIONS[name]['model'] = T5EncoderModel.from_pretrained(T5_VERSIONS[name]['handle'])


def t5_encode_text(text, name: str = 't5_small', max_length=MAX_LENGTH):
    """
    Encodes a sequence of text with a T5 text encoder.
    :param text: List of text to encode.
    :param name: Name of T5 model to use. Options are:
        - :code:`'t5_small'` (~0.24 GB, 512 encoding dim),
        - :code:`'t5_base'` (~0.89 GB, 768 encoding dim),
        - :code:`'t5_large'` (~2.75 GB, 1024 encoding dim),
        - :code:`'t5_3b'` (~10.6 GB, 1024 encoding dim),
        - :code:`'t5_11b'` (~42.1 GB, 1024 encoding dim),
    :return: Returns encodings and attention mask. Element **[i,j,k]** of the final encoding corresponds to the **k**-th
        encoding component of the **j**-th token in the **i**-th input list element.
    """
    _check_downloads(name)
    tokenizer = T5_VERSIONS[name]['tokenizer']
    model = T5_VERSIONS[name]['model']

    # Move to cuda is available
    if torch.cuda.is_available():
        device = torch.device('cuda')
        model = model.to(device)
    else:
        device = torch.device('cpu')

    # Tokenize text
    tokenized = tokenizer.batch_encode_plus(
        text,
        padding='longest',
        max_length=max_length,
        truncation=True,
        return_tensors="pt",  # Returns torch.tensor instead of python integers
    )

    input_ids = tokenized.input_ids.to(device)
    attention_mask = tokenized.attention_mask.to(device)

    model.eval()

    # Don't need gradient - T5 frozen during Imagen training
    with torch.no_grad():
        t5_output = model(input_ids=input_ids, attention_mask=attention_mask)
        final_encoding = t5_output.last_hidden_state.detach()

    # Wherever the encoding is masked, make equal to zero
    final_encoding = final_encoding.masked_fill(~rearrange(attention_mask, '... -> ... 1').bool(), 0.)

    return final_encoding, attention_mask.bool()


def get_encoded_dim(name: str) -> int:
    """
    Gets the encoding dimensionality of a given T5 encoder.
    """
    return T5_VERSIONS[name]['dim']

# Unet Architecture
256 time embedding dimension;
4 levels depth;
4 cross attention;
512 bottleneck;

In [None]:
# perform the exponential moving average optimization;
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())


# Q = image; K = V = text embedding;
class CrossAttention(nn.Module):
    def __init__(self, channels, size):
        super(CrossAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x, te):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, te, te)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value

        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=time_embed_dim):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=time_embed_dim):
        super().__init__()

        self.ct = nn.ConvTranspose2d(in_channels,out_channels,kernel_size=3,padding = 1)
        self.up = nn.Upsample(scale_factor= 2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.ct(x)
        x = self.up(x)

        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

In [None]:
class UNet_conditional(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=time_embed_dim, num_classes=None, device=device):
            super().__init__()
            self.device = device
            self.time_dim = time_dim

            # text encoder return text embeddings of 512, this Linear layers rearrange the dimension to use the text embeddings in the cross attention layers;
            self.to_time_tokens_128 = nn.Sequential(
                nn.Linear(self.time_dim, 128 *2),
                Rearrange('b (r d) -> b r d', r=2)
            )
            self.to_time_tokens_256 = nn.Sequential(
                nn.Linear(self.time_dim, 256 *2),
                Rearrange('b (r d) -> b r d', r=2)
            )

            self.te_to_dim_128 = nn.Linear(512, 128)
            self.te_to_dim_256 = nn.Linear(512, 256)

            self.norm_128 = nn.LayerNorm(128)
            self.norm_256 = nn.LayerNorm(256)

            self.inc = DoubleConv(c_in, 32)
            self.down1 = Down(32, 64)
            self.down2 = Down(64, 128)
            self.ca2 = CrossAttention(128,16)
            self.down3 = Down(128, 256)
            self.ca3 = CrossAttention(256, 8)
            self.down4 = Down(256,512)
            self.bot1 = DoubleConv(512, 512)
            self.up1 = Up(512, 256)
            self.ca4 = CrossAttention(256, 8)
            self.up2 = Up(256, 128)
            self.ca5 = CrossAttention(128,16)
            self.up3 = Up(128, 64)
            self.up4 = Up(64, 32)
            self.outc = nn.Conv2d(32, c_out, kernel_size=1)

            if num_classes is not None:
                self.label_emb = nn.Embedding(num_classes, time_dim)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    # generate a toke of the time embeddings to cancatenate it to the text embedding later
    def generate_t_tokens(self,t,cond_dim):
        if cond_dim == 256:
            time_tokens = self.to_time_tokens_256(t)
        elif cond_dim == 128:
            time_tokens = self.to_time_tokens_128(t)
        elif cond_dim == 64:
            time_tokens = self.to_time_tokens_64(t)

        return time_tokens

    # concat time embedding and text embeddings
    def concat_time_text(self,te,t,out_dim):
        tt = self.generate_t_tokens(t,out_dim)
        te_t = torch.cat((te,tt),dim=-2)
        if out_dim == 256:
            te_t = self.norm_256(te_t)
        elif out_dim == 128:
            te_t = self.norm_128(te_t)
        elif out_dim == 64:
            te_t = self.norm_64(te_t)
        return te_t

    def reduce_dim(self,te,dim_out):
        if dim_out == 256:
            te = self.te_to_dim_256(te)
        elif dim_out == 128:
            te = self.te_to_dim_128(te)
        elif dim_out == 64:
            te = self.te_to_dim_64(te)
        return te

    def forward(self, x, t, y,te):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)
        if y is not None:
            t += self.label_emb(y)


        te_256 = self.reduce_dim(te,256)
        te_128 = self.reduce_dim(te,128)

        tte_256 = self.concat_time_text(te_256,t,256)
        tte_128 = self.concat_time_text(te_128,t,128)


        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x3 = self.ca2(x3,tte_128)
        x4 = self.down3(x3, t)
        x4 = self.ca3(x4,tte_256)
        x5 = self.down4(x4,t)
        x6 = self.bot1(x5)
        x = self.up1(x6,x4, t)
        x = self.ca4(x,tte_256)
        x = self.up2(x, x3, t)
        x = self.ca5(x,tte_128)
        x = self.up3(x, x2, t)
        x = self.up4(x, x1, t)
        output = self.outc(x)
        return output

# Define the model : train and sample

Every 20 epochs there is a checkpoint;
N.B. The files will be saved on drive

In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, device=device):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device


    # cosine schedule;
    def betas_cosine(self,num_diffusion_timesteps, alpha_bar, max_beta=0.999):
        betas = []
        for i in range(num_diffusion_timesteps):
            t1 = i / num_diffusion_timesteps
            t2 = (i + 1) / num_diffusion_timesteps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        betas = torch.tensor(betas)
        return betas

    # if cosine is False return le linear noising schedule;
    def prepare_noise_schedule(self,cosine = False):
        if cosine:
            beta = self.betas_cosine(self.noise_steps,lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,)
            return beta
        else :
            return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    # noising the image at timestep t;
    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    # sampling;
    def sample(self, model, n, labels, cfg_scale=3):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                te_labels = get_label_index(labels)
                te = t5_encode_text(te_labels)
                te = te[0].to(self.device)
                predicted_noise = model(x, t, labels,te)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None,te)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

    # text to image generation, this module in used for Generate the final images;
    def generate_image_from_label(self,model,n,label):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            y_index = get_similar_class(label)
            y = torch.Tensor([y_index] * n).long().to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                te_labels = get_text_for_sample(label,n)
                te = t5_encode_text(te_labels)
                te = te[0].to(self.device)
                predicted_noise = model(x, t, y,te)

                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x



# Start the train loop

In [None]:
# train the model and every 20 epochs sample and plot 21 image for every class_label;
def train():
    dataloader = get_data()
    model = UNet_conditional(num_classes=num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=image_size, device=device)
    l = len(dataloader)
    ema = EMA(0.995)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)

    if EXIST_MODEL :
        print("Recover model...")
        model.load_state_dict(torch.load(os.path.join(path_drive+"models",  f"ckpt.pt")))
        model.eval()
        ema_model.load_state_dict(torch.load(os.path.join(path_drive+"models",  f"ema_ckpt.pt")))
        ema_model.eval()
        optimizer.load_state_dict(torch.load(os.path.join(path_drive+"models",  f"optim.pt")))

    losses = []
    best_loss = float("inf")
    for epoch in range(epochs+1):
        print("Starting epoch :"+str(epoch))
        epoch_loss = 0.0
        pbar = tqdm(dataloader)
        for i, (images, labels) in enumerate(pbar):
            te_labels = get_label_index(labels)
            te = t5_encode_text(te_labels)
            te = te[0].to(device)
            images = images.to(device)
            labels = labels.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            if np.random.random() < 0.1:
                labels = None
            predicted_noise = model(x_t, t, labels,te)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ema.step_ema(ema_model, model)

            pbar.set_postfix(MSE=loss.item())
            epoch_loss += loss.item() * len(images) / len(dataloader.dataset)
        
        losses.append(epoch_loss)
        log_string = f"Loss at epoch {epoch}: {epoch_loss:.3f}"
        print(log_string)

        # storing the model;
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), os.path.join(path_drive+"models",  f"ckpt.pt"))
            torch.save(ema_model.state_dict(), os.path.join(path_drive+"models",  f"ema_ckpt.pt"))
            torch.save(optimizer.state_dict(), os.path.join(path_drive+"models",  f"optim.pt"))
            print(" --> Best model ever (stored)")

        if epoch!= 0 and epoch % 20 == 0:
            #plot_loss(losses)
            labels = torch.arange(21).long().to(device)
            #sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
            ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
            plot_images(ema_sampled_images)
            #save_images(sampled_images, os.path.join(path_drive+"results",  f"{epoch}.jpg"))
            save_images(ema_sampled_images, os.path.join(path_drive+"results",  f"{epoch}_ema.jpg"))

In [None]:
train()

# Generate 500 new images
This module load the model selected by {path_model} and then generate image following the "test.txt" labels

In [None]:
path_model = "./drive/My Drive/{model_folder}/"
path_test_generate = "./drive/My Drive/{test_folder}/"
path_test_text = "./drive/My Drive/test.txt"

In [None]:
import io


def parse(s):
    s = s.strip()
    return s
def separate(s):
    p = s.split(';')
    return p

def prepare_model():
    device = "cuda"
    model = UNet_conditional(num_classes=21).to(device)
    ckpt = torch.load(os.path.join(path_model+"models",  f"ema_ckpt.pt"))
    model.load_state_dict(ckpt)
    diffusion = Diffusion(img_size=64, device=device)

    return diffusion,model

def generate(name,label,diffusion,model,iter = 10,n = 1):
    s = 'ABCDEFGHIJ'

    for i in range (iter):
        x = diffusion.generate_image_from_label(model, n, label)
        name_iter =name + '_' + s[i]
        save_images(x, os.path.join(path_test_generate, f"{name_iter}.jpg"))

    print("Generate images for label "+ label + " name "+name)


def start_generate_new_image():
    with io.open(path_test_text,"r") as f:
        data = f.readlines()

    parsed_data = []
    for i in data:
        parsed_data.append(parse(i))

    diffusion,model = prepare_model()

    for x in parsed_data:
        s = separate(x)
        file_name = ""+s[0]
        file_e = ""+s[1]
        generate(file_name,file_e,diffusion,model)

In [None]:
start_generate_new_image()