In [None]:
from dataclasses import dataclass
from torchvision import transforms
from data_handler import DataHandler
from unet import UNet
from ddpm import DDPM
from train import train
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from sample import ImgDataTransformer, sample_images

In [None]:
# some hyperparameters
@dataclass
class Hyperparameters:
    width: int = 32
    height: int = 32
    eval_interval: int = 10
    checkpoint_interval: int = 10
    batch_size: int = 64
    epochs: int = 300
    n_classes: int = 100

hyperparameters = Hyperparameters()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

data_handler = DataHandler(train_path="../cifar100/train",
                           val_path="../cifar100/test",
                           transform=transform,
                           batch_size=hyperparameters.batch_size)

In [None]:
unet_model = UNet(timesteps=1000, n_classes=hyperparameters.n_classes)
model = DDPM(unet_model)

In [None]:
train_losses, val_losses = train(model=model,
                      optimizer=torch.optim.Adam(params=model.parameters(), lr=3e-4),
                      data_handler=data_handler,
                      epochs=hyperparameters.epochs,
                      eval_interval=hyperparameters.eval_interval,
                      weights_save_path="../params/weights/weights.pth",
                      checkpoint_path="../params/checkpoints/checkpoint.pth",
                      checkpoint_interval=hyperparameters.checkpoint_interval,
                      from_checkpoint=False)

In [None]:
sns.set_theme(style="darkgrid", font_scale=1.4)

plt.figure(figsize=(8, 6))
plt.title("Loss")
plt.plot(range(0, hyperparameters.epochs+1, hyperparameters.eval_interval), train_losses, label="Train", linewidth=2)
plt.plot(range(0, hyperparameters.epochs+1, hyperparameters.eval_interval), val_losses, label="Valid", linewidth=2)
plt.legend()
plt.xlabel("Epoch")
plt.show()

In [None]:
img_data_transformer = ImgDataTransformer(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225],
                                          device='cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
sample_images(model, n_samples=6, img_size=(3, hyperparameters.width, hyperparameters.height), img_data_transformer=img_data_transformer, n_classes=hyperparameters.n_classes, cls=7)