In [None]:
import os
import time
import numpy as np

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from tqdm import tqdm 
from ema_pytorch import EMA
from modules import *
from diffusion import GaussianDiffusion

from diffusers import AutoencoderKL
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp

In [None]:
class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, 
                       act_layer=lambda: nn.GELU(approximate="tanh"), drop=0)
        
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        # Calculate shift, scale, and gate values
        mod_params = self.adaLN_modulation(c).chunk(6, dim=1)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod_params

        # Attention block
        modulated_norm_x = modulate(self.norm1(x), shift_msa, scale_msa)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulated_norm_x)

        # MLP block
        modulated_norm_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulated_norm_x)
        return x

In [None]:
class DiT(nn.Module):
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=768,
        depth=12,
        num_heads=8,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=25,
        learn_sigma=True
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.num_classes = num_classes

        # Embedding layers
        self.x_emb = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        self.y_emb = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        self.t_emb = nn.Sequential(
            SinusoidalPosEmb(256),
            nn.Linear(256, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )

        # Positional embedding
        num_patches = self.x_emb.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

        # Transformer blocks and final layer
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        pos_embed = sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_emb.num_patches ** 0.5))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Initialize adaLN modulation and final layer weights to zero
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def unpatchify(self, x):
        # Reconstruct the image from patches
        c = self.out_channels
        p = self.x_emb.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        x = x.reshape(x.shape[0], h, w, p, p, c)
        x = torch.einsum('nhwpqc->nchpwq', x)
        return x.reshape(x.shape[0], c, h * p, h * p)

    def forward(self, x, t, y):
        x = self.x_emb(x) + self.pos_embed
        c = self.t_emb(t) + self.y_emb(y, self.training)

        for block in self.blocks:
            x = block(x, c)
        x = self.final_layer(x, c)
        return self.unpatchify(x)

In [None]:
class Trainer():
    def __init__(
            self,
            model,
            dataloader,
            ckpt_dir,
            load_path=None,
            total_step=400000,
            save_n_step=10000,
            lr=1e-4,
            timestep=1000
        ):
        os.makedirs(ckpt_dir, exist_ok=True)
        os.makedirs('./results/', exist_ok=True)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = nn.DataParallel(model).to(self.device)
        
        self.timestep = timestep
        self.dataloader = dataloader
        self.ckpt_dir = ckpt_dir
        
        self.step = 1
        self.n_classes = model.num_classes
        self.total_step = total_step
        self.save_n_step = save_n_step
        self.ema = EMA(self.model, beta = 0.995, update_every = 1)
        self.ema.to(self.device)
        
        self.loss_fn = nn.MSELoss()
        #self.loss_fn = nn.L1Loss() 
        self.loss_history = []  # List to store loss values
        self.scaler = torch.amp.GradScaler(self.device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.diff_method = GaussianDiffusion(self.device, timestep=timestep)

        if load_path is not None:
            self.load_state_dict(load_path)
            self.ema.copy_params_from_ema_to_model()
            print("sucessful load state dict !!!!!!")
            print(f"start from step {self.step}")
    
    def state_dict(self, step):
        return {
            "step": step,
            "ema": self.ema.state_dict(),
        }
    
    def load_state_dict(self, path):
        state_dict = torch.load(path)
        self.ema.load_state_dict(state_dict['ema'])
        self.step = state_dict['step']

    def train(self):
        start = time.time()
        print(f'Start of step {self.step}')
        
        for step in tqdm(range(self.step, self.total_step+1), desc=f"Training progress"):
            self.optimizer.zero_grad()
            img, label = next(iter(self.dataloader))
            img = img.to(self.device)
            label = label.to(self.device)
            
            with torch.no_grad():
                # encoder the image into latent
                # use AutoencoderKL.from_pretrained() class to construct pre-train model
                latent = vae.encode(img).latent_dist.sample().mul_(0.18215)
    
            noise = torch.randn_like(latent)
            t = torch.randint(0, self.timestep, (latent.shape[0],), device=self.device)
            noisy_latent = self.diff_method.q_sample(latent, t, noise)
            pred_noise = self.model(noisy_latent, t, label)

            # compute the MSE between true noise (eps) and pred noise
            loss = self.loss_fn(pred_noise, noise)
            loss = loss.mean()

            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            self.grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), 1e9)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.ema.update()
            
            self.loss_history.append(loss.item())
            if step % self.save_n_step == 0:
                #clear_output(wait=True)
                epoch = step // self.save_n_step
                time_minutes = (time.time() - start) / 60
                torch.save(self.state_dict(step), f"{self.ckpt_dir}/weight_epoch{epoch}.pt")
                
                print(f"epoch: {epoch}, loss: {loss.data} ~~~~~~")
                print (f'Time taken for epoch {epoch} is {time_minutes:.3f} min\n') 
                print(f"sucessful saving epoch {epoch} state dict !!!!!!!")
                start = time.time()

                self.generate(epoch)
                self.ema.copy_params_from_ema_to_model()
        print("finish training: ~~~~~~~~~~~~~~~~~~~~~~~~~~")
    
    def plot_loss(self):
        plt.figure(figsize=(10, 5))
        plt.plot(self.loss_history, label='Loss')
        plt.xlabel('Steps')
        plt.ylabel('Loss')
        plt.title('Training Loss over Steps')
        plt.legend()
        plt.savefig(f'./training_loss.jpg')
        plt.show()
    
    @torch.inference_mode()
    def generate(self, epoch):
        _, label = next(iter(self.dataloader))
        x_0 = self.diff_method.ddim_sample(self.ema,label[:9], self.n_classes)           
        img = vae.decode(x_0 / 0.18215).sample

        fake_img = img
        num_rows = 3
        num_columns = 3
        
        _, axs = plt.subplots(num_rows, num_columns, figsize=(6, 6))
        for i in range(num_rows):
            for j in range(num_columns):
                ax = axs[i, j]
                index = i * num_columns + j
                img = fake_img[index]
                img = unnormalize(img, device=self.device)
                img = img.clamp(0, 1)

                # Display the image
                ax.imshow(img.permute(1, 2, 0).cpu().detach().numpy())
                ax.axis('off')
        plt.savefig(f'./{epoch}_result_img.png')
        plt.show()

In [None]:
imagenet_dir = "/kaggle/input/imagenet100/" 

transform = transforms.Compose([
    transforms.Resize((256, 256)),  
    transforms.ToTensor(),         
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  
                         std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=os.path.join(imagenet_dir, "train.X1"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to('cuda')

In [None]:
LR = 1e-4
TOTAL_ITERATION = 200000
SAVE_N_ITERATION = 5000
CKPT_DIR = './model_weight/'
LOAD_PATH = None

model = DiT(
    input_size=32,
    patch_size=2,
    in_channels=4,
    hidden_size=768,
    depth=12,
    num_heads=8,
    mlp_ratio=4.0,
    class_dropout_prob=0.1,
    num_classes=25,
    learn_sigma=False
)

trainer = Trainer(
    model,
    train_loader,
    CKPT_DIR,
    load_path=LOAD_PATH, 
    total_step=TOTAL_ITERATION,
    save_n_step=SAVE_N_ITERATION,
    lr=LR
)

In [None]:
trainer.train()