In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
#from diffusers import DDIMScheduler,DDPMScheduler
import os
from unet import UNet

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # Use CUDA
elif torch.backends.mps.is_available():
    device = torch.device("mps")  # Use MPS
else:
    device = torch.device("cpu")  # Fallback to CPU

In [3]:
class catDataset(Dataset):
    def __init__(self, img_dir, second_img_dir):
        self.img_dir = img_dir
        self.img_names = []

        for filename in os.listdir(img_dir):
            file_path = os.path.join(img_dir, filename)
            if os.path.isfile(file_path):
                self.img_names.append(file_path)
        
        for filename in os.listdir(second_img_dir):
            file_path = os.path.join(second_img_dir, filename)
            if os.path.isfile(file_path):
                self.img_names.append(file_path)

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        image = read_image(self.img_names[idx])
        image = image.to(torch.float32)/255.0
        return image


In [4]:
def get_model_size(model):
    return sum(p.numel() for p in model.parameters())


In [5]:
ds=catDataset("./cats","./mycat_64x64")
batch_size = 64
train_dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)

In [6]:
class DDPMScheduler(nn.Module):
    def __init__(self, num_time_steps: int=1000):
        super().__init__()
        self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)
        alpha = 1 - self.beta
        self.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)

    def forward(self, t):
        return self.beta[t], self.alpha[t]

In [7]:
total_timesteps = 1000
sch = DDPMScheduler(total_timesteps)
#sch.set_timesteps(50)
# Inference and training are totally seperate
# train always is t -> t-1 
#sch.step(img1,25,)
# Use gaussian noise 
# img, gaussian, timestep
# sch.add_noise(img1,torch.randn_like(img1),torch.tensor(999))
# plt.imshow can work with floats just complains a bit
# .permute(1, 2, 0)

In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        # Initialize EMA parameters with the same values as model parameters
        self.ema_params = {name: param.clone().detach() for name, param in model.named_parameters()}

    def update(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                self.ema_params[name] = self.decay * self.ema_params[name] + (1 - self.decay) * param

    def apply(self):
        for name, param in self.model.named_parameters():
            param.data = self.ema_params[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            param.data = self.ema_params[name]


In [None]:
def train_loop(dataloader, model, loss_fn, optimizer, ema):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X) in enumerate(dataloader):
        x_0 = X.to(device)
        curr_bs = X.shape[0]
        t = torch.randint(0,total_timesteps,(curr_bs,))
        noise = torch.randn_like(x_0,requires_grad=False)
        a = sch.alpha[t].view(curr_bs,1,1,1).to(device)
        x = (torch.sqrt(a)*x_0) + (torch.sqrt(1-a)*noise)

        pred = model(x,t[0])
        loss = loss_fn(noise,pred)

        # Backpropagation
        loss.backward()
        optimizer.step()
        ema.update()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            #print(torch.mps.current_allocated_memory()/1e9)


In [None]:
losf = nn.MSELoss()
model = UNet(3,3).to(device)
ema = EMA(model,decay=0.999)
print(f"Model has {get_model_size(model)} parameters")
#print(torch.mps.current_allocated_memory()/1e9)
epochs = 5
opt = torch.optim.AdamW(model.parameters(),2e-5)
for i in range(epochs):
    print(f"Epoch {i}")
    train_loop(train_dataloader,model,losf,opt,ema)

Model has 35455555 parameters
Epoch 0
loss: 1.253199  [   64/29854]
loss: 1.000623  [ 6464/29854]
loss: 0.998806  [12864/29854]
loss: 0.998822  [19264/29854]
loss: 1.000440  [25664/29854]
Epoch 1
loss: 0.998288  [   64/29854]
loss: 0.995701  [ 6464/29854]
loss: 0.999130  [12864/29854]
loss: 1.002276  [19264/29854]


KeyboardInterrupt: 

In [13]:
images = []
with torch.no_grad():
    # Initialize z
    z = torch.randn_like(ds[0]).to(device).unsqueeze(0)
    
    # Loop through timesteps in reverse order
    for t in reversed(range(1, total_timesteps)):
        print(z)
        print(model(z,t))
        break
        x = i.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
        plt.imshow(x)
        plt.show()

tensor([[[[-0.1692, -0.9671,  1.1251,  ..., -1.0551, -0.7499, -1.6292],
          [ 0.6439,  0.2862,  0.5339,  ...,  1.0847,  0.1038,  0.2279],
          [ 0.6312,  0.3369,  0.4150,  ...,  1.7995,  2.3065, -0.8919],
          ...,
          [-0.4395,  0.4910,  0.1439,  ..., -0.2693, -1.0838,  0.1068],
          [-1.4824,  0.3024,  0.9480,  ...,  0.7726,  0.8500, -0.2259],
          [-0.1618,  1.5281, -0.1929,  ..., -1.0863,  1.5868,  1.3455]],

         [[-0.0499,  0.7305,  0.3193,  ..., -0.0198, -1.0959,  2.1072],
          [ 0.5687, -0.1067,  0.8327,  ..., -0.2662,  0.1121,  0.2600],
          [ 2.1679, -1.5367, -0.5600,  ...,  0.1887, -0.6503, -0.5361],
          ...,
          [ 0.6410, -0.7314,  0.5745,  ...,  0.6057,  0.7800, -0.5803],
          [ 0.7338,  0.5559,  1.6077,  ...,  1.3705, -1.3864, -1.5120],
          [-1.3322, -0.7367, -0.5126,  ...,  0.0476, -1.2847, -1.5150]],

         [[ 1.6886, -0.2865, -0.1043,  ...,  2.0151, -0.2910,  0.1858],
          [ 0.0808,  1.0636, -

In [None]:
torch.mps.current_allocated_memory()/1e9

0.564282368