In [None]:
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

from src.unet_model import Unet
from src.ddim import DDIM
from src.ddpm import DDPM
from src.pokemon_dataset import PokemonDataset

In [None]:
# set device to cpu or cuda or mps
if torch.backends.mps.is_available():
    device = torch.device('mps') 
    print("Device set to : mps")
elif(torch.cuda.is_available()): 
    device = torch.device('cuda:0') 
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    device = torch.device('cpu')
    print("Device set to : cpu")

In [None]:
IMG_SIZE=64
transform = transforms.Compose([
        transforms.ToTensor(), # from [0,255] to range [0.0,1.0]
        transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
])

In [None]:
dataset = PokemonDataset(
    imgs_path='./data/pokemon_jpg',
    data_path='./data/pokemon_stats.csv',
    transform=transform
    )

In [None]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=False)

In [None]:
beta_start = 1e-4
beta_end = 0.02
time_step = 500

lr = 0.0002 #1e-3
epochs = 500
save_epoch = 50
save_dir = './weights/'
model_name = 'test'

In [None]:
unet = Unet(
    image_channels = 3, 
    down_channels  = [64, 128, 256, 512], 
    up_channels = [512, 256, 128, 64], 
    out_dim = 64, 
    time_emb_dim = 32, 
    context_dim = 18
   ).to(device)
ddpm = DDPM(beta_start, beta_end, time_step, device)
optimizer = optim.Adam(unet.parameters(), lr=lr)

In [None]:
for epoch in range(0, epochs+1):
    for step, batch in enumerate(dataloader):
        imgs, contexts = batch[0].to(device), batch[1].to(device)
        batch_size = len(imgs)        
        t = torch.randint(0, time_step, (batch_size,), device=device).long()
        
        noise_imgs, noise = ddpm.add_noise(imgs, t)
        pred_noise = unet(noise_imgs, t, contexts)
        
        optimizer.zero_grad()
        # loss is mean squared error between the predicted and true noise
        loss = F.mse_loss(pred_noise, noise)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch} | step {step+1:03d}/{len(dataloader)} Loss: {loss.item()}")
        
        with torch.no_grad():
            if epoch % 5 == 0 and step == 0:
                ddpm.sample(unet, 3, 64, 18, time_step, 40, plot=True)
    
    # save model periodically
    if epoch%save_epoch==0:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(unet.state_dict(), os.path.join(save_dir, f"{model_name}_{epoch}.pth"))
        print('saved model at ' + os.path.join(save_dir, f"{model_name}_{epoch}.pth"))