# Generating Images of Digits from Text Prompts

This notebook provides you with a complete code example to generate MNIST digits from text prompts. The `train_text_prompts.txt` and `test_text_prompts.txt` files provide the text data used for training and testing.

## Loading the MNIST Dataset with PyTorch

In [None]:
import os

if not os.path.exists("MNIST_dataset"):
    os.system("git clone https://github.com/DeepTrackAI/MNIST_dataset")

train_path = os.path.join("MNIST_dataset", "mnist", "train")
train_images_files = sorted(os.listdir(train_path))

print(len(train_images_files))

Implement the normalization of the images ...

In [2]:
from torchvision.transforms import Compose, Normalize, ToTensor

trans = Compose([ToTensor(), Normalize(mean=[0.5], std=[0.5], inplace=True)])

...create the dataset...

In [3]:
import json
import matplotlib.pyplot as plt


MNIST_images, MNIST_labels, MNIST_sentences = [], [],[]

# Load mnist_sentences.json dictionary
with open('mnist_sentences.json', 'r') as file:
    mnist_sentences = json.load(file)

for file in train_images_files:
    image = plt.imread(os.path.join(train_path, file))
    MNIST_images.append(trans(image))
    
    filename = os.path.basename(file)
    label = int(filename[0])
    MNIST_labels.append(label)

    sentence = mnist_sentences.get(filename, "")
    MNIST_sentences.append(sentence)

MNIST_set = list(zip(MNIST_images, MNIST_labels, MNIST_sentences))


... and plot some of the transformed MNIST digits.

In [None]:
import torch

fig, axs = plt.subplots(1, 8, figsize=(15, 3))
for ax in axs.ravel():
    img, label, _ = MNIST_set[torch.randint(0, len(MNIST_set), (1,)).squeeze()]
    ax.imshow(img.squeeze(), cmap="gray")
    ax.set_title(f"Label: {label}", fontsize=16)
    ax.axis("off")
plt.tight_layout()
plt.show()


## Implementing the Forward Process

Define the device on which the computations are performed ...

In [5]:
import torch

def get_device():
    """Select device where to perform computations."""
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    # elif torch.backends.mps.is_available():
    #    return torch.device("mps")
    else:
        return torch.device("cpu")

In [6]:
device = get_device()

In [None]:
print(device)

... implement the forward diffusion process ...

In [8]:
class Diffusion:
    """Denoising diffusion probabilstic model (DDPM)."""
    
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02,
                 device=device):
        """Initialize the diffusion model."""
        self.noise_steps, self.beta_start, self.beta_end, self.device = \
            noise_steps, beta_start, beta_end, device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        """Prepare the noise schedule."""
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def forward_diffusion(self, x, t):
        """Implement the forward diffusion process."""
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar = \
            torch.sqrt(1 - self.alpha_bar[t])[:, None, None, None]
        noise = torch.randn_like(x)

        return sqrt_alpha_bar * x + sqrt_one_minus_alpha_bar * noise, noise

... sampling images in the forward diffusion process ..

In [9]:
diffusion = Diffusion(noise_steps=401, beta_start=0.0001, beta_end=0.02)

clean_image, label, _ = MNIST_set[torch.randint(0, len(MNIST_set), (1,)).squeeze()]

time_steps = [0, 100, 200, 300, 400]
noisy_images = []
for i in range(len(time_steps)):
    noisy_image, noise = diffusion.forward_diffusion(
        x=clean_image[None, ...].to(device), 
        t=torch.tensor([time_steps[i]]).to(device)
    )
    noisy_images.append(noisy_image)

... and visualize the noisy digits generated in the forward diffusion process.

In [None]:
fig, axs = plt.subplots(1, len(time_steps))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(noisy_images[i].cpu().numpy().squeeze(), cmap="gray")
    ax.set_title(f"t = {time_steps[i]}", fontsize=10)
    ax.axis("off")
plt.tight_layout()
plt.show()

## Implementing the Reverse Diffusion Process

Update the `Diffusion` class to implement the reverse diffusion process.

In [11]:
from tqdm import tqdm

class Diffusion:
    """Denoising diffusion probabilstic model (DDPM)."""
    
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02,
                 img_size=28, device=device):
        """Initialize the diffusion model."""
        self.noise_steps, self.beta_start, self.beta_end, self.device = \
            noise_steps, beta_start, beta_end, device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        
        self.img_size = img_size

    def prepare_noise_schedule(self):
        """Prepare the noise schedule."""
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def forward_diffusion(self, x, t):
        """Implement the forward diffusion process."""
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar = \
            torch.sqrt(1 - self.alpha_bar[t])[:, None, None, None]
        noise = torch.randn_like(x)

        return sqrt_alpha_bar * x + sqrt_one_minus_alpha_bar * noise, noise
    
    def reverse_diffusion(self, model, n_images, n_channels, 
                          position_encoding_dim, position_encoding_function,
                          fix_noise=None, save_time_steps=None, 
                          context=None, guidance_strength=None):
        """Reverse diffusion process"""
        with torch.no_grad():
            if fix_noise is not None:
                x = fix_noise.to(self.device)
            else:
                x = torch.randn(
                    (n_images, n_channels, self.img_size, self.img_size)
                ).to(self.device)
            
            denoised_images = []
            for i in tqdm(reversed(range(0, self.noise_steps)),
                          desc="U-Net inference", total=self.noise_steps):
                t = (torch.ones(n_images) * i).long()
                t_pos_enc = position_encoding_function(
                    t.unsqueeze(1), position_encoding_dim
                ).to(self.device)

                if context is None:
                    predicted_noise = model(x=x, t=t_pos_enc)
                else:
                    conditional_pred = model(x=x, t=t_pos_enc, context=context)
                    unconditional_pred = model(x=x, t=t_pos_enc, context=None)
                    predicted_noise = torch.lerp(unconditional_pred,
                                                 conditional_pred,
                                                 guidance_strength)

                alpha = self.alpha[t][:, None, None, None]
                alpha_bar = self.alpha_bar[t][:, None, None, None]
                
                noise = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
                    
                x = (1 / torch.sqrt(alpha) * (x - ((1 - alpha) 
                    / torch.sqrt(1 - alpha_bar)) * predicted_noise) 
                    + torch.sqrt(1-alpha) * noise)

                if i in save_time_steps: denoised_images.append(x)

            denoised_images = torch.stack(denoised_images)
            denoised_images = denoised_images.swapaxes(0, 1)
            return denoised_images

## Defining the Position Encoding Function

Implement the position encoding function ...

In [12]:
def positional_encoding(t, enc_dim):
    """Encode position information with a sinusoid."""
    scaled_positions = torch.arange(0, enc_dim, 2).float() / enc_dim
    frequency = 10000 ** scaled_positions
    inverse_frequency = (1.0 / frequency).to(t.device)
    x = t.repeat(1, enc_dim // 2) * inverse_frequency
    pos_enc_a, pos_enc_b = torch.sin(x), torch.cos(x)
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
    return pos_enc

... sinusoidal position encodings for different time steps ...

In [13]:
position_encoding_dim = 256

pos_encs = []
for i in range(0, 100):
    t = torch.tensor([i])
    pos_enc = positional_encoding(t, position_encoding_dim)
    pos_encs.append(pos_enc.squeeze())
pos_encs = torch.stack(pos_encs)

... and visualize the position encodings.

In [None]:
fig = plt.figure()
plt.imshow(pos_encs.cpu().numpy())
plt.xlabel("Encoding dimension")
plt.ylabel("Time step (t)")
plt.show()

## Defining a Custom Tokenizer and Text Encoder

Implement the function to tokenize the text ...

In [15]:
import contractions, re, spacy, unicodedata

tokenizers = {"eng": spacy.blank("en"), "spa": spacy.blank("es")}

regular_expression = r"^[a-zA-Z0-9áéíóúüñÁÉÍÓÚÜÑ.,!?¡¿/:()]+$"
pattern = re.compile(unicodedata.normalize("NFC", regular_expression))

def tokenize(text, lang="eng"):
    """Tokenize text."""
    swaps = {"’": "'", "‘": "'", "“": '"', "”": '"', "´": "'", "´´": '"'}
    for old, new in swaps.items():
        text = text.replace(old, new)
    text = contractions.fix(text) if lang == "eng" else text
    tokens = tokenizers[lang](text)
    return [token.text for token in tokens if pattern.match(token.text)]

In [16]:
class Vocab:
    """Wrapper around a dictionary to make it callable like torchtext's Vocab."""
    def __init__(self, vocab_dict, unk_token="<unk>"):
        self.vocab_dict = vocab_dict
        self.unk_token = unk_token
        self.default_index = vocab_dict.get(unk_token, -1)
        self.index_to_token = {idx: token for token, idx in vocab_dict.items()}
        
    def __call__(self, token_or_tokens):
        """Make the vocab callable to return the index for a given token or list of tokens."""
        if isinstance(token_or_tokens, list):
            return [self.vocab_dict.get(token, self.default_index) for token in token_or_tokens]
        return self.vocab_dict.get(token_or_tokens, self.default_index)
    
    def set_default_index(self, index):
        """Set default index for unknown tokens."""
        self.default_index = index

    def lookup_token(self, index_or_indices):
        """Retrieve the token corresponding to a given index or list of indices."""
        if isinstance(index_or_indices, list):
            return [self.index_to_token.get(int(index), self.unk_token) for index in index_or_indices]
        return self.index_to_token.get(int(index_or_indices), self.unk_token)

    def get_itos(self):
        """Return a list of tokens ordered by their index."""
        itos = [None] * len(self.index_to_token)
        for index, token in self.index_to_token.items():
            itos[index] = token
        return itos
        
    def __iter__(self):
        """Iterate over the tokens in the vocabulary."""
        return iter(self.vocab_dict)

    def __len__(self):
        """Return the number of tokens in the vocabulary."""
        return len(self.vocab_dict)
    
    def __contains__(self, token):
        """Check if a token is in the vocabulary."""
        return token in self.vocab_dict


In [17]:
from collections import Counter

def build_vocab_from_iterator(iterator, specials=None, min_freq=1):
    """Build vocabulary from an iterator over tokenized sentences."""
    token_freq = Counter(token for tokens in iterator for token in tokens)
    vocab, index = {}, 0
    if specials: 
        for token in specials: 
            vocab[token] = index
            index += 1
    for token, freq in token_freq.items():
        if freq >= min_freq:
            vocab[token] = index
            index += 1
    return vocab

... build a custom vocabulary ...

In [18]:
def sentence_iterator(dataset):
    """Iterate over the IMBD dataset."""
    for sample in dataset:
        yield tokenize(sample)

with open('test_sentences.txt', 'r') as file:
    test_sentences = [line.strip() for line in file.readlines()]

vocab_dict = build_vocab_from_iterator(sentence_iterator(MNIST_sentences+test_sentences),
                                  specials=["<unk>"], min_freq=1)

vocab = Vocab(vocab_dict, unk_token="<unk>")
vocab.set_default_index(vocab(vocab.unk_token))

... check the vocabulary and string-to-numerical indices mapping ...

In [None]:
for word, index in vocab_dict.items(): print(f"{word}: {index}")

... write a function to implemet a function to pad and process ...

In [20]:
def pad_and_process(texts, vocab=vocab, max_token_length=77):
    """Tokenize a sentence."""
    batch_tokens = []
    for text_prompt in texts:
        tokens = ([vocab("<sos>")] + 
                  [vocab(token) for token in tokenize(text_prompt)] + 
                  [vocab("<eos>")])

        if len(tokens) > max_token_length:
            tokens = tokens[:max_token_length]
        else:
            tokens += [vocab("<pad>")] * (max_token_length - len(tokens))
        
        tokens = torch.tensor(tokens, dtype=torch.long)
        batch_tokens.append(tokens)
    return torch.stack(batch_tokens)

... observe a tokenized version of an example text prompt ...

In [87]:
_, _, example_text = random.choice(MNIST_set)
tokens = pad_and_process([example_text])
print(example_text), print(tokens.shape), print(tokens);

Near the waterfall, there are eight moss-covered rocks, 4 ferns, and two butterflies. How many insects are here?
torch.Size([1, 77])
tensor([[  0, 108,   2, 464,   4,   5,   6,   9, 465, 466, 467,   4,  60, 468,
           4,  11, 343, 448,  14,  15,  16, 469,   6, 470,  19,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0]])


... implement a class for a custom text encoder ...

In [22]:
import torch.nn as nn

class TextEncoder(nn.Module):
    """Text encoder."""

    def __init__(self, max_token_length, vocab_size, embedding_dim, num_heads):
        """Initialize the text encoder module."""
        super(TextEncoder, self).__init__()
        
        self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_encoding = positional_encoding(
            torch.arange(0, max_token_length).unsqueeze(1), embedding_dim,
        )

        self.self_attention = nn.MultiheadAttention(
            embedding_dim, num_heads=num_heads, batch_first=True,
        )
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim), nn.GELU(),
            nn.Linear(embedding_dim, embedding_dim),
        )

    def forward(self, tokens):
        """Forward pass of text encoder module."""
        token_embeddings = self.token_embedding(tokens)
        position_encodings = (self.position_encoding
                              .repeat(tokens.size(0), 1, 1).to(tokens.device))

        token_embeddings_with_pos = token_embeddings + position_encodings

        normalized_embeddings = self.layer_norm1(token_embeddings_with_pos)
        attention_output, _ = self.self_attention(
            query=normalized_embeddings, 
            key=normalized_embeddings, 
            value= normalized_embeddings,
        )
        attention_output = attention_output + normalized_embeddings

        residual_input = attention_output
        normalized_attention_output = self.layer_norm2(attention_output)
        feed_forward_output = self.feed_forward(normalized_attention_output)
        residual_output = feed_forward_output + residual_input

        return residual_output

... instantiate a custom text encoder ...

In [23]:
custom_text_encoder = TextEncoder(max_token_length=77, vocab_size=len(vocab),
                                  embedding_dim=768, num_heads=4).to(device)

... and check the shape of example text prompts.

In [88]:
_, _, example_text = random.choice(MNIST_set)
tokens = pad_and_process([example_text])
text_embedding = custom_text_encoder(tokens.to(device))

print(f"tokens shape: {tokens.shape}")
print(f"text embeddings shape: {text_embedding.shape}")

tokens shape: torch.Size([1, 77])
text embeddings shape: torch.Size([1, 77, 768])


## Defining the Conditional Attention U-Net

In [None]:
import deeplay as dl

position_encoding_dim = 256

unet_template = dl.AttentionUNet(  ###
    in_channels=1, channels=[32, 64, 128], base_channels=[256, 256], 
    channel_attention=[True, True, True], out_channels=1, 
    position_embedding_dim=position_encoding_dim, num_classes=10,
    context_embedding_dim=768,
)
unet = unet_template.create()  ### unet.build()
unet.to(device);

print(unet)

## Training the Diffusion Model

Define the data loader ...

In [66]:
from torch.utils.data import DataLoader


loader = DataLoader(dataset=MNIST_set, batch_size=128, shuffle=True)

... define the loss function ...

In [49]:
criterion = torch.nn.MSELoss()

... define the optimizer ...

In [50]:
optimizer = torch.optim.AdamW(list(unet.parameters())
                              + list(custom_text_encoder.parameters()), 
                              lr=1e-4)

...  instantiate the diffusion class for training ...

In [29]:
diffusion = Diffusion(
    noise_steps=1000, img_size=28, beta_start=1e-4, beta_end=0.02,
)

...implement a function to prepare the data ...

In [30]:
def prepare_data(image, noise_steps=1000, device=device):
    """Prepare data."""
    batch_size = image.shape[0]
    t = torch.randint(low=0, high=noise_steps, size=(batch_size,)).to(device)
    image = image.to(device)
    x_t, noise = diffusion.forward_diffusion(image, t)
    t = positional_encoding(t.unsqueeze(1), position_encoding_dim)
    return x_t.to(device), t.to(device), noise.to(device)

... implement the training cycle ...

In [None]:
import numpy as np
import time
from datetime import timedelta
import random

epochs = 20
n_images = 5

example_texts = random.sample(test_sentences, n_images)

save_time_steps = [999, 900, 800, 700, 600, 500, 400, 300, 200, 100, 0]

train_loss = []
for epoch in range(epochs):
    start_time = time.time()
    num_batches = len(loader)

    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "_" * 10)
    
    unet.train(), custom_text_encoder.train()
    
    running_loss = 0.0
    for batch_idx, (images, class_labels, text_inputs) in enumerate(loader, start=0):
        x_t, t, noise = prepare_data(images)

        # Custom Tokenizer.
        tokens = pad_and_process(text_inputs)
        text_embeddings = custom_text_encoder(tokens.to(device))

        # CLIP.
        """tokens = CLIP_tokenizer(
            text_inputs, padding="max_length",
            max_length=CLIP_tokenizer.model_max_length, truncation=True,
            return_tensors="pt",
        )
        text_embeddings = CLIP_text_encoder(tokens.input_ids.to(device))[0]"""
        
        context = None if np.random.rand() < 0.1 else text_embeddings

        outputs = unet(x=x_t, t=t, context=context)

        optimizer.zero_grad()
        loss = criterion(outputs, noise)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx + 1}/{num_batches}: "
                  + f"Train loss: {loss.item():.4f}")
        running_loss += loss.item()

    train_loss.append(running_loss / len(loader))
    end_time = time.time()

    print("-" * 10 + "\n" + f"Epoch {epoch + 1}/{epochs} : " 
          + f"Train loss: {train_loss[-1]:.4f}, " 
          + f"Time taken: {timedelta(seconds=end_time - start_time)}")
    
    unet.eval(), custom_text_encoder.eval()

    # example_texts = random.sample(test_sentences, n_images)
        
    # Custom Tokenizer.
    tokens = pad_and_process(example_texts)
    text_embeddings = custom_text_encoder(tokens.to(device))
    
    # CLIP.
    """tokens = CLIP_tokenizer(
        example_texts, padding="max_length",
        max_length=CLIP_tokenizer.model_max_length, truncation=True,
        return_tensors="pt",
    )
    text_embeddings = CLIP_text_encoder(tokens.input_ids.to(device))[0]"""
    
    generated_images = diffusion.reverse_diffusion(
        model=unet, n_images=n_images, n_channels=1, 
        position_encoding_dim=position_encoding_dim,
        position_encoding_function=positional_encoding,
        save_time_steps=save_time_steps, 
        context=text_embeddings, guidance_strength=3,
    )

    fig = plt.figure(figsize=(len(save_time_steps), 1.2 * n_images))
    for idx in range(n_images):
        image_reverse_diff_traj = generated_images[idx]
        for j in range(len(image_reverse_diff_traj)):
            plt.subplot(n_images, len(image_reverse_diff_traj), 
                        idx * len(image_reverse_diff_traj) + j + 1)
            plt.imshow(image_reverse_diff_traj[j]
                       .permute(1, 2, 0).cpu().numpy(), cmap="gray")
            if j == 5: plt.title(example_texts[idx], fontsize=10)
            plt.axis("off")
    plt.show()
    plt.close()

In [None]:
unet.eval(), custom_text_encoder.eval()

''' class_labels = random.sample(list(range(10)), n_images)
example_texts = generate_text_prompts(class_labels, test_text_prompts)'''

example_texts = random.sample(test_sentences, n_images)


# Custom Tokenizer.
tokens = pad_and_process(example_texts)
text_embeddings = custom_text_encoder(tokens.to(device))

generated_images = diffusion.reverse_diffusion(
        model=unet, n_images=n_images, n_channels=1, 
        position_encoding_dim=position_encoding_dim,
        position_encoding_function=positional_encoding,
        save_time_steps=save_time_steps, 
        context=text_embeddings, guidance_strength=3,
    )

fig = plt.figure(figsize=(len(save_time_steps), 1.2 * n_images))
for idx in range(n_images):
    image_reverse_diff_traj = generated_images[idx]
    for j in range(len(image_reverse_diff_traj)):
        plt.subplot(n_images, len(image_reverse_diff_traj), 
                    idx * len(image_reverse_diff_traj) + j + 1)
        plt.imshow(image_reverse_diff_traj[j]
                    .permute(1, 2, 0).cpu().numpy(), cmap="gray")
        if j == 5: plt.title(example_texts[idx], fontsize=10)
        plt.axis("off")
plt.show()
# plt.savefig(f"fig_10_B1.pdf", bbox_inches="tight")  
# plt.close()

## Defining CLIP Tokenizer and CLIP Text Encoder

Import the CLIP tokenizer and CLIP text encoder ...

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer

CLIP_tokenizer = CLIPTokenizer.from_pretrained(
    "openai/clip-vit-large-patch14", torch_dtype=torch.float32,
)
CLIP_text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14", torch_dtype=torch.float32,
).to(device)

... check the tokens and the text embeddings of an example text prompt generated by CLIP ...

In [None]:
example_text = random.choice(test_sentences)

tokens = CLIP_tokenizer(
    example_text, padding="max_length", 
    max_length=CLIP_tokenizer.model_max_length,
    truncation=True, return_tensors="pt",
)
text_embedding = CLIP_text_encoder(tokens.input_ids.to(device))[0]

print(f"text: {example_text}")
print(f"tokens: {tokens.input_ids}")
print(f"tokens shape: {tokens.input_ids.shape}")
print(f"text embeddings: {text_embedding.shape}")

... and freeze the weights of CLIP.

In [None]:
for param in CLIP_text_encoder.parameters():
    param.requires_grad = False

### Defining the Conditional Attention U-Net for CLIP

In [None]:
import deeplay as dl

position_encoding_dim = 256

unet_template = dl.AttentionUNet(  ###
    in_channels=1, channels=[32, 64, 128], base_channels=[256, 256], 
    channel_attention=[True, True, True], out_channels=1, 
    position_embedding_dim=position_encoding_dim, num_classes=10,
    context_embedding_dim=768,
)
unet_clip = unet_template.create()  ### unet.build()
unet_clip.to(device);  ###

print(unet_clip)  ###

## Training the Diffusion Model with CLIP

Remove the `custom_text_encoder` parameters from the optimizer ...

In [None]:
optimizer_clip = torch.optim.AdamW(unet_clip.parameters(), lr=1e-4)  ###

... remove `custom_text_encoder` and add CLIP to training cycle ...

In [None]:
import numpy as np
import time
from datetime import timedelta

epochs = 20
n_images = 5

example_texts = random.sample(test_sentences, n_images)

save_time_steps = [999, 900, 800, 700, 600, 500, 400, 300, 200, 100, 0]

train_loss = []
for epoch in range(epochs):
    start_time = time.time()
    num_batches = len(loader)

    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "_" * 10)
    
    unet_clip.train() ###, custom_text_encoder.train()
    
    running_loss = 0.0
    for batch_idx, (images, class_labels, text_inputs) in enumerate(loader, start=0):
        x_t, t, noise = prepare_data(images)

        # Custom Tokenizer.
        """tokens = pad_and_process(text_inputs)
        text_embeddings = custom_text_encoder(tokens.to(device))"""

        # CLIP.
        tokens = CLIP_tokenizer(
            text_inputs, padding="max_length",
            max_length=CLIP_tokenizer.model_max_length, truncation=True,
            return_tensors="pt",
        )
        text_embeddings = CLIP_text_encoder(tokens.input_ids.to(device))[0]
        
        context = None if np.random.rand() < 0.1 else text_embeddings
        
        outputs = unet_clip(x=x_t, t=t, context=context)  ###

        optimizer_clip.zero_grad()
        loss = criterion(outputs, noise)
        loss.backward()
        optimizer_clip.step()
        
        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx + 1}/{num_batches}: "
                  + f"Train loss: {loss.item():.4f}")
        running_loss += loss.item()

    train_loss.append(running_loss / len(loader))
    end_time = time.time()

    print("-" * 10 + "\n" + f"Epoch {epoch + 1}/{epochs} : " 
          + f"Train loss: {train_loss[-1]:.4f}, " 
          + f"Time taken: {timedelta(seconds=end_time - start_time)}")
    
    unet_clip.eval() ###, custom_text_encoder.eval()

    # example_texts = random.sample(test_sentences, n_images)


    # Custom Tokenizer.
    """tokens = pad_and_process(example_texts)
    text_embeddings = custom_text_encoder(tokens.to(device))"""
    
    # CLIP.
    tokens = CLIP_tokenizer(
        example_texts, padding="max_length",
        max_length=CLIP_tokenizer.model_max_length, truncation=True,
        return_tensors="pt",
    )
    text_embeddings = CLIP_text_encoder(tokens.input_ids.to(device))[0]
    
    generated_images = diffusion.reverse_diffusion(
        model=unet_clip, n_images=n_images, n_channels=1,  ###
        position_encoding_dim=position_encoding_dim,
        position_encoding_function=positional_encoding,
        save_time_steps=save_time_steps, 
        context=text_embeddings, guidance_strength=3,
    )

    fig = plt.figure(figsize=(len(save_time_steps), 1.2 * n_images))
    for idx in range(n_images):
        image_reverse_diff_traj = generated_images[idx]
        for j in range(len(image_reverse_diff_traj)):
            plt.subplot(n_images, len(image_reverse_diff_traj), 
                        idx * len(image_reverse_diff_traj) + j + 1)
            plt.imshow(image_reverse_diff_traj[j]
                       .permute(1, 2, 0).cpu().numpy(), cmap="gray")
            if j == 5: plt.title(example_texts[idx], fontsize=10)
            plt.axis("off")
    plt.show()
    plt.close()

In [None]:
unet_clip.eval()

''' class_labels = random.sample(list(range(10)), n_images)
example_texts = generate_text_prompts(class_labels, test_text_prompts)'''

example_texts = random.sample(test_sentences, n_images)

# Custom Tokenizer.
# tokens = pad_and_process(example_texts)
# text_embeddings = custom_text_encoder(tokens.to(device))

tokens = CLIP_tokenizer(
    example_texts, padding="max_length",
    max_length=CLIP_tokenizer.model_max_length, truncation=True,
    return_tensors="pt",
)
text_embeddings = CLIP_text_encoder(tokens.input_ids.to(device))[0]

generated_images = diffusion.reverse_diffusion(
        model=unet_clip, n_images=n_images, n_channels=1, 
        position_encoding_dim=position_encoding_dim,
        position_encoding_function=positional_encoding,
        save_time_steps=save_time_steps, 
        context=text_embeddings, guidance_strength=3,
    )

fig = plt.figure(figsize=(len(save_time_steps), 1.2 * n_images))
for idx in range(n_images):
    image_reverse_diff_traj = generated_images[idx]
    for j in range(len(image_reverse_diff_traj)):
        plt.subplot(n_images, len(image_reverse_diff_traj), 
                    idx * len(image_reverse_diff_traj) + j + 1)
        plt.imshow(image_reverse_diff_traj[j]
                    .permute(1, 2, 0).cpu().numpy(), cmap="gray")
        if j == 5: plt.title(example_texts[idx], fontsize=10)
        plt.axis("off")
plt.show()
# plt.savefig(f"fig_10_B2.pdf", bbox_inches="tight")
# plt.close()