<a href="https://colab.research.google.com/github/Rohit1217/DDPM/blob/main/Diffusion1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import random
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset,DataLoader

from datasets_gen import get_mnist_,get_cifar10
from noise_scheduler import noise_scheduler,get_at
import unet

In [None]:
Device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def trainload(dataset='mnist'):
  if dataset=='mnist':
    im_data=get_mnist_(normalize=True)
    im_data = F.pad(im_data, (1, 2, 1, 2), mode='constant', value=-1)
    im_data=im_data.unsqueeze(1)
    im_tensor=TensorDataset(im_data,im_data)
    trainloader= DataLoader(im_tensor,batch_size=64,shuffle=True)
  elif dataset=='cifar10':
    trainloader=get_cifar10()
  return trainloader

In [None]:
trainloader=trainload('cifar10')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48268414.36it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
model=unet.Unet(3,16)
model=model.to(Device)
optim=torch.optim.Adam(model.parameters(),lr=0.001)
criterion=F.mse_loss

In [None]:
# Set the number of epochs
num_epochs = 100
loss_list=[]

for epoch in range(num_epochs):
    count=0
    total_loss=0
    for images,_ in trainloader:
        images = images.to(Device)
        b,f,c,c=images.shape
        t=random.randint(0, 1000)
        noise=torch.randn(b,f,c,c).to(Device)

        noised_image = noise_scheduler(images,noise,t)
        pred_noise=model(noised_image,t)
        loss = criterion(pred_noise,noise)

        count+=1
        total_loss+=loss.item()

        loss.backward(retain_graph=True)
        optim.step()
        optim.zero_grad()

    print(f"Epoch [{epoch+1}/{num_epochs}], Avg_Loss: {total_loss/count:.4f}, Sum_loss:{total_loss:.4f}")

Epoch [1/100], Avg_Loss: 0.2689, Sum_loss:56.7331
Epoch [2/100], Avg_Loss: 0.1529, Sum_loss:32.2695
Epoch [3/100], Avg_Loss: 0.1405, Sum_loss:29.6448
Epoch [4/100], Avg_Loss: 0.1168, Sum_loss:24.6463
Epoch [5/100], Avg_Loss: 0.1210, Sum_loss:25.5407


KeyboardInterrupt: 

In [None]:
def DDPM_sampler(model,images,t,n=0.5):
  with torch.no_grad():
      model.eval
      shape=images.shape
      pred_noise=model(images,t)
      at=get_at(t)
      at1=get_at(t-1)
      base_at=at/at1
      c=(1-base_at)/(math.sqrt(1-at))
      #print(math.sqrt(1-base_at),base_at,1-at)
      if t>1:
        pred_image=math.sqrt(1/(base_at))*(images-c*pred_noise) + torch.randn(shape,device=Device)*math.sqrt(1-base_at)
      else:
        pred_image=math.sqrt(1/(base_at))*(images-c*pred_noise)
      return pred_image

In [None]:
def DDIM_sampler(model,images,t,n=1):
    with torch.no_grad():
        model.eval()
        shape=images.shape
        pred_noise=model(images,t)
        at=get_at(t)
        at1=get_at(t-1)
        pred_x0=math.sqrt(at1/at)*(images-math.sqrt(1-at)*pred_noise)
        if t>1:
          sigma=n*((1-at1)/(1-at))*(1-(at/at1))
          #print(sigma,at1,at,(1-at1)/(1-at),at1>at,t)
          dir_xt=math.sqrt(1-at1-sigma)*pred_noise
          random_noise=torch.randn(shape,device=Device)*math.sqrt(sigma)
          images=pred_x0+dir_xt+random_noise
        else:
          images=pred_x0
        return images

In [None]:
x=torch.randn(64,3,31,31).to(Device)
timesteps=1000
for i in range(0,timesteps):
    t=timesteps-i
    x=DDIM_sampler(model,x,t,1)

    if(t%100==0 or t==1):
      y = x[:64]
      y=y*0.5+0.5
      y=y.permute(0,2,3,1)
      y=y.view(8,8,31,31,3)
      y = y.numpy()
      # Create a figure and a set of subplots arranged in an 8x8 grid
      fig, axes = plt.subplots(8, 8, figsize=(8, 8))
      # Iterate through the subplots and display each image
      for i in range(8):
          for j in range(8):
              axes[i, j].imshow(y[i,j])  # Normalize pixel values to [0, 1]
              axes[i, j].axis('off')  # Turn off axis labels
      plt.savefig('samples.jpeg')
      plt.show()
    break
plt.savefig('samples.jpeg')

1000 tensor([[[ 9.2544e+12,  3.6232e+12,  3.3041e+12,  ...,  2.2205e+12,
          -5.1878e+12,  7.4701e+12],
         [-2.3733e+12,  6.3732e+12,  2.9314e+12,  ..., -2.3822e+12,
           2.1448e+12,  3.1022e+11],
         [ 4.0354e+12,  1.5897e+12, -2.0048e+12,  ..., -1.2672e+12,
          -2.2812e+12, -2.4225e+11],
         ...,
         [ 4.1347e+12,  4.0728e+12,  1.0447e+13,  ...,  6.0252e+11,
          -2.9996e+12, -7.3098e+12],
         [ 4.8388e+12,  7.8723e+12,  1.2225e+12,  ...,  5.3567e+11,
          -2.8956e+12, -8.1949e+12],
         [-9.9464e+12,  5.8371e+12,  6.0576e+12,  ...,  2.0097e+12,
          -1.8868e+12, -2.2143e+13]],

        [[-4.8718e+12,  3.5043e+12, -3.5116e+11,  ..., -2.7097e+12,
           5.9637e+11, -9.9221e+12],
         [-7.3348e+12,  3.7595e+12, -9.4107e+11,  ..., -1.3905e+12,
          -1.8483e+12,  4.6977e+12],
         [ 1.8731e+12,  4.7821e+12,  3.6384e+12,  ..., -4.5283e+11,
          -1.1617e+12, -2.2008e+12],
         ...,
         [ 6.8828e+1

<Figure size 640x480 with 0 Axes>

In [None]:
def DDIM_sampler2(model,at=0.97,ts=650):
    with torch.no_grad():
        image=torch.randn(64,1,31,31).to(Device)
        for t in range(1,ts):
            noise=torch.randn(64,1,31,31).to(Device)
            pred_noise=model(image,ts-t)
            #pred_noise_t1=model(image,ts-t-1)
            vart=math.sqrt(1-math.pow(at,ts-t))
            vart1=math.sqrt(1-math.pow(at,ts-t))
            alpha=(1-at)/vart
            if t < 0.85*ts:
              image=(image-alpha*pred_noise)/math.sqrt(at)
            else:
              image=(image-alpha*pred_noise)/math.sqrt(at)
    return image


In [None]:
x=im_data[0]
for i in range(1,4000):
  x=noise_scheduler(x,torch.randn(1,1,31,31),i)
  x = x.to('cpu')
  image_np = x.numpy()
  plt.imshow(image_np[0, 0], cmap='gray')
  plt.axis('off')
  plt.show()
  print(i)

In [None]:
im=DDIM_sampler2(model,ts=450)
x=im[:64]
x = x[:64].view(8, 8, 31, 31)
x = x.to('cpu')
image_np = x.numpy()

# Create a figure and a set of subplots arranged in an 8x8 grid
fig, axes = plt.subplots(8, 8, figsize=(8, 8))

# Iterate through the subplots and display each image
for i in range(8):
    for j in range(8):
        axes[i, j].imshow(image_np[i, j], cmap='gray')
        axes[i, j].axis('off')  # Turn off axis labels
plt.savefig('samples.jpeg')
plt.show()