## Functions for forward diffusion

# Diffusion Model

In [14]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from glob import glob
import matplotlib.pyplot as plt
from skimage import io
import numpy as np

IMG_SIZE = 64
BATCH_SIZE = 32

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

## Using 300 time steps

In [15]:
# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [16]:
class CustomDataset(Dataset):
  def __init__(self, root_dir, transform=transforms.Compose([transforms.ToTensor(),
                                                             transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                                             transforms.RandomHorizontalFlip(),
                                                             transforms.Lambda(lambda t: (t * 2) - 1)
                                                             ])):
    self.root_dir = root_dir
    self.transform = transform
    self.img_list = glob(self.root_dir+"/*")

  def __len__(self):
    return len(self.img_list)

  def __getitem__(self, idx):
    image = io.imread(self.img_list[idx])

    if self.transform:
      image = self.transform(image)

    return image

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t+1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t * 255),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])
    
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))
    plt.axis("off")

In [17]:
# # Testing the forward diffusion step
# image = next(iter(train_loader))[0]

# plt.figure(figsize = (15, 15))
# plt.axis("off")
# num_images = 10
# stepsize = int(T / num_images)

# for idx in range(0, T, stepsize):
#     t = torch.tensor([idx]).type(torch.int64)
#     plt.subplot(1, num_images+1, (idx // stepsize) + 1)
#     show_tensor_image(image)
#     image, noise = forward_diffusion_sample(image, t)

In [18]:
data = CustomDataset('/home/maxim/Documents/TestProject/maxim-lightning/archive (4)/bitmojis/')
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [19]:
from torch import nn
import math

import pytorch_lightning as pl


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1 
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)
    
    def prepare_data(self):
        transform = transforms.Compose(
            [
                transforms.Resize((IMG_SIZE, IMG_SIZE)),
                transforms.ToTensor(),
            ]
        )

        end_train_idx = int(len(dataloader.dataset) - len(dataloader.dataset) / 5)
        end_val_idx = int(len(dataloader.dataset) - len(dataloader.dataset) / 7)
        end_test_idx = len(dataloader.dataset)

        self.train_dataset = Subset(dataloader.dataset, range(0, end_train_idx))
        self.val_dataset = Subset(dataloader.dataset, range(end_train_idx + 1, end_val_idx))
        self.test_dataset = Subset(dataloader.dataset, range(end_val_idx + 1, end_test_idx))

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=20
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=1, num_workers=20
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=1, num_workers=20
        )
    
    def configure_optimizers(self):
         return torch.optim.Adam(self.parameters(), lr=0.0007)

    def training_step(self, batch, batch_idx):
        x0 = batch[0]
        n = len(x0)
        
        t = torch.randint(0, T, (n,), device="cuda")
        
        loss = get_loss(self, batch[0], t)

        self.step += 1

        if self.step % 500 == 0:
            sample_plot_image(self)
            print(f"Step {self.step:03d} | Loss: {loss.item()}")
            torch.save(self.state_dict(), "diffusion_model_celeba.pth")

        logs = {"loss": loss}
        return {"loss": loss, "log": logs}
    
    def validation_step(self, batch, batch_idx):
        x = batch

        t = torch.randint(T - 5, T, (1,), device="cuda").long()

        output = self(x, t)

        loss = get_loss(self, batch[0], t)

        reverse_transforms = v2.Compose([
            transforms.Lambda(lambda t: (t + 1) / 2),
            transforms.Lambda(lambda t: t.permute(1, 2, 0)),
            transforms.Lambda(lambda t: t * 255.),
            transforms.Lambda(lambda t: np.absolute(t.numpy()).astype(np.uint8)),
            transforms.ToTensor()
        ])

        output_img = output[0].cpu()

        self.val_outputs.append(reverse_transforms(output_img))

        self.log("val loss", loss)

        logs = {"loss": loss}
        return {"loss": loss, "log": logs}
    
    def on_validation_epoch_end(self):
        grid = vutils.make_grid(self.val_outputs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

        self.val_outputs.clear()

In [20]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    print((torch.isnan(noise_pred.view(-1)).sum().item()==0).__str__())
    return F.l1_loss(noise, noise_pred) # can use l2 loss

In [21]:
@torch.no_grad()
def sample_timestep(x, t, model):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image(model):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(20,20))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, model)
        if i % stepsize == 0:
            plt.subplot(1, num_images, i//stepsize+1)
            show_tensor_image(img.detach().cpu())
    plt.show()            

In [22]:
celeba_dataset = CustomDataset(root_dir = "/home/maxim/Documents/TestProject/maxim-lightning/archive (4)/bitmojis")
train_loader = DataLoader(dataset=celeba_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last = True)

In [23]:
model = SimpleUnet()
model

SimpleUnet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transfor

In [24]:
from pytorch_lightning.callbacks import ModelCheckpoint

In [25]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [26]:
checkpoint_callback = ModelCheckpoint(dirpath='/home/maxim/Documents/TestProject/maxim-lightning/checkpoints/', every_n_epochs = 5)

In [27]:
trainer = pl.Trainer(max_epochs=2, precision='16-mixed', accelerator="cuda", callbacks=[checkpoint_callback])

with torch.autograd.detect_anomaly():
    trainer.fit(model)
    trainer.validate(model)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  with torch.autograd.detect_anomaly():
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | time_mlp | Sequential | 1.1 K 
1 | conv0    | Conv2d     | 1.8 K 
2 | downs    | ModuleList | 41.2 M
3 | ups      | ModuleList | 21.3 M
4 | output   | Conv2d     | 195   
----------------------------------------
62.4 M    Trainable params
0         Non-trainable params
62.4 M    Total params
249.756   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



ValueError: expected 4D input (got 3D input)

In [None]:
# from torch.optim import Adam
# from tqdm import tqdm

# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = SimpleUnet()
# model.to(device)
# optimizer = Adam(model.parameters(), lr=0.001)
# epochs = 4 # Try more!

# print("# of steps:", len(train_loader))
# for epoch in range(epochs):
#     for step, batch in tqdm(enumerate(train_loader), position=0, leave=True):
#         optimizer.zero_grad()
#         t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
#         batch = batch.to(device)
#         loss = get_loss(model, batch, t)
#         loss.backward()
#         optimizer.step()
#         if step % 3000 == 0:
#             sample_plot_image(model)
#             print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()}")
#             torch.save(model.state_dict(), "diffusion_model_celeba.pth")

In [None]:
from torch.optim import Adam
from tqdm import tqdm

# BATCH_SIZE = 32
bitmoji_dataset = CustomDataset(root_dir = "/kaggle/input/bitmojis/bitmojis/")
train_loader = DataLoader(dataset=bitmoji_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last = True)

device = "cuda" if torch.cuda.is_available() else "cpu"
# bitmoji_model = SimpleUnet()
bitmoji_model.to(device)
optimizer = Adam(bitmoji_model.parameters(), lr=0.001)
epochs = 4 # Try more!

print("# of batches:", len(train_loader))
for epoch in range(epochs):
    for step, batch in tqdm(enumerate(train_loader), position=0, leave=True):
        optimizer.zero_grad()
        t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
        batch = batch.to(device)
        loss = get_loss(bitmoji_model, batch, t)
        loss.backward()
        optimizer.step()
        if step % 1500 == 0:
            sample_plot_image(bitmoji_model)
            print(f"Epoch {epoch} | Step {step:03d} | Loss: {loss.item()}")
            torch.save(bitmoji_model.state_dict(), "diffusion_model_bitmoji.pth")

In [None]:
# Creating directory of 1000 original images
from PIL import Image
import shutil
import os

path1 = "/kaggle/output/og_imgs/"
path2 = "/kaggle/output/gen_imgs/"

path = "/kaggle/input/bitmojis/bitmojis/*"
img_list = glob(path)

os.makedirs(path1, exist_ok = True)
os.makedirs(path2, exist_ok = True)

for i in range(1000):
    shutil.copy(img_list[i], path1+f'{i}.jpeg')

In [None]:
gen_imgs = np.zeros((1000, 64, 64, 3))
start = 0
batch_size = 20

@torch.no_grad()
def sample_plot_image(model):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    plt.figure(figsize=(20,20))
    plt.axis('off')

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, model)
        if i  == 0:
            show_tensor_image(img.detach().cpu())
    plt.show()            

for j in tqdm(range(0, 1000)):
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, model)
        if i  == 0:
#             show_tensor_image(img.detach().cpu())
#             plt.show()
            gen_imgs[j] = img[0].permute(1, 2, 0).detach().cpu().numpy()

for i in range(1000):
    Image.fromarray((gen_imgs[i]*255).astype(np.uint8)).save(path2 + f"{i}.jpeg")

!pip install pytorch-fid
!python3 -m pytorch_fid ../output/og_imgs/ ../output/gen_imgs/ --device cpu

In [None]:
# Creating directory of 1000 original images
from PIL import Image
import shutil
import os

path1 = "/kaggle/celeba_output/og_imgs/"
path2 = "/kaggle/celeba_output/gen_imgs/"

path = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/*"
img_list = glob(path)

os.makedirs(path1, exist_ok = True)
os.makedirs(path2, exist_ok = True)

for i in range(1000):
    shutil.copy(img_list[i], path1+f'{i}.jpeg')
    
gen_imgs = np.zeros((1000, 64, 64, 3))
start = 0
batch_size = 20

@torch.no_grad()
def sample_plot_image(model):
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    plt.figure(figsize=(20,20))
    plt.axis('off')

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, model)
        if i  == 0:
            show_tensor_image(img.detach().cpu())
    plt.show()            

for j in tqdm(range(0, 1000)):
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, bitmoji_model)
        if i  == 0:
            gen_imgs[j] = img[0].permute(1, 2, 0).detach().cpu().numpy()

for i in range(1000):
    Image.fromarray((gen_imgs[i]*255).astype(np.uint8)).save(path2 + f"{i}.jpeg")

!pip install pytorch-fid
!python3 -m pytorch_fid ../celeba_output/og_imgs/ ../celeba_output/gen_imgs/ --device cpu

In [None]:
img_size = IMG_SIZE
num_images = 10
stepsize = int(T/num_images)

for k in range(10):
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    plt.figure(figsize=(20,20))
    plt.axis('off')
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, model)
        if i % stepsize == 0:
            plt.subplot(k+1, num_images, i//stepsize+1)
            show_tensor_image(img.detach().cpu())
    plt.show()

In [None]:
img_size = IMG_SIZE
num_images = 10
stepsize = int(T/num_images)

for k in range(10):
    img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    plt.figure(figsize=(20,20))
    plt.axis('off')
    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t, bitmoji_model)
        if i % stepsize == 0:
            plt.subplot(k+1, num_images, i//stepsize+1)
            show_tensor_image(img.detach().cpu())
    plt.show()

In [None]:
img_size = IMG_SIZE
num_images = 10
stepsize = int(T/num_images)

plt.figure(figsize=(20,20))
plt.axis('off')
for k in range(10):
    for j in range(10):
        img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
        for i in range(0,T)[::-1]:
            t = torch.full((1,), i, device=device, dtype=torch.long)
            img = sample_timestep(img, t, bitmoji_model)
            if i == 0:
                plt.subplot(10, 10, 10*k + j + 1)
                show_tensor_image(img.detach().cpu())
plt.show()

In [None]:
img_size = IMG_SIZE
num_images = 10
stepsize = int(T/num_images)

plt.figure(figsize=(20,20))
plt.axis('off')
for k in range(10):
    for j in range(10):
        img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
        for i in range(0,T)[::-1]:
            t = torch.full((1,), i, device=device, dtype=torch.long)
            img = sample_timestep(img, t, model)
            if i == 0:
                plt.subplot(10, 10, 10*k + j + 1)
                show_tensor_image(img.detach().cpu())
plt.show()