In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms

from tqdm import trange
from utils import *
from learner import *
from unet import *
from attention import *
from time_embedding import *
from vae import *
from ddpm import *


In [2]:
# transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)


In [3]:
BATCH_SIZE = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
# def data_collate(batch):
#     X, y = zip(*batch)
#     ddpm = DDPM(0, 1, 5)
#     X = torch.stack(X, dim=0)
#     y = torch.tensor(y)
#     (X, t), noise = ddpm.schedule(X)
#     return (X, t, y), noise


In [5]:
train_dataloader, test_dataloader, classes = dataloader(train_dataset, test_dataset, BATCH_SIZE)

In [6]:
unet = UNET(10, 1, 1, (32,64,128,256))

In [7]:
lr = 1e-3
epochs = 25
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
schedular = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=epochs*len(train_dataloader))


In [8]:
import matplotlib.pyplot as plt

def show_image(image, msg=None):
    single_image = image[0].detach().cpu().permute(1, 2, 0).numpy()
    plt.imshow(single_image)
    plt.title(f'{msg}Image')
    plt.axis('off')
    plt.show()


In [9]:
class DDPM:
    def __init__(self, beta_min, beta_max, n_steps, device):
        init_attr(self, locals())
        self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps)
        self.alpha = 1 - self.beta
        self.alpha_bar = self.alpha.cumprod(dim=0).to(self.device)
    
    def schedule(self, x):
        t = torch.randint(0, self.n_steps, (len(x),), dtype=torch.long, device=self.device)
        noise = torch.randn_like(x).to(self.device)
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        mean = alpha_bar_t.sqrt()*x
        variance = (1 - alpha_bar_t).sqrt()*noise
        x = mean + variance
        return x, t, noise


In [10]:
ddpm = DDPM(0.0001, 0.02, 1000, device)

In [11]:
# for X, y in train_dataloader:
#     X = X.to(device)
#     xt, t, noise = ddpm.schedule(X)
#     print(xt.shape)
#     show_image(xt, "Noised ")
#     print(noise.shape)
#     show_image(noise, "Actual Noise ")
#     print(t.shape)
#     break


In [12]:
def train(model, epochs=1):
    model.to(device)
    for epoch in trange(epochs):
        model.train()
        loss_calc = 0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)
            xt, t, noise = ddpm.schedule(X)
            pred = model(xt, t, y)
            loss = loss_fn(pred, noise)
            loss_calc += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch} = Loss: {loss_calc/len(train_dataloader)}")

In [13]:
train(unet)

  0%|          | 0/1 [00:00<?, ?it/s]



Epoch: 0 = Loss: 0.09340511672298113


In [14]:
torch.save(unet.state_dict(), "./Saved Models/unet_model_2023_12_31_8_11.pth")