In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm

In [2]:
def get_data(path):
    image_extensions = ['.jpg']
    image_names = []
    for filename in os.listdir(path):
        if any(filename.lower().endswith(ext) for ext in image_extensions):
            image_names.append(filename)
    return image_names

In [3]:
path = '/Users/ayanfe/Documents/Datasets/animefaces256cleaner'
path1 = '/Users/ayanfe/Documents/Code/Diffusion Model/model/model.pth'
image_names = get_data(path)
print(len(image_names))

92219


In [4]:
timesteps = 200

# create a fixed beta schedule
beta = np.linspace(0.0001, 0.02, timesteps)

# this will be used as discussed in the reparameterization trick
alpha = 1 - beta
alpha_bar = np.cumprod(alpha, 0)
alpha_bar = np.concatenate((np.array([1.]), alpha_bar[:-1]), axis=0)
sqrt_alpha_bar = np.sqrt(alpha_bar)
one_minus_sqrt_alpha_bar = np.sqrt(1-alpha_bar)

# this function will help us set the RNG key for Numpy
def set_key(key):
    np.random.seed(key)

# this function will add noise to the input as per the given timestamp
def forward_noise(key, x_0, t):
    set_key(key)
    noise = np.random.normal(size=x_0.shape)
    reshaped_sqrt_alpha_bar_t = np.reshape(np.take(sqrt_alpha_bar, t), (-1, 1, 1, 1))
    reshaped_one_minus_sqrt_alpha_bar_t = np.reshape(np.take(one_minus_sqrt_alpha_bar, t), (-1, 1, 1, 1))
    noisy_image = reshaped_sqrt_alpha_bar_t  * x_0/255 + reshaped_one_minus_sqrt_alpha_bar_t  * noise
    return noisy_image, noise

# this function will be used to create sample timestamps between 0 & T
def generate_timestamp(key, num):
    set_key(key)
    return torch.randint(0, timesteps,(num,), dtype=torch.int32)

def reshape_img(img):
    data = cv2.resize(img,(64,64))
    data = np.transpose(img,(2,0,1))
    return data

In [5]:
def ddim(x_t, pred_noise, t, sigma_t):
    #alpha_t_bar = np.take(alpha_bar, t)
    #alpha_t_minus_one = np.take(alpha, t-1)
    alpha_t_bar = np.take(alpha_bar, t.astype(int))  # Cast t to integer before using np.take
    alpha_t_minus_one = np.take(alpha, (t - 1).astype(int))  # Similarly, cast t-1 to integer

    #alpha_t_bar = torch.from_numpy(alpha_t_bar)
    #alpha_t_minus_one = torch.from_numpy(alpha_t_minus_one)

    pred = (x_t - ((1 - alpha_t_bar) ** 0.5) * pred_noise)/ (alpha_t_bar ** 0.5)
    pred = (alpha_t_minus_one ** 0.5) * pred

    pred = pred + ((1 - alpha_t_minus_one - (sigma_t ** 2)) ** 0.5) * pred_noise
    eps_t = np.random.normal(size=x_t.shape)
    pred = pred+(sigma_t * eps_t)

    return pred

In [6]:
def inference(model, device, num_samples=5):
    # Define number of inference loops to run
    inference_timesteps = 10
    
    # Create a range of inference steps that the output should be sampled at
    inference_range = range(0, timesteps, timesteps // inference_timesteps)
    
    x = torch.randn(1, 3, 64, 64).to(device)  # Initialize a random input image on GPU
    img_list = []
    img_list.append(np.squeeze(x.cpu().numpy(), 0))  # Append the initial image
    
    # Iterate over inference_timesteps
    for i in range(inference_timesteps):
        t = np.expand_dims(inference_range[i], 0)
        t = torch.from_numpy(t).type(torch.float32).to(device)
    
        pred_noise = model(x, t).detach().to(device)  # Obtain predicted noise
        
        x = ddim(x.cpu().detach().numpy(), pred_noise.cpu().detach().numpy(), t.cpu().detach().numpy(), 0)  # Perform denoising using DDIM
        x = torch.from_numpy(x).type(torch.float32).to(device)  # Transfer denoised image back to GPU
        img_list.append(np.squeeze(x.cpu().detach().numpy(), 0))  # Append the denoised image
    
    # Visualize the final denoised image
    plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
    plt.show()

In [12]:
from torch.utils.data import DataLoader

def training_loop(n_epochs, optimizer, model, loss_fn, device, batch_size=32):
    dataset = MyDataset(image_names, path)  # Define your custom dataset class to load images on-the-fly
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0

        for batch_idx, (imgs, noise, t) in enumerate(dataloader):
            imgs = imgs.to(device)
            noise = noise.to(device)
            t = t.to(device)

            imgs = imgs.view(-1, 3, 256, 256)
            
            optimizer.zero_grad()
            outputs = model(imgs, t)
            loss = loss_fn(outputs, noise)
            loss.backward()
            optimizer.step()

            loss_train += loss.item()

            print("Epoch: {}, Batch: {}/{}, Loss: {:.6f}".format(epoch, batch_idx+1, len(dataloader), loss.item()))

        print('Epoch {}, Training loss: {:.6f}'.format(epoch, loss_train / len(dataloader)))

# Define your custom dataset class to load images on-the-fly
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, image_names, path):
        self.image_names = image_names
        self.path = path

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img = plt.imread(os.path.join(self.path, self.image_names[idx]))
        img = reshape_img(img)
        img = np.expand_dims(img, 0)

        t = generate_timestamp(0, timesteps)
        t = torch.reshape(t, (-1, 1)).type(torch.float32)

        imgs, noise = forward_noise(0, img, t)

        imgs = torch.from_numpy(imgs).type(torch.float32)
        noise = torch.from_numpy(noise).type(torch.float32)
        t = t.to(device)

        return imgs, noise, t

In [13]:
from unet import Unet


model = Unet()

optimizer = optim.Adam(model.parameters())  #  <3>
loss_fn = nn.MSELoss()  #  <4>
#device = torch.device("mps")  # Assuming Metal GPU is your target
device = torch.device("mps")
model.to(device)
'''
for name, param in model.named_parameters():
    print(name, param.dtype)
'''
training_loop(  # <5>
    n_epochs = 1000,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    device = device
)


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 200, 3, 256, 256]

In [None]:
inference(model, device)