In [106]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST, Flowers102, StanfordCars, CIFAR10
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import os
from torchvision.io import read_image
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from datetime import datetime
import pytorch_lightning as pl
pl.seed_everything(42)

Seed set to 42


42

In [107]:
batch_size = 32
n_T = 1000
device = "cuda"
n_classes = 2
n_feat = 128
l_rate = 0.00008
save_model = False
save_dir = './imgs/'
img_size = 128

In [108]:
class ResidualConvBlock(nn.Module):
    """
    A residual convolutional block with two convolutinal layers and GELU activation.
    """
    def __init__(
        self, in_channels, out_channels, is_res = False
    ) -> None:
        super().__init__()
        # Check if input and output channels are same for resiual connectoin
        self.same_channels = in_channels==out_channels
        # Whether to perform residual connection
        self.is_res = is_res

        # First convolutoinal layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

        # Second convolutoinal layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x):
        if self.is_res:
            # Residual
            x1 = self.conv1(x)
            # Second conv layer
            x2 = self.conv2(x1)

            # Add residual connection based on channels
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2

            return out
        else:
            # Non-residual
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2

In [109]:
class UnetDown(nn.Module):
    """
    A downscampling block for my cnn model.
    """
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()

        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [110]:
class UnetUp(nn.Module):
    """
    An upsampling block for my cnn model.
    """
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        # Concatenate the input with the skip connection of resnet
        x = torch.cat((x, skip), 1)

        # Run the concatanated tensor throgh the model
        x = self.model(x)

        return x

In [111]:
class EmbedFC(nn.Module):
    """
    A network for embedding features.
    """
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()

        self.input_dim = input_dim

        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        # Reshape the input tensor to a single dimensoin
        x = x.view(-1, self.input_dim)

        # pass the input through the embeding network
        return self.model(x)

In [112]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat = 128, n_classes = 10):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        # Initial conv
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        # Downsampling blocks
        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        # Feacter vect extract for context and time embeddings
        self.to_vec = nn.Sequential(nn.AvgPool2d(8), nn.GELU())

        # Embeddings for context and time info
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        self.contextembed2 = EmbedFC(n_classes, 1*n_feat)

        # Upscaling blocks
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 8, 8),
            nn.GroupNorm(8, 2 * n_feat),
            nn.GELU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)

        # Output conv
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.GELU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):
        x = self.init_conv(x)

        # Downsampling
        down1 = self.down1(x)
        down2 = self.down2(down1)

        # print("hi1")

        # Feature vector for context and time embedding
        hiddenvec = self.to_vec(down2)

        # print(c)

        # print('hi2')

        # Convert context information to one-hot encoding and apply context mask
        c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)
        
        # print(c)

        # print('hi3')
        # print(c)

        # print(c)
        # print(c.shape)

        context_mask = context_mask[:, None].repeat(1, self.n_classes) * -1 + 1
        c = c * context_mask

        # print('hi4')

        # Generate context and time embeddings
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)


        # print('hi5')

        # Upsampling path with context and time embeddings
        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1+ temb1, down2)
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))

        # print('hi6')

        return out

In [113]:
def ddpm_schedules(beta1, beta2, T):
    """
    This function computes and returns pre-computed schedules for DDPM sampling and training.
    """

    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    # Compute values for nois
    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)
    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab
    
    return {
        "alpha_t": alpha_t,
        "oneover_sqrta": oneover_sqrta,
        "sqrt_beta_t": sqrt_beta_t,
        "alphabar_t": alphabar_t,
        "sqrtab": sqrtab,
        "sqrtmab": sqrtmab,
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,
    }


In [114]:
tf = transforms.Compose([transforms.ToTensor()])

In [115]:
def cvtImg(img, channels=3):
    if channels == 3:
        img = img
    elif channels == 4:
        img = img[0]
    img = img - img.min()
    img = (img / img.max())
    return tf(img.numpy().astype(float))

In [116]:
class DDPM(pl.LightningModule):
    def __init__(self, nn_model=ContextUnet(in_channels=3, n_feat=n_feat, n_classes=n_classes),
                 betas=(1e-4, 0.02), n_T=n_T, device=device, timesteps=n_T,
                 device_=device, drop_prob=0.1, learning_rate=1e-4, num_classes=n_classes):
        super(DDPM, self).__init__()

        # Initialize the neural network model
        self.nn_model = nn_model.to(device_)

        # Store the num of class
        self.num_classes = num_classes

        # Register pre-computed schedules for diffusion process (noise gen)
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.learning_rate = learning_rate
        self.step = 0
        self.n_T = timesteps
        self.drop_prob = drop_prob
        self.loss_mse = nn.SmoothL1Loss()

    def forward(self, x, c):
        # Sample a random diffusion timestep
        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)

        # Generate noise for diffusion process
        noise = torch.randn_like(x)

        # Compute the denoised image at the chosen timestep
        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )

        # Apply context mask with dropout
        context_mask = torch.bernoulli(torch.zeros_like(c, dtype=torch.float) + self.drop_prob).to(self.device)

        # Compute the loss between the prediced noise and the actual noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample(self, n_sample, size, device, guide_w = 1.0):
        # Initialize the starting noise
        x_i = torch.randn(n_sample, *size).to(device)

        # Generate class labels for conditional image gen
        c_i = torch.arange(0, self.num_classes).to(device)
        c_i = c_i.repeat(int(n_sample/c_i.shape[0]))

        # Initialize the context mask for conditional generation
        context_mask = torch.zeros_like(c_i).to(device)

        # Duplicate the inputs for guide comparison
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1.

        x_i_store = []

        print()
        for i in range(self.n_T, 0, -1):
            print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )

            if i%20==0 or i==self.n_T:
                x_i_store.append(x_i.detach().cpu().numpy())

        x_i_store = np.array(x_i_store)
        return cvtImg(x_i.detach().cpu(), 4), x_i_store
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.learning_rate)
    
    def training_step(self, batch, batch_idx):
        x, c = batch
        loss_ema = None
        
        loss = self(x, c)

        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95 * loss_ema + 0.05 * loss.item()

        if self.step % 5000 == 0 and self.step != 0:
            print(f"Step {self.step:03d} | Loss: {loss.item()}")
            
            dict = self.state_dict()

            dict["pytorch-lightning_version"] = '2.1.0'
            dict["global_step"] = self.step
            dict["epoch"] = self.current_epoch
            dict["state_dict"] = self.state_dict()

            now = datetime.now()

            dt_string = now.strftime("%d|%m|%Y %H:%M:%S")
            
            torch.save(dict, "saves/diffusion_model_step: " + self.step.__str__() + "|time: " + dt_string + ".ckpt")

        self.step += 1

        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        x, c = batch

        n_sample = 4*n_classes
        x_gen, x_gen_store = self.sample(n_sample, (3, img_size, img_size), device, guide_w=1.0)

        x_gen = x_gen.permute(1,2,0)
        
        # remove *-1

        grid = make_grid(transforms.ToTensor()(transforms.ToPILImage()(x_gen)), nrow=10)
        save_image(grid, save_dir + f"image_{self.step}.png")
        print('saved image at ' + save_dir + f"image{self.step}.png")

In [117]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations, attr, img_dir, transform=None):
        self.img_labels = annotations
        self.attr = attr
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx])
        
        image = read_image(img_path)
        label = self.img_labels[idx]

        if self.transform:
            image = self.transform(image)

        val = ""
        for i in range(self.attr[idx].__len__()):
            val = val + self.attr[idx][i].__str__()

        return image, int(val)

In [118]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

all_attr = pd.read_csv("/home/maxim/Downloads/archive (2)/list_attr_celeba.csv", index_col="image_id")
all_attr[all_attr.columns] = (all_attr[all_attr.columns] + 1) / 2

id = pd.read_csv("/home/maxim/Downloads/archive (2)/list_attr_celeba.csv").pop("image_id")

attr_list = ["Male"]

data = CustomImageDataset(id, all_attr[
                                        attr_list
                                      ].astype(int).values,
                                               "/home/maxim/Downloads/archive (2)/img_align_celeba/img_align_celeba",
                                                 transform)

data = torch.utils.data.ConcatDataset([data])


trainloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=20)

In [119]:


# tf = transforms.Compose([
#                                transforms.Grayscale(3),
#                                transforms.Resize(28),
#                                transforms.ToTensor(),
#                                transforms.Normalize((0.5,), (0.5,)),])

In [120]:
torch.cuda.empty_cache()

In [121]:
torch.set_float32_matmul_precision('medium')

In [122]:
# dataset = CIFAR10("./data", download=True, transform=tf)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=20)

ddpm = DDPM(nn_model=ContextUnet(in_channels=3, n_feat=n_feat, n_classes=2),
             betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1, learning_rate=l_rate, num_classes=n_classes)

trainer = pl.Trainer(accelerator=device, max_epochs=500, precision='16-mixed')

# ddpm = DDPM.load_from_checkpoint('saves/diffusion_model_step: 5000|time: 14|12|2023 10:08:22.ckpt')

trainer.fit(ddpm, train_dataloaders=trainloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type         | Params
------------------------------------------
0 | nn_model | ContextUnet  | 7.6 M 
1 | loss_mse | SmoothL1Loss | 0     
------------------------------------------
7.6 M     Trainable params
0         Non-trainable params
7.6 M     Total params
30.281    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
trainer.validate(ddpm, dataloaders=trainloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/maxim/anaconda3/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Validation: |          | 0/? [00:00<?, ?it/s]


saved image at ./imgs/image24192.png

saved image at ./imgs/image24192.png

saved image at ./imgs/image24192.png

saved image at ./imgs/image24192.png

saved image at ./imgs/image24192.png

sampling timestep 8770