In [None]:
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from PIL import Image
import logging
from tqdm import tqdm
from torch import optim
from torch.utils.data import Dataset, DataLoader

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, 0  # Returning 0 as a dummy label

In [None]:
class Diffusion:
    def __init__(self, noise_steps = 1000, img_size = 256, device = "cuda"):
        self.img_size = img_size
        self.device = device
        self.noise_steps = noise_steps
        self.beta = Diffusion.cosine_beta_schedule(self.noise_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim = 0)

    @staticmethod
    def cosine_beta_schedule(timesteps = 1000, s = 0.008):
        steps = timesteps + 1
        x = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)

    # def prepare_noise_schedule(self):
    #     return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low = 1, high = self.noise_steps, size=(n,))

    def sample(self, model, n):

        # The number of new images to sample.
        logging.info(f"Sampling {n} new images....")
        model.eval()

        with torch.no_grad():            # disables gradient calculations, which saves memory and computation

            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)  #x : tensor of random noise

            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):   # tqdm provides a progress bar for the loop, making it easier to track progress.

                t = (torch.ones(n) * i).long().to(self.device)
                # t = torch.tensor([i for _ in range(i)])

                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]

                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)

                #denoising formula
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise

# Image vis. libraries expect image data in the range [0, 1] for floating-point values.
# x.clamp(-1, 1) ensures that all values in x are within the range [-1, 1] to handle any possible outliers that may fall outside this range
# Adding 1 shifts the range from [-1, 1] to [0, 2].
# Dividing by 2 scales the range from [0, 2] to [0, 1].

        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)   #convert image into png format
        model.train()

        return x

In [None]:
# from torch.utils.data import DataLoader
# from tqdm import
def setup_logging(run_name):
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)

import numpy as np
def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    if ndarr.dtype != np.uint8:
        ndarr = (ndarr * 255).astype(np.uint8)
    im = Image.fromarray(ndarr)
    im.save(path)

    display(im)


# class Gray(object):
#     def __call__(self,img):
#         return img.convert("RGB")

import torchvision

transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((128, 128)),  # Resize images to 256x256
    torchvision.transforms.ToTensor(),         # Convert images to tensor
    torchvision.transforms.Normalize(mean=[0, 0, 0],
                std=[1, 1, 1])  # Normalize using ImageNet's mean and std
])
dataset_path = '/content/drive/My Drive/Colab Notebooks/face_dataset/Humans'
dataset = CustomImageDataset(image_dir = dataset_path, transform= transforms)

# dataset_path = "cars_train"
# car_dataset = torchvision.datasets.ImageFolder(root = dataset_path, transform=transforms)
# dataset = torchvision.datasets.CIFAR10(root='.',download=True ,transform = transforms)
dataloader = DataLoader(dataset, batch_size = 32, shuffle=True)
# dataloader
if dataloader:
    print("DataLoader successfully created!")
    print("Length of DataLoader:", len(dataloader.dataset))
    print("Batch size:", dataloader.batch_size)


DataLoader successfully created!
Length of DataLoader: 7211
Batch size: 32


In [None]:
for i, batch in enumerate(dataloader):
    print(batch)
    break

[tensor([[[[ 0.5843,  0.5922,  0.6000,  ...,  0.6235,  0.6157,  0.6000],
          [ 0.5922,  0.6000,  0.6078,  ...,  0.6314,  0.6235,  0.6078],
          [ 0.6000,  0.6078,  0.6235,  ...,  0.6314,  0.6314,  0.6157],
          ...,
          [ 0.4588,  0.4667,  0.4667,  ...,  0.1451,  0.2078,  0.2471],
          [ 0.4353,  0.4353,  0.4431,  ...,  0.0588,  0.0902,  0.1059],
          [ 0.4118,  0.4196,  0.4196,  ..., -0.0745, -0.0275, -0.0431]],

         [[ 0.5843,  0.5922,  0.6000,  ...,  0.6235,  0.6157,  0.6000],
          [ 0.5922,  0.6000,  0.6078,  ...,  0.6314,  0.6235,  0.6078],
          [ 0.6000,  0.6078,  0.6235,  ...,  0.6314,  0.6314,  0.6157],
          ...,
          [ 0.4588,  0.4667,  0.4667,  ...,  0.1451,  0.2078,  0.2471],
          [ 0.4353,  0.4353,  0.4431,  ...,  0.0588,  0.0902,  0.1059],
          [ 0.4118,  0.4196,  0.4196,  ..., -0.0745, -0.0275, -0.0431]],

         [[ 0.5843,  0.5922,  0.6000,  ...,  0.6235,  0.6157,  0.6000],
          [ 0.5922,  0.6000, 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
epochs = 50
# from UNET import UNet

run_name = 'bhumika'
lr = 1e-4
def train():
    setup_logging(run_name)
    # device = ''
    # dataloader = get_data(args)
    model = UNet1().to(device)
    # model.load_state_dict(torch.load('models/gungun_face1/ckpt.pt'))
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.6)

    mse = nn.MSELoss()
    diffusion = Diffusion(noise_steps=1000, img_size = 64, device=device)
    # logger = SummaryWriter(os.path.join("runs", run_name))
    l = len(dataloader)

    for epoch in range(epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, _) in enumerate(pbar):
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            # print(t)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pbar.set_postfix(MSE=loss.item())
            # logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
        scheduler.step()
        
        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, os.path.join("results", run_name, f"{epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("models", run_name, f"ckpt.pt"))

In [None]:
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image


# %%
class DoubleConv(nn.Module):
    def __init__ (self, in_channels, out_channels, mid_channels = None, residual = False):
        super(). __init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(nn.Conv2d(in_channels,mid_channels,kernel_size = 3, padding = 1, bias = False),
                                     nn.GroupNorm(16,mid_channels),
                                     nn.GELU(),
                                     nn.Conv2d(mid_channels,out_channels,kernel_size = 3, padding = 1, bias = False),
                                     nn.GroupNorm(16,out_channels),
                                     nn.GELU(),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
           return self.double_conv(x)


# %%
class Down(nn.Module):
    def __init__ (self, in_channels, out_channels, emb_dim = 256):
        super(). __init__()
        self.down_conv = nn.Sequential(nn.MaxPool2d(kernel_size = 2), #decrease spatial dimensions by half
                                  DoubleConv(in_channels, in_channels, residual=True), #No of channels remains same when residual true
                                  DoubleConv(in_channels, out_channels),

        )

        self.down_emb = nn.Sequential(nn.SiLU(),
                                      nn.Linear(emb_dim,out_channels))

    def forward(self, x, t):
        x = self.down_conv(x)
        emb = self.down_emb(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])

        return x+emb

# %%
class Up(nn.Module):
    def __init__(self, in_channels,out_channels,emb_dim = 256):
         super(). __init__()
         self.unsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)  #increase h,w by scale_factor
         self.up_conv = nn.Sequential( DoubleConv(in_channels, in_channels, residual=True),
                                  DoubleConv(in_channels, out_channels, in_channels // 2)) #n_channels // 2 bcoz due to concat channels increases

         self.up_emb = nn.Sequential(nn.SiLU(),
                                      nn.Linear(emb_dim,out_channels))

    def forward(self, x, skip_x, t):
        # #print("upsampling",x.shape)
        x = self.unsample(x)   #increase spatial dimensions by double
        x = torch.cat([x,skip_x], dim = 1) # x and the skip_x have same spatial dimensions but dfrnt channels.
                                           #Concat along the channel dimension (dim=1) ensures that the features from both tensors are combined
                                           # channel-wise, preserving the spatial information.
        x = self.up_conv(x)
        emb = self.up_emb(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return  x + emb

# %%
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

# %%
class UNet1(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.input_conv = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        # self.sa1 = Attention_block(128,128,128)
        self.down2 = Down(128, 256)
        # self.sa2 = Attention_block(256,256,256)
        self.down3 = Down(256, 256)
        self.sa3 = Attention_block(256,256,256)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = Attention_block(128,128,128)
        #self.sa4 = SelfAttention(128, 16)
        self.up2 = Up(256, 64)
        self.sa5 = Attention_block(64,64,64)
        #self.sa5 = SelfAttention(64, 32)
        self.up3 = Up(128, 64)
        self.sa6 = Attention_block(64,64,64)
        #self.sa6 = SelfAttention(64, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size = 1)


    def pos_encoding(self,t):
            channels = self.time_dim
            inv_freq  = 1.0 / (
                10000
                ** (torch.arange(0, channels, 2, device="cuda").float() / channels)
            )
            pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
            pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
            pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
            return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t)

        x1 = self.input_conv(x)
        x2 = self.down1(x1, t)
        # x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        # x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4_sa = self.sa3(x4,x4)
        x4 = x4_sa+x4

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x,x)+x # explicitely one skip to be added.

        x = self.up2(x, x2, t)
        x = self.sa5(x,x)+x
        x = self.up3(x, x1, t)
        x = self.sa6(x,x)+x
        output = self.outc(x)
        return output


In [None]:
train()