In [1]:
import torch
from torch import nn
import torchvision
from torchvision import transforms

from utils import *
from learner import *
from unet import *
from attention import *
from time_embedding import *
from vae import *
from ddpm import *


In [2]:
BATCH_SIZE = 32
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)


In [4]:
def data_collate(batch):
    X, y = zip(*batch)
    ddpm = DDPM(0, 1, 5)
    X = torch.stack(X, dim=0)
    y = torch.tensor(y)
    (X, t), noise = ddpm.schedule(X)
    return (X, t, y), noise


In [5]:
def dataloader(train_dataset, test_dataset, batch_size, collate_fn=None):
    return (
        DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn), 
        DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    )
train_dataloader, test_dataloader = dataloader(train_dataset, test_dataset, BATCH_SIZE, data_collate)

In [6]:
unet = UNET(10, 1, 1, (32,64,128,256))

In [7]:
lr = 1e-3
epochs = 25
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
schedular = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=epochs*len(train_dataloader))


In [8]:
from tqdm import trange

In [9]:
def train(model, epochs=1):
    model.to(device)
    for epoch in trange(epochs):
        model.train()
        loss_calc = 0
        for batch, (image, noise) in enumerate(train_dataloader):
            X, t, c = image
            X, t, c = X.to(device), t.to(device), c.to(device)
            image = X, t, c
            noise = noise.to(device)
            pred = model(image)
            loss = loss_fn(pred, noise)
            loss_calc += loss.item()
            # print(f"Batch: {loss_calc}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch} = Loss: {loss_calc/len(train_dataloader)}")

In [11]:
train(unet)

100%|██████████| 1/1 [21:54<00:00, 1314.39s/it]

Epoch: 0 = Loss: 0.22196614021658898





In [13]:
torch.save(unet.state_dict(), "./Saved Models/unet_model.pth")