# Dataset manipulations

In [1]:
!pip install torch==2.3.0
!pip install torchvision==0.18.0
!pip install torchtext==0.18.0
!pip install kagglehub

Collecting torch==2.3.0
  Downloading torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.0)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.3.0)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.3.0)
  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.3.0)
  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylin

In [2]:
import kagglehub
import pandas as pd
path = kagglehub.dataset_download("jessicali9530/celeba-dataset")

In [3]:
from torchvision.io import read_image
import torchvision.transforms as transforms
from torchvision.transforms import ToTensor
import torch
from torch.utils.data import Dataset, DataLoader
import os

opposites = {
  '5_o_Clock_Shadow': 'Clean_Shaven',
  'Arched_Eyebrows': 'Flat_Eyebrows',
  'Attractive': 'Unattractive',
  'Bags_Under_Eyes': 'Clean_Eyes',
  'Bald': 'Has_Hair',
  'Bangs': 'No_Bangs',
  'Big_Lips': 'Thin_Lips',
  'Big_Nose': 'Small_Nose',
  'Black_Hair': 'Bright_Hair',
  'Blond_Hair': 'Non_Blond_Hair',
  'Blurry': 'Clear',
  'Brown_Hair': 'Non_Brown_Hair',
  'Bushy_Eyebrows': 'Thin_Eyebrows',
  'Chubby': 'Slim',
  'Double_Chin': 'Single_Chin',
  'Eyeglasses': 'No_Eyeglasses',
  'Goatee': 'No_Goatee',
  'Gray_Hair': 'Non_Gray_Hair',
  'Heavy_Makeup': 'No_Makeup',
  'High_Cheekbones': 'Low_Cheekbones',
  'Male': 'Female',
  'Mouth_Slightly_Open': 'Mouth_Closed',
  'Mustache': 'No_Mustache',
  'Narrow_Eyes': 'Wide_Eyes',
  'No_Beard': 'Beard',
  'Oval_Face': 'Round_Face',
  'Pale_Skin': 'Tan_Skin',
  'Pointy_Nose': 'Flat_Nose',
  'Receding_Hairline': 'Full_Hairline',
  'Rosy_Cheeks': 'Pale_Cheeks',
  'Sideburns': 'No_Sideburns',
  'Smiling': 'Not_Smiling',
  'Straight_Hair': 'Curly_Hair',
  'Wavy_Hair': 'Straight_Hair',
  'Wearing_Earrings': 'No_Earrings',
  'Wearing_Hat': 'No_Hat',
  'Wearing_Lipstick': 'No_Lipstick',
  'Wearing_Necklace': 'No_Necklace',
  'Wearing_Necktie': 'No_Necktie',
  'Young': 'Old'
}

class CustomDataset(Dataset):
    def __init__(self, annotations_file, img_dir, opposites,image_transform):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.opposites=opposites
        self.image_transform=image_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, image_num):
        # return one face and corresponding description
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[image_num, 0])
        image = read_image(img_path)
        image_tensor=self.image_transform(image)

        description_series=self.img_labels.iloc[image_num]
        columns=self.img_labels.columns
        description=[column_name if description_series[column_name]==1 else self.opposites[column_name] for column_name in columns[1:]]
        description=' '.join(description)

        return image_tensor, description

image_transform = transforms.Compose([
    transforms.ConvertImageDtype(torch.float),  # Convert to float in range [0, 1]
    # transforms.CenterCrop((128, 128)),
    transforms.Resize((256,256)),#Pictures will be not natural proportions (178*216)->(256*256)
    transforms.Normalize([0.5], [0.5])
])
dataset=CustomDataset(path+'/list_attr_celeba.csv',path+'/img_align_celeba/img_align_celeba',opposites,image_transform)
dataloader=DataLoader(dataset,batch_size=32,shuffle=True)

In [4]:
import torch
from torch import nn
from torch.nn import functional as F
import math
import numpy as np

# Noising image

In [5]:
class Scheduler:
    def __init__(self, num_steps=1000, step_size=1,beta_start=0.0001, beta_end=0.008, img_size=128):
        # linear schedule
        self.betas=torch.linspace(beta_start,beta_end,num_steps,dtype=torch.float32)
        self.alphas=1.0-self.betas
        self.alphas_cumprod=torch.cumprod(self.alphas,0)

        self.img_size=img_size
        self.amount_steps=num_steps
        self.step_size=step_size # =20 on inference mode
        self.steps_tensor = torch.from_numpy(np.arange(0, num_steps)[::-1].copy())
        self.current_step=num_steps-1
    
    def add_noise(self, original_images, timestep):
        # (4) cumulative values for fast calculation of timestep image batch with noise
        # vector form calculations
        # images are in latent representations
        sqrt_alphas_cumprod=self.alphas_cumprod[timestep]
        sqrt_alphas_cumprod=sqrt_alphas_cumprod**0.5
        current_alpha_cumprod=self.alphas_cumprod[timestep]
        one_minus_alphas_cumprod=(1.0-current_alpha_cumprod)
        # Sample noise
        noise=torch.randn(original_images.shape, dtype=original_images.dtype, device=original_images.device)
        # return weighted sum of images
        noised_images=original_images*sqrt_alphas_cumprod+one_minus_alphas_cumprod*noise
        return noised_images

    def denoising_step(self, timestep_current, noised_image_latent_current, noise_latent_current):
        # predict previous image
        # noise_latent_current is obtained from unet

        # previous timestep. On inference step_size is bigger in sake of speed
        timestep_previous=timestep_current-self.amount_steps//self.step_size

        # get precomputed cumulative alpha - coefficients of original image
        alpha_cumprod_current=self.alphas_cumprod[timestep_current]
        alpha_cumprod_previous=self.alphas_cumprod[timestep_previous] if timestep_previous >=0 else torch.tensor(1.0)
        # get cumulative betas - coefficients of noise
        beta_cumprod_current=1.-alpha_cumprod_current
        beta_cumprod_previous=1.-alpha_cumprod_previous
        # get current_timestamp's alpha and beta from cumulative values division
        alpha_current=alpha_cumprod_current/alpha_cumprod_previous
        beta_current=1-alpha_current

        # estimate means we calculate it using orignal+noise*coef?
        # compute original_sample_latent_estimate as image_latent_current - noise_latent_current
        predicted_original_image_latent=(noised_image_latent_current - (beta_cumprod_current**0.5)*noise_latent_current)/(alpha_cumprod_current**0.5)
        # compute weights for predicted_original and noise
        coef_to_original_to_make_previous_current=((alpha_cumprod_previous**0.5)*beta_current)/beta_cumprod_current
        coef_to_make_previous_from_current=(alpha_current**0.5)*beta_cumprod_previous/beta_cumprod_current

        # In other words, Previous without noise = current*0.99 + original*0.01
        predicted_previous_image_latent=coef_to_make_previous_from_current*noised_image_latent_current + coef_to_original_to_make_previous_current*predicted_original_image_latent

        # Since process of noising is stochastic, we dont know exact image at each step of noising
        # Capture uncertainty with variance
        if timestep_current==0:
            uncertainty=0
            coef_uncertainty=0
        elif timestep_current>0:
            uncertainty=torch.randn(noise_latent_current.shape,dtype=noise_latent_current.dtype,device=noise_latent_current.device)
            # The closer to timestep=0, the less coef_uncertainty
            coef_uncertainty=(1-alpha_cumprod_previous)/(1-alpha_cumprod_current)*beta_current
            coef_uncertainty=torch.clamp(coef_uncertainty,min=1e-20)

        predicted_previous_image_latent=predicted_previous_image_latent+coef_uncertainty*uncertainty
        return predicted_previous_image_latent

    def set_inference_mode(self, step_size=20):
        self.step_size=20
        
    def reset_parameters(self,step_size=1):
        self.step_size=step_size
        self.current_step=self.amount_steps-1

In [6]:
# import matplotlib.pyplot as plt
# ddpmsampler=Scheduler()

# for images_batch,descriptions_batch in dataloader:
#   print(images_batch.shape)
#   print(descriptions_batch)
#   noised_images_batch=ddpmsampler.add_noise(images_batch,399)
#   batch_size = noised_images_batch.size(0)
#   fig, axes = plt.subplots(1, batch_size, figsize=(15, 5))

#   for i, image in enumerate(noised_images_batch):
#       axes[i].imshow(image.permute(1, 2, 0).cpu().numpy())
#       axes[i].axis('off')

#   plt.show()
#   break
  # TODO: fix clipping for better output

# Variational Autoencoder

Firstly, high resolution noise is encoded into lower (latent) dimenstion in order to make matrices smaller and have less compute. Thus, we need Encoder part and Decoder part

In [7]:
class MultiheadAttention(nn.Module):
    def __init__(self, amount_heads, dim_q, dim_kv=None,qkv_bias=True,o_bias=True):
        super().__init__()
        # Multihead to capture different relationships
        self.amount_heads=amount_heads
        self.head_dim=dim_q//amount_heads

        # Learnable projection of embedding to highlight usage in particular purpose
        self.Q_matrix=nn.Linear(dim_q,dim_q,bias=qkv_bias)
        if dim_kv==None:
            self.K_matrix=nn.Linear(dim_q,dim_q,bias=qkv_bias)
            self.V_matrix=nn.Linear(dim_q,dim_q,bias=qkv_bias)
        else:
            self.K_matrix=nn.Linear(dim_kv,dim_q,bias=qkv_bias)
            self.V_matrix=nn.Linear(dim_kv,dim_q,bias=qkv_bias)
        # Project back to initial embedding space
        self.O_matrix=nn.Linear(dim_q,dim_q,bias=o_bias)

    def forward(self, query, key=None, value=None, selfattention_mask=None):
        # Expect query to be Batch,seq_len=channel,embed_dim=H/8*W/8
        input_shape=query.shape
        batch_size, sequence_length, d_embed = query.shape
        if key==None and value==None:
            #self-attention
            key=query
            value=query
            #?
            interim_shape = (batch_size, sequence_length, self.amount_heads, self.head_dim)
        else:
            #?
            interim_shape = (batch_size, -1, self.amount_heads, self.head_dim)

        # Projections
        q_proj=self.Q_matrix(query).view(interim_shape).transpose(1, 2)
        k_proj=self.K_matrix(key).view(interim_shape).transpose(1, 2)
        v_proj=self.V_matrix(value).view(interim_shape).transpose(1, 2)

        attention_logits = (q_proj @ k_proj.transpose(-1, -2)) # Logits #-2,-1?
        if selfattention_mask is not None:
            # Mask where the upper triangle (above the principal diagonal) is 1
            mask = torch.ones_like(attention_logits, dtype=torch.bool).triu(1)
            # Fill the upper triangle with -inf
            attention_logits = attention_logits.masked_fill(mask, -torch.inf)

        # Normalize values. When having long sequence, attention value will naturally grow
        attention_logits=attention_logits/math.sqrt(self.head_dim)
        attention_weights = F.softmax(attention_logits, dim=-1)

        context_aware_x=attention_weights @ v_proj
        context_aware_x = context_aware_x.transpose(1, 2).contiguous().view(input_shape)
        context_aware_x=self.O_matrix(context_aware_x)
        return context_aware_x

In [8]:
class VAEMultiheadAttention(nn.Module):
    def __init__(self,channels):
        super().__init__()
        # Used, since convolution output features are related in groups
        # E.g. Dark eyes area is differ from bright nose area
        # Thus features representing should be normalized in local groups
        self.groupnorm=nn.GroupNorm(16,channels)
        self.multiheadattention=MultiheadAttention(1,channels) #TODO: check self/cross

    def forward(self,x):
        # keep initial information
        residue=x
        x=self.groupnorm(x)

        # print(x.shape)
        batch,channel,height,width=x.shape
        x=x.view(batch,channel,height*width)
        x = x.transpose(-1, -2) #?

        x=self.multiheadattention(x)

        x = x.transpose(-1, -2)
        x=x.view(batch,channel,height,width)

        x=x+residue
        return x

class VAEResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(16, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.groupnorm_2 = nn.GroupNorm(16, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

    def forward(self,x):
        # match residue to output
        residue=self.residual_layer(x)

        x=self.groupnorm_1(x)
        x=F.silu(x)
        x=self.conv_1(x)
        x=self.groupnorm_2(x)
        x=F.silu(x)
        x=self.conv_2(x)

        return x+residue

In [9]:
import torch.nn as nn
class VAEEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            # B,3,H,W -> B,32,H,W
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            VAEResidualBlock(32, 32),
            VAEResidualBlock(32, 32),
            # B,32,H,W->B,32,H/2,W/2
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            VAEResidualBlock(32, 64),
            VAEResidualBlock(64, 64),
            # B,64,H/2,W/2->B,64,H/4,W/4
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            VAEResidualBlock(64, 128),
            VAEResidualBlock(128, 128),
            ## B,128,H/4,W/4->B,128,H/8,W/8
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            #B,128=channel,H/8,W/8->B,128,H/8,W/8
            VAEMultiheadAttention(128),
            VAEResidualBlock(128, 128),
            nn.GroupNorm(16, 128),
            nn.SiLU(),
            nn.Conv2d(128, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 8, kernel_size=1, padding=0) #8,8,1,1
        ])

    def forward(self, x):
        for module in self.layers:
            x = module(x) #?
        mean, log_variance = torch.chunk(x, 2, dim=1)
        log_variance = torch.clamp(log_variance, -30, 20)
        variance = log_variance.exp()
        stdev = variance.sqrt()
        noise = torch.randn(stdev.shape, dtype=stdev.dtype, device=stdev.device)
        x = mean + stdev * noise
        x = x * 0.18215
        return x


class VAEDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(4, 4, kernel_size=1, padding=0),
            nn.Conv2d(4, 128, kernel_size=3, padding=1),
            VAEResidualBlock(128, 128),
            VAEMultiheadAttention(128),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            VAEResidualBlock(128, 128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            VAEResidualBlock(128, 64),
            VAEResidualBlock(64, 64),
            VAEResidualBlock(64, 64),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            VAEResidualBlock(64, 32),
            VAEResidualBlock(32, 32),
            VAEResidualBlock(32, 32),
            nn.GroupNorm(4, 32),
            nn.SiLU(),
            nn.Conv2d(32, 3, kernel_size=3, padding=1)
        ])

    def forward(self, x):
        for module in self.layers:
            x = module(x)
        x = x / 0.18125
        return x

#TODO: channels*=2

In [10]:
# encoder=VAEEncoder()
# decoder=VAEDecoder()

# for images_batch,descriptions_batch in dataloader:
#   encoded_images=encoder(images_batch)
#   decoded_images=decoder(encoded_images)
#   print('decoded shape',decoded_images.shape)

#   batch_size = encoded_images.size(0)
#   fig, axes = plt.subplots(1, batch_size, figsize=(15, 5))

#   for i, image in enumerate(encoded_images):
#       axes[i].imshow(image.permute(1, 2, 0).cpu().detach().numpy())
#       axes[i].axis('off')
#   break

# Text Encoder

At each step in the reverse diffusion, CLIP can compare the current generated image's embedding with the target text embedding and adjust the denoising direction accordingly. To be able to compare input prompt with denoised image at current timestep, clip should learn projections of initial image/text embedding to map to same space.

After similarity score is recieved, subtract it from loss: MSELoss(real_noise,predicted_unet_loss)-weight*similarity

Trade off vanilla smooth loss landscape to new model feature.

In [11]:
# from transformers import AutoModel, AutoTokenizer, BertTokenizer
# from torchvision import models
# from dataclasses import dataclass

# dataclass
# class Config:
#     embed_dim: int = 512
#     transformer_embed_dim: int = 768
#     max_len: int = 32
#     text_model: str = "distilbert-base-multilingual-cased"
#     epochs: int = 5
#     batch_size: int = 128

# class Projection(nn.Module):
#     # Project initial embedding into common space
#     def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
#         super().__init__()
#         self.linear1 = nn.Linear(d_in, d_out, bias=False)
#         self.linear2 = nn.Linear(d_out, d_out, bias=False)
#         self.layer_norm = nn.LayerNorm(d_out)
#         self.drop = nn.Dropout(p)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         embed1 = self.linear1(x)
#         embed2 = self.drop(self.linear2(F.gelu(embed1))) # non-linearity and small gradient fix
#         embeds = self.layer_norm(embed1 + embed2)
#         return embeds

# class VisionEncoder(nn.Module):
#     def __init__(self, d_out: int) -> None:
#         super().__init__()
#         self.base = models.resnet34(pretrained=True)
#         d_in = self.base.fc.in_features
#         self.base.fc = nn.Identity() #classification layer=I, resnet should return pure embeddings
#         self.projection = Projection(d_in, d_out)
#         for p in self.base.parameters():
#             p.requires_grad = False

#     def forward(self, x):
#         projected_vec = self.projection(self.base(x))
#         projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
#         return projected_vec / projection_len # normalize to len==1

# class TextEncoder(nn.Module):
#     def __init__(self, d_out: int) -> None:
#       super().__init__()
#       self.base = AutoModel.from_pretrained(Config.text_model)
#       self.projection = Projection(Config.transformer_embed_dim, d_out)
#       for p in self.base.parameters():
#           p.requires_grad = False

#     def forward(self, x):
#       #expects {input_ids:,attention_mask:} input from tokenizer
#       out = self.base(x)[0]#? [0]
#       out = out[:, 0, :]  # get CLS token output = embedding #?[:,0,:]
#       projected_vec = self.projection(out)
#       projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
#       return projected_vec / projection_len

# def CLIP_loss(logits: torch.Tensor) -> torch.Tensor:
#     n = logits.shape[1]      # number of samples
#     labels = torch.arange(n) # Create labels tensor
#     #  image-to-caption similarity
#     loss_i = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean")
#     # caption-to-image similarity
#     loss_t = F.cross_entropy(logits, labels, reduction="mean")
#     # Calculate the final loss
#     loss = (loss_i + loss_t) / 2

#     return loss

# def metrics(similarity: torch.Tensor):
#     y = torch.arange(len(similarity)).to(similarity.device)
#     img2cap_match_idx = similarity.argmax(dim=1)
#     cap2img_match_idx = similarity.argmax(dim=0)

#     img_acc = (img2cap_match_idx == y).float().mean()
#     cap_acc = (cap2img_match_idx == y).float().mean()

#     return img_acc, cap_acc

# class Tokenizer:
#     def __init__(self, tokenizer: BertTokenizer) -> None:
#         self.tokenizer = tokenizer

#     def __call__(self, x: str) -> AutoTokenizer:
#         return self.tokenizer(
#             x,
#             max_length=Config.max_len,
#             truncation=True,
#             padding=True,
#             return_tensors="pt",
#         )

# class CLIP(nn.Module):
#   def __init__(self, lr: float = 1e-3) -> None:
#     super().__init__()
#     self.vision_encoder = VisionEncoder(Config.embed_dim)
#     self.caption_encoder = TextEncoder(Config.embed_dim)
#     self.tokenizer = Tokenizer(AutoTokenizer.from_pretrained(Config.text_model))
#     self.lr = lr
#     self.device = "cuda" if torch.cuda.is_available() else "cpu"

#   def forward(self, images, text):
#     # tensor image, tuple of texts as input
#     text_tokenized = self.tokenizer(text).to(self.device)
#     caption_embed = self.caption_encoder(text_tokenized["input_ids"])
#     print(caption_embed.shape)
#     image_embed = self.vision_encoder(images)
#     print(image_embed.shape)
#     similarity = caption_embed.T @ image_embed

#     return similarity

# clip=CLIP()

# optimizer = torch.optim.Adam(
#     [
#         {"params": clip.vision_encoder.parameters()},
#         {"params": clip.caption_encoder.parameters()},
#     ],
#     lr=clip.lr,
# )

# for epoch in range(1,2):
#     clip.train()
#     for batch in dataloader:
#         image, text = batch
#         print(len(text),text)
#         similarity = clip(image, text)
#         print(similarity.shape)
#         loss = CLIP_loss(similarity)
#         img_acc, cap_acc = metrics(similarity)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         break

#     print(f"Epoch [epoch+1/num_epochs], Batch Loss: {loss.item()}, img_acc,cap_acc: {img_acc,cap_acc}")

In [12]:
from torchtext.vocab import build_vocab_from_iterator
from collections import Counter

class TextEmbedding(nn.Module):
    def __init__(self, n_vocab: int, n_embd: int, n_token: int):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_embd)
        # A learnable weight matrix encodes the position information for each token
        self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd)))

    def forward(self, tokens):
        x = self.token_embedding(tokens)
        x = x + self.position_embedding
        return x

class EncoderLayer(nn.Module):
    def __init__(self, n_head: int, n_embd: int):
        super().__init__()
        self.layernorm_1 = nn.LayerNorm(n_embd)
        self.attention = MultiheadAttention(n_head, n_embd)
        self.layernorm_2 = nn.LayerNorm(n_embd)
        self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
        self.linear_2 = nn.Linear(4 * n_embd, n_embd)

    def forward(self, x):
        # (Batch_Size, Seq_Len, Dim)
        residue = x

        # Learning representations
        x = self.layernorm_1(x)
        x = self.attention(x, selfattention_mask=True)
        x = x+ residue

        # Thinking about hierarchical relationships
        residue = x
        x = self.layernorm_2(x)
        x = self.linear_1(x)
        x = x * torch.sigmoid(1.702 * x)   # QuickGELU activation function
        x = self.linear_2(x)
        x =x+residue

        return x

class TextEncoder(nn.Module):
    # Text embedding
    def __init__(self):
        super().__init__()
        self.embedding = TextEmbedding(81, 32, 40)

        self.layers = nn.ModuleList([
            EncoderLayer(8, 32) for i in range(8)
        ])

        self.layernorm = nn.LayerNorm(32)

    def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
        # Expect tokenized input
        tokens = tokens.type(torch.long)
        state = self.embedding(tokens)
        for layer in self.layers:
            state = layer(state)
        output = self.layernorm(state)
        return output

class Tokenizer:
    def __init__(self, opposites):
        all_words = set()
        for key, value in opposites.items():
            all_words.add(key)
            all_words.add(value)
        self.build_vocab(all_words)

    def build_vocab(self, data):
        tokenized_data = [[word] for word in data]  # Each word in its own list
        self.vocab = build_vocab_from_iterator(tokenized_data, specials=["<pad>", "<unk>"])
        self.vocab.set_default_index(self.vocab["<unk>"])  # Handle unknown tokens

    def tokenize(self, batch_text_string):
        batch = []
        for text in batch_text_string:
            text_splitted = text.split()  # Split on whitespace to get words
            input_ids = [self.vocab[token] for token in text_splitted]
            batch.append(input_ids)
        return batch



In [13]:
# Only 40 words in dictionary, build custom tokenizer and learn custom embeddings

# tokenizer = Tokenizer(opposites)
# textencoder = TextEncoder()

# for images_batch, descriptions_batch in dataloader:
#     input_ids = tokenizer.tokenize(descriptions_batch)
#     input_ids = torch.tensor(input_ids, dtype=torch.long)
#     embeddings = textencoder(input_ids)
#     print(embeddings)
#     break

# Unet

In [14]:
class UNETResidualConvolution(nn.Module):
    # keep track of current timestep via adjusting image with timestep_embedded
    def __init__(self,in_channels,out_channels, timestamp_dim=1280):
        super().__init__()
        # split vector into 32 groups and normalize separetly
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        # transform initial timestamp embedding into out_channels
        self.linear1 = nn.Linear(timestamp_dim, out_channels)

        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

    def forward(self,x, timestamp_embedded):
        # match residue to output
        residue=self.residual_layer(x)

        timestamp_embedded=F.silu(timestamp_embedded)
        timestamp_embedded=self.linear1(timestamp_embedded)

        x=self.groupnorm_1(x)
        x=F.silu(x)
        x=self.conv_1(x)

        # Add time to width and height matrix
        # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
        merged=x+timestamp_embedded.unsqueeze(-1).unsqueeze(-1)

        merged=self.groupnorm_2(merged)
        merged=F.silu(x)
        merged=self.conv_2(merged)

        return merged+residue

class UNETAttention(nn.Module):

    def __init__(self,n_head,head_dim,context_dim=32):
        super().__init__()
        channels = n_head * head_dim
        self.channels=channels

        self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
        self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

        self.layernorm_1 = nn.LayerNorm(channels)
        # selfattention within image
        self.attention_1 = MultiheadAttention(n_head, channels, qkv_bias=False)

        self.layernorm_2 = nn.LayerNorm(channels)
        # Here, cross attention between image Query and text Key,Values
        self.attention_2 = MultiheadAttention(n_head, channels, dim_kv=context_dim, qkv_bias=False)

        self.layernorm_3 = nn.LayerNorm(channels)
        self.linear_geglu_1  = nn.Linear(channels, 4 * channels * 2)
        self.linear_geglu_2 = nn.Linear(4 * channels, channels)

        self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

    def forward(self,image,text):
        # image: [1, 320, 16, 16]
        # text   [2, 40, 32]
        # print(image.shape,text.shape)
        residue_long=image
        x=self.groupnorm(image)
        x=self.conv_input(x)
        n, c, h, w = x.shape
        x = x.view((n, c, h * w))
        x = x.transpose(-1, -2)

        residue_short=x
        x=self.layernorm_1(x)
        # self attention for image
        x=self.attention_1(x)
        x=x+residue_short

        residue_short=x
        x=self.layernorm_2(x)
        # cross attention in order to guide model
        #[1, 256, 320]) torch.Size([2, 40, 32])
        # print(x.shape,text.shape) #1,256->2,128
        x=self.attention_2(x,text,text)
        x=x+residue_short

        residue_short=x
        x=self.layernorm_3(x)
        x,gate=self.linear_geglu_1(x).chunk(2,dim=-1)
        x=x*F.gelu(gate)
        x=self.linear_geglu_2(x)
        x=x+residue_short

        x = x.transpose(-1, -2)
        x = x.view((n, c, h, w))

        x=self.conv_output(x)
        x=x+residue_long
        return x

class Upsample(nn.Module):
    # In Unet,starting from bottleneck Upsampling blocks are performed to get initial size of image_latent
    #? Isnt latent represenation too small? use celeba256
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x=self.conv(x)
        return x

class SwitchSequential(nn.Sequential):
    # Lazy model's modules run based on passed parameters
    def forward(self, x, context, time):
        for layer in self:
            if isinstance(layer, UNETAttention):
                x = layer(x, context)

            elif isinstance(layer, UNETResidualConvolution):
                x = layer(x, time)
            else:
                #conv2d or upsample
                x = layer(x)
        return x

class UNET(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoders = nn.ModuleList([
          SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),

#           SwitchSequential(UNETResidualConvolution(320, 320), UNETAttention(8, 40)),
          SwitchSequential(UNETResidualConvolution(320, 320), UNETAttention(8, 40)),
          # (Batch_Size, 320, Height / 8, Width / 8) is obtained from vae
          # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
          SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),

          SwitchSequential(UNETResidualConvolution(320, 640), UNETAttention(8, 80)),
#           SwitchSequential(UNETResidualConvolution(640, 640), UNETAttention(8, 80)),

          # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
          SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),

          SwitchSequential(UNETResidualConvolution(640, 1280), UNETAttention(8, 160)),
#           SwitchSequential(UNETResidualConvolution(1280, 1280), UNETAttention(8, 160)),

          # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
#           SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),

#           SwitchSequential(UNETResidualConvolution(1280, 1280)),
#           SwitchSequential(UNETResidualConvolution(1280, 1280)),
      ])

        self.bottleneck = nn.ModuleList([
          SwitchSequential(UNETResidualConvolution(1280, 1280)),
          SwitchSequential(UNETAttention(8, 160)),
          SwitchSequential(UNETResidualConvolution(1280, 1280)),
        ])

        self.decoders = nn.ModuleList([
#           SwitchSequential(UNETResidualConvolution(2560, 1280)),
#           SwitchSequential(UNETResidualConvolution(2560, 1280)),

          # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
#           SwitchSequential(UNETResidualConvolution(2560, 1280), Upsample(1280)),

#           SwitchSequential(UNETResidualConvolution(2560, 1280), UNETAttention(8, 160)),
          SwitchSequential(UNETResidualConvolution(2560, 1280), UNETAttention(8, 160)),

          # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
          SwitchSequential(UNETResidualConvolution(1920, 1280), UNETAttention(8, 160), Upsample(1280)),

          SwitchSequential(UNETResidualConvolution(1920, 640), UNETAttention(8, 80)),
#           SwitchSequential(UNETResidualConvolution(1280, 640), UNETAttention(8, 80)),

          # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
          SwitchSequential(UNETResidualConvolution(960, 640), UNETAttention(8, 80), Upsample(640)),
          SwitchSequential(UNETResidualConvolution(960, 320), UNETAttention(8, 40)),

#           SwitchSequential(UNETResidualConvolution(640, 320), UNETAttention(8, 40)),
          SwitchSequential(UNETResidualConvolution(640, 320), UNETAttention(8, 40)),
      ])

    def forward(self,x,context,timestep_embedded):
        # x: denoised latent picture (Batch_Size, 4, Height / 8, Width / 8)
        # context: (Batch_Size, Seq_Len, Dim)
        # time: (1, 1280)

        skip_connections = []
        for layer in self.encoders:
            x = layer(x, context, timestep_embedded)
            skip_connections.append(x)

        for layer in self.bottleneck:
            x = layer(x, context, timestep_embedded)

#         print(len(skip_connections),len(self.decoders))
        for layer in self.decoders:
            corresponding_skip=skip_connections.pop()
            x = torch.cat((x,corresponding_skip), dim=1)
            x = layer(x, context, timestep_embedded)
        
        return x

class UNET_OutputLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    def forward(self,x):
        x=self.groupnorm(x)
        x=F.silu(x)
        x=self.conv(x)
        return x

class TimeEmbedding(nn.Module):
  # create complex time embedding
    def __init__(self, timestamp_dim):
        super().__init__()
        self.linear1 = nn.Linear(timestamp_dim, 4 * timestamp_dim)
        self.linear2 = nn.Linear(4 * timestamp_dim, 4 * timestamp_dim)

    def forward(self, time):
        # modifies one number into layer (1,320)
        timestep_embedded=self.linear1(time)
        timestep_embedded=F.silu(timestep_embedded)
        timestep_embedded=self.linear2(timestep_embedded)
        return timestep_embedded

In [15]:
class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_embedding = TimeEmbedding(320)#? why 320
        self.unet = UNET()
        self.final = UNET_OutputLayer(320, 4)

    def forward(self, latent, context, timestep):
        # latent: (Batch_Size, 4, Height / 8, Width / 8)
        # context: (Batch_Size, Seq_Len, Dim)
        # time: (1, 320)

        # (1, 320) -> (1, 1280)
        timestep_embedded = self.time_embedding(timestep)

        # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
        output = self.unet(latent, context, timestep_embedded)

        # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
        output = self.final(output)

#         print("Outside: input size", latent.size(),
#           "output_size", latent.size())
        # (Batch, 4, Height / 8, Width / 8)
        return output

# Inference Pipeline

In [16]:
def rescale(x, old_range, new_range, clamp=False):
    old_min, old_max = old_range
    new_min, new_max = new_range
    x -= old_min
    x *= (new_max - new_min) / (old_max - old_min)
    x += new_min
    if clamp:
        x = x.clamp(new_min, new_max)
    return x

def get_time_embedding(timestep):
    # Shape: (160,): tensor([1.0000e+00, 9.4384e-01, 8.9021e-01, 8.3898e-01, ..., 1.0000e-04])
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
    # Shape: (1, 160): tensor([[10.0000,  9.4384,  8.9021,  8.3898, ..., 0.0010]])
    x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
    # Shape: (1, 160 * 2): tensor([[cos(10.0000), cos(9.4384), ..., sin(10.0000), sin(9.4384), ...]])
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

def image_context_similarity_loss(image, context):
    batch_images_previous_prediction_resized = F.interpolate(
        image, size=(32, 2048), mode="bilinear", align_corners=False
    )  # Shape: [Batch, 3, 32, 2048]

    context_flattened = context.view(context.size(0), -1)  # Shape: [Batch, 1280]
    context_flattened = context_flattened / context_flattened.norm(dim=-1, keepdim=True)
    context_padded = F.pad(context_flattened, (0, 2048 - context_flattened.size(1)))  # Shape: [Batch, 2048]

    context_padded_expanded = context_padded.view(context_padded.size(0), 1, 1, -1)  # Shape: [Batch, 1, 1, 2048]

    # Compute dot product along the feature dimension (2048)
    dot_product = torch.sum(
        batch_images_previous_prediction_resized * context_padded_expanded, dim=-1
    )  # Shape: [Batch, 3, 32]

    # Reduce dot product to a scalar loss to be able to call .backward
    similarity_loss = 1 - torch.mean(dot_product)
    return similarity_loss

In [17]:
from tqdm.autonotebook import tqdm
def evaluate(diffusion,tokenizer,text_encoder,image_encoder,image_decoder,n_inference_steps,descriptions_batch,scheduler,do_cfg):
    diffusion.eval()
    text_encoder.eval()
    image_encoder.eval()
    image_decoder.eval()
    
    with torch.no_grad():
        cond_tokens = tokenizer.tokenize(descriptions_batch)
        cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
        cond_context = text_encoder(cond_tokens)
        if do_cfg:
            uncond_tokens = tokenizer.tokenize(descriptions_batch)
            uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
            uncond_context = text_encoder(uncond_tokens)
            context = torch.cat([cond_context, uncond_context]) # B*2, 40, 32
        else:
            context = cond_context # B, 40, 32
            
        # set step_size
        scheduler.set_inference_mode(scheduler.amount_steps//n_inference_steps)

        # # Create noised latent image
        latents_shape = (context.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH) # B, 4, 32, 32 #? what is 4?
        latents = torch.randn(latents_shape, device=device)
        
        # move to gpu0
        context=context.to(device)
        latents=latents.to(device)
        
        timesteps = tqdm(range(scheduler.amount_steps-1,0,-scheduler.step_size))
        for i, timestep in enumerate(timesteps):
            scheduler.current_step=timestep
            # for each denoising step, create timestep embedding from cosine/sine waves
            time_embedding = get_time_embedding(timestep).to(device)
            predicted_latent_noise = diffusion(latents, context, time_embedding)
            latents = scheduler.denoising_step(timestep, latents, predicted_latent_noise)

        # From latent, go back to Original image
        images = decoder(latents)

        # Fix colors
        images = rescale(images, (-1, 1), (0, 255), clamp=True)
        # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel)
        images = images.to("cpu", torch.uint8).numpy()
    return images

# Training pipeline

In [18]:
import torch.optim as optim
#TODO: take configs from stable diffustion github
scheduler=Scheduler() #!
encoder=VAEEncoder()
tokenizer = Tokenizer(opposites)
textencoder = TextEncoder()
diffusion=Diffusion()
decoder=VAEDecoder()
# textencoder=nn.DataParallel(textencoder)
# diffusion=nn.parallel.DistributedDataParallel(diffusion)
# encoder=nn.DataParallel(encoder)
# decoder=nn.DataParallel(decoder)

In [19]:
print(sum(p.numel() for p in encoder.parameters()))
print(sum(p.numel() for p in textencoder.parameters()))
print(sum(p.numel() for p in diffusion.parameters()))
print(sum(p.numel() for p in decoder.parameters()))

2148144
105568
408647044
3104343


In [20]:
WIDTH = 256
HEIGHT = 256
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8
device="cuda" if torch.cuda.is_available() else "cpu"
diffusion.to(device)
encoder.to(device)
decoder.to(device)
textencoder.to(device)
strength=0.8
do_cfg=False
cfg_scale=7.5
n_inference_steps=50 #!
models={"image_encoder":encoder,
        "tokenizer":tokenizer,
        "text_encoder":textencoder,
        "diffusion":diffusion,
        "image_decoder":decoder}
optimizers={
    "image_encoder":optim.Adam(encoder.parameters(), lr=4.5e-6),
    "text_encoder":optim.Adam(textencoder.parameters(), lr=5e-4),
    "diffusion":optim.Adam(diffusion.parameters(), lr=2e-6),
    "image_decoder":optim.Adam(decoder.parameters(), lr=4.5e-6)
}
device="cuda" if torch.cuda.is_available() else "cpu"
idle_device="cpu"

In [21]:
for original_images_batch, descriptions_batch in dataloader:
    descriptions_evaluate=descriptions_batch
    break

In [22]:
# if torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#     # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
#     diffusion=nn.DataParallel(diffusion)
# diffusion.to(device)

In [23]:
from tqdm.autonotebook import tqdm
def train_one_epoch(diffusion,tokenizer,text_encoder,image_encoder,image_decoder,optimizers,loss_fn,dataloader,scheduler,do_cfg,epoch):
    loop = tqdm(
    enumerate(dataloader, 1),
    total=len(dataloader),
    desc=f"Epoch {epoch}: train",
    leave=True,
    )
    diffusion.train()
    text_encoder.train()
    image_encoder.train()
    image_decoder.train()
    train_loss = 0.0

    for original_images_batch, descriptions_batch in dataloader:
        # # Text embeddings, context from prompts creation
        cond_tokens = tokenizer.tokenize(descriptions_batch)
        cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
        cond_context = text_encoder(cond_tokens)
        if do_cfg:
            uncond_tokens = tokenizer.tokenize(descriptions_batch)
            uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
            uncond_context = text_encoder(uncond_tokens)
            context = torch.cat([cond_context, uncond_context]) # B*2, 40, 32
        else:
            context = cond_context # B, 40, 32

        # move to gpu0
        original_images_batch=original_images_batch.to(device)
        context=context.to(device)

        for timestep in range (scheduler.amount_steps-1,0,-scheduler.step_size):
            scheduler.current_step=timestep
            # for each denoising step, create timestep embedding from cosine/sine waves
            time_embedding = get_time_embedding(scheduler.current_step).to(device)

            # # Sample and encode noised images
            batch_noised_images=scheduler.add_noise(original_images_batch,scheduler.current_step)
            batch_latent_noised_images=image_encoder(batch_noised_images) # current
            
            # Get actual and predicted latent noise
            predicted_latent_noise = diffusion(batch_latent_noised_images, context, time_embedding)
            batch_images_previous_prediction_latent = scheduler.denoising_step(timestep, batch_latent_noised_images, predicted_latent_noise)
            sample_latent_noise=batch_images_previous_prediction_latent-batch_latent_noised_images

            predicted_noise=image_decoder(predicted_latent_noise)
            sample_noise=image_decoder(sample_latent_noise)

            # MSE between sample noise and predicted noise
            loss=loss_fn(predicted_noise,sample_noise)
            train_loss+=loss

            optimizers['text_encoder'].zero_grad()
            optimizers['image_encoder'].zero_grad()
            optimizers['diffusion'].zero_grad()

            loss.backward()
            optimizers['text_encoder'].step()
            optimizers['image_encoder'].step()
            optimizers['diffusion'].step()

            #  Only one backpropogation for initial text encoder is needed (occured in timestep=999)
            context=context.detach()

            # Train decoder with special loss
            batch_images_previous_prediction=image_decoder(batch_images_previous_prediction_latent.detach())
#             print(context.shape,batch_images_previous_prediction.shape) [2, 40, 32]) torch.Size([2, 3, 256, 256])
            decoder_loss=image_context_similarity_loss(batch_images_previous_prediction,context)
            optimizers['image_decoder'].zero_grad()
            decoder_loss.backward(retain_graph=True)
            optimizers['image_decoder'].step()
            train_loss+=decoder_loss
            
        loop.set_postfix(loss=train_loss / (loop.n or 1))
    return train_loss

In [24]:
from torch.nn import MSELoss
import matplotlib.pyplot as plt
import torch

for epoch in range(10):
    loss=train_one_epoch(diffusion, tokenizer, textencoder, encoder, decoder, optimizers, MSELoss(), dataloader, scheduler, False, epoch)
    scheduler.reset_parameters()
    
    if epoch % 1 == 0:
        # Save checkpoints
        torch.save(diffusion.state_dict(), f"diffusion_epoch_{epoch}.pth")
        torch.save(textencoder.state_dict(), f"textencoder_epoch_{epoch}.pth")
        torch.save(encoder.state_dict(), f"encoder_epoch_{epoch}.pth")
        torch.save(decoder.state_dict(), f"decoder_epoch_{epoch}.pth")
        
        denoised_images = evaluate(diffusion, tokenizer, textencoder, encoder, decoder, n_inference_steps, descriptions_evaluate, scheduler, False)
        print(epoch, loss)
        batch_size = denoised_images.shape[0]
        fig, axes = plt.subplots(1, batch_size, figsize=(15, 5))
        # Ensure axes is always a list (wrap single axis in a list for batch_size=1)
        if batch_size == 1:
            axes = [axes]
        for i, image in enumerate(denoised_images):
            # Transpose to (H, W, C) format for NumPy
            axes[i].imshow(image.transpose(1, 2, 0))
            axes[i].axis('off')
        plt.show()

Epoch 0: train:   0%|          | 0/101300 [00:00<?, ?it/s]

torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 32, 32])
torch.Size([2, 4, 32, 32]) torch.Size([2, 40, 32]) torch.Size([1, 320])
torch.Size([2, 4, 32, 32]) torch.Size([2, 4, 3

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 

In [None]:
        # # Create initial noised latent image
#         if not training_mode:
#             # evaluation mode, generate image from initial pure noise
#             scheduler.set_inference_mode(scheduler.step_size)
#             latents_shape = (context.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH) # B, 4, 16, 16 #? what is 4?
#             batch_latent_noised_images = torch.randn(latents_shape, device=device)
# if do_cfg:
#                 # Copy same latent noised image for conditional and uncoditional
#                 batch_latent_noised_images = batch_latent_noised_images.repeat(2, 1, 1, 1)
#             if do_cfg:
#                 #?
#                 output_cond, output_uncond = predicted_noise.chunk(2)
#                 predicted_noise = cfg_scale * (output_cond - output_uncond) + output_uncond