In [2]:
%%writefile utils.py
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import math

def generate_betas(beta_0, beta_T, num_steps): 
    delta = beta_T - beta_0 
    return [delta*i/(num_steps-1) + beta_0 for i in range(num_steps)] 

def generate_alpha(beta_0, beta_T, num_steps): 
    betas = generate_betas(beta_0, beta_T, num_steps)
    alphas = [1-b for b in betas]
    return alphas

def generate_alpha_bar(beta_0, beta_T, num_steps): 
    alphas = generate_alpha(beta_0, beta_T, num_steps)
    alpha_bar = [alphas[0]]
    for i in range(1, len(alphas)): alpha_bar.append(alpha_bar[i-1] * alphas[i])
    return alpha_bar

def count_parameters(model: nn.Module):
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total:,}")
    return total

Overwriting utils.py


In [30]:
%%writefile guided_unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F


def count_parameters(model: nn.Module):
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total:,}")
    return total

# interleaved = torch.stack((a, b), dim=2).reshape(a.shape[0], -1)

def get_pos_emb(positions, emb_dim):
    assert emb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    factor = 100 ** ((torch.arange(
        start=0, end=emb_dim // 2, dtype=torch.float32, device=positions.device) / (emb_dim // 2))
    )
    t_emb = positions[:, None]
    t_emb = t_emb.repeat(1, emb_dim // 2)
    t_emb = t_emb / factor
    t_emb = torch.stack((torch.sin(t_emb), torch.cos(t_emb)), dim=2).reshape(t_emb.shape[0], -1)
    return t_emb

def get_time_embedding(time_steps, temb_dim):
    # factor = 10000^(2i/d_model)
    factor = 100 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2)
    t_emb = t_emb / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, down_sample=True, num_heads=4, num_layers=1, X_att_dim=768):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )
        
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.X_att_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(X_att_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
             for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )
        self.Xatts = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels, 4, 2, 1) if self.down_sample else nn.Identity()
        self.lamb_pre = nn.Parameter(torch.tensor([0.5]))

    def forward(self, x, t_emb, X_att_emb, guided):
        out = x
        for i in range(self.num_layers):
            
            # Resnet block of Unet
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            # Attention block of Unet
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)

            X_att_emb_i = self.X_att_emb_layers[i](X_att_emb)
            out_attn_self, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn_cross, _ = self.Xatts[i](in_attn, X_att_emb_i, X_att_emb_i)
            out_attn = out_attn_self + F.tanh(self.lamb_pre) * out_attn_cross * guided.view(-1, 1, 1)

            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
        out = self.down_sample_conv(out)
        return out


class MidBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads=4, num_layers=1, X_att_dim=768):
        super().__init__()
        self.num_layers = num_layers
        
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers+1)
            ]
        )
        
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers + 1)
        ])
        self.X_att_emb_layer = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(X_att_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers+1)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)]
        )
        self.Xatts = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers+1)
            ]
        )

        self.lamb_pre = nn.Parameter(torch.tensor([0.5]))
    
    def forward(self, x, t_emb, X_att_emb, guided):
        out = x
        
        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)
        
        for i in range(self.num_layers):
            
            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)

            X_att_emb_i = self.X_att_emb_layer[i](X_att_emb)
            out_attn_self, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn_cross, _ = self.Xatts[i](in_attn, X_att_emb_i, X_att_emb_i)
            out_attn = out_attn_self + F.tanh(self.lamb_pre) * out_attn_cross * guided.view(-1, 1, 1)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i+1](out)
            out = out + self.t_emb_layers[i+1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i+1](out)
            out = out + self.residual_input_conv[i+1](resnet_input)
        
        return out


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, up_sample=True, num_heads=4, num_layers=1, X_att_dim=768):
        super().__init__()
        self.num_layers = num_layers
        self.up_sample = up_sample

        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )
        self.t_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(t_emb_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.X_att_emb_layers = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(),
                nn.Linear(X_att_dim, out_channels)
            )
            for _ in range(num_layers)
        ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(8, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )       
        self.attention_norms = nn.ModuleList(
            [
                nn.GroupNorm(8, out_channels)
                for _ in range(num_layers)
            ]
        )   
        self.attentions = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )
        self.Xatts = nn.ModuleList(
            [
                nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                for _ in range(num_layers)
            ]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
                                                 4, 2, 1) \
            if self.up_sample else nn.Identity()
        
        self.lamb_pre = nn.Parameter(torch.tensor([0.5]))
    
    def forward(self, x, out_down, t_emb, X_att_emb, guided):
        x = self.up_sample_conv(x)
        x = torch.cat([x, out_down], dim=1)

        out = x
        for i in range(self.num_layers):
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)

            X_att_emb_i = self.X_att_emb_layers[i](X_att_emb)
            out_attn_self, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn_cross, _ = self.Xatts[i](in_attn, X_att_emb_i, X_att_emb_i)
            out_attn = out_attn_self + F.tanh(self.lamb_pre) * out_attn_cross * guided.view(-1, 1, 1)

            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn

        return out


class GuidedUnet(nn.Module):
    def __init__(self, X_att_dim):
        super().__init__()
        im_channels = 3
        self.down_channels = [32, 64, 128, 128]
        self.mid_channels = [128, 128, 128]
        self.t_emb_dim = 128
        self.down_sample = [True, True, False]
        self.num_down_layers = 2
        self.num_mid_layers = 2
        self.num_up_layers = 2
        self.X_att_dim = X_att_dim
        self.max_tokens = 20
             
        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-2]
        assert len(self.down_sample) == len(self.down_channels) - 1
        
        # Initial projection from sinusoidal time embedding
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )

        self.up_sample = list(reversed(self.down_sample))
        self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=3, padding=(1, 1))
        
        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels)-1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim,
                                        down_sample=self.down_sample[i], num_layers=self.num_down_layers, X_att_dim=X_att_dim))
        
        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels)-1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim,
                                      num_layers=self.num_mid_layers, X_att_dim=X_att_dim))
        
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels)-1)):
            self.ups.append(UpBlock(self.down_channels[i] * 2, self.down_channels[i-1] if i != 0 else 16,
                                    self.t_emb_dim, up_sample=self.down_sample[i], num_layers=self.num_up_layers, X_att_dim=X_att_dim))
        
        self.norm_out = nn.GroupNorm(8, 16)
        self.conv_out = nn.Conv2d(16, im_channels, kernel_size=3, padding=1)
        self.pos_enc = nn.Parameter(torch.tensor(get_pos_emb(torch.tensor(range(self.max_tokens)), self.X_att_dim)))

    def forward(self, x, t, xa_emb, guided):
        out = self.conv_in(x)
        # B x C1 x H x W

        xa_emb = xa_emb + self.pos_enc[:xa_emb.shape[1]]
        # w_emb = self.text_enc(w_emb)
        
        # t_emb -> B x t_emb_dim
        t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        down_outs = []
        
        for idx, down in enumerate(self.downs):
            down_outs.append(out)
            out = down(out, t_emb, xa_emb, guided)
        # down_outs  [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
        # out B x C4 x H/4 x W/4
            
        for mid in self.mids:
            out = mid(out, t_emb, xa_emb, guided)
        # out B x C3 x H/4 x W/4
        
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb, xa_emb, guided)
            # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]

        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        # out B x C x H x W
        return out



Overwriting guided_unet.py


In [33]:
%%writefile runme.py
import torch
import torch.nn as nn 
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from torch.utils.data import Subset
from tqdm import tqdm
from guided_unet import GuidedUnet, get_pos_emb
from utils import count_parameters
device = torch.device("cuda")

# IMPORT WORD EMBEDDING MATRIX
dim = 768
mym = GuidedUnet(dim).to(device)
count_parameters(mym)

# IMPORT DATASET
latent_set = torch.load("/kaggle/input/fruit-diffusion-dataset/z_set.pth")
time_set = torch.load("/kaggle/input/fruit-diffusion-dataset/t_set.pth")
label_set = torch.load("/kaggle/input/fruit-diffusion-dataset/l_set.pth")
noise_set = torch.load("/kaggle/input/fruit-diffusion-dataset/n_set.pth")
print(latent_set.shape, time_set.shape,label_set.shape, noise_set.shape)

dataset = TensorDataset(latent_set, time_set, label_set, noise_set)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
print("train_loader created")

max_label = int(torch.max(label_set).item())
print("max_label:", max_label)
emb_mat = get_pos_emb(torch.tensor(range(max_label+1)), dim).to(device)

print("beginning training")
criterion = nn.MSELoss()
mym = mym.to(device)
optimizer = torch.optim.Adam(mym.parameters(), lr=5e-5)

num_epochs = 50
for epoch in range(num_epochs):
    total_loss = 0.0
    iter = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    last_loss = 0.0  # To keep track of the last batch's loss
    running_loss = 0

    for latent, time, label, noise in pbar:
        latent, time, label, noise = latent.to(device), time.to(device), label.to(device), noise.to(device)
        B = latent.shape[0]
        iter += 1

        guided = torch.randint(0,2,(B,)).cuda()

        x_att_in = emb_mat[label.long()].unsqueeze(1).repeat(1,5,1)

        # print(latent.shape, time.shape, x_att_in.shape, guided.shape)
        ns_hat = mym(latent, time, x_att_in, guided)
        loss = criterion(ns_hat, noise)
        
        # BATCH STUFF
        optimizer.zero_grad()
        loss.backward()    
        optimizer.step()
        total_loss += loss.item()
        last_loss = loss.item() 
        running_loss = 0.98 * running_loss + 0.02 * last_loss
        if iter % 100 == 0: torch.save(mym, "noise_predictor_b.pth")
        pbar.set_postfix(bl=1000*last_loss, rl=1000*running_loss)

    print(f"Epoch {epoch}, Last Batch Loss: {last_loss}, Total Loss: {total_loss:.4f}")
    torch.save(mym, "noise_predictor_e.pth")


Overwriting runme.py


In [34]:
!python runme.py

  self.pos_enc = nn.Parameter(torch.tensor(get_pos_emb(torch.tensor(range(self.max_tokens)), self.X_att_dim)))
Total trainable parameters: 6,485,883
  latent_set = torch.load("/kaggle/input/fruit-diffusion-dataset/z_set.pth")
  time_set = torch.load("/kaggle/input/fruit-diffusion-dataset/t_set.pth")
  label_set = torch.load("/kaggle/input/fruit-diffusion-dataset/l_set.pth")
  noise_set = torch.load("/kaggle/input/fruit-diffusion-dataset/n_set.pth")
torch.Size([170000, 3, 32, 32]) torch.Size([170000]) torch.Size([170000]) torch.Size([170000, 3, 32, 32])
train_loader created
max_label: 5
beginning training
Epoch 0:  13%|█▉             | 694/5313 [02:09<14:15,  5.40it/s, bl=215, rl=245]^C
Epoch 0:  13%|█▉             | 695/5313 [02:10<14:25,  5.34it/s, bl=187, rl=244]
Traceback (most recent call last):
  File "/kaggle/working/runme.py", line 60, in <module>
    loss.backward()    
    ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/_tensor.py", line 581, in backward
