In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import plotly.graph_objects as go
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

from tqdm.auto import tqdm
from loguru import logger

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#####################
## UNet Components ##
#####################

from ddpm_components import DoubleConv, Down, SelfAttention, Up

###############
## Utilities ##
###############

from utils import get_data, create_diffusion_animation

In [2]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_dim=256, device=device):
        super().__init__()
        self.device = device
        self.time_dim = time_dim

        # Initial Conv (no size change)
        self.initial_conv = DoubleConv(in_channels, 32)
        
        # Encoder (Down) - each Down has MaxPool2d that halves spatial size
        self.down1 = Down(32, 64)              # 64 -> 32
        self.sa1 = SelfAttention(64, 32)       # spatial size: 32x32
        self.down2 = Down(64, 128)             # 32 -> 16
        self.sa2 = SelfAttention(128, 16)      # spatial size: 16x16
        self.down3 = Down(128, 128)            # 16 -> 8
        self.sa3 = SelfAttention(128, 8)       # spatial size: 8x8
        
        # Bottle-neck (no size change)
        self.bot1 = DoubleConv(128, 256)       # spatial size: 8x8
        self.bot2 = DoubleConv(256, 256)       # spatial size: 8x8
        self.bot3 = DoubleConv(256, 128)       # spatial size: 8x8
        
        # Decoder (Up) - each Up has Upsample that doubles spatial size
        self.up1 = Up(256, 64)                 # 8 -> 16 (concat with x3: 128ch)
        self.sa4 = SelfAttention(64, 16)       # spatial size: 16x16
        self.up2 = Up(128, 32)                 # 16 -> 32 (concat with x2: 64ch)
        self.sa5 = SelfAttention(32, 32)       # spatial size: 32x32
        self.up3 = Up(64, 32)                  # 32 -> 64 (concat with x1: 32ch)
        self.sa6 = SelfAttention(32, 64)       # spatial size: 64x64
        
        # Out Conv
        self.out_conv = nn.Conv2d(32, out_channels, kernel_size=1)
        
    def pos_encoding(self, t, channels):
        inv_freq = 1. / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc
    
    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)
        
        # Encoder (Down path)
        x1 = self.initial_conv(x)      # 64x64, 32 ch
        x2 = self.down1(x1, t)         # 32x32, 64 ch
        x2 = self.sa1(x2)              # 32x32, 64 ch
        x3 = self.down2(x2, t)         # 16x16, 128 ch
        x3 = self.sa2(x3)              # 16x16, 128 ch
        x4 = self.down3(x3, t)         # 8x8, 128 ch
        x4 = self.sa3(x4)              # 8x8, 128 ch
        
        # Bottle-neck
        x4 = self.bot1(x4)             # 8x8, 256 ch
        x4 = self.bot2(x4)             # 8x8, 256 ch
        x4 = self.bot3(x4)             # 8x8, 128 ch
        
        # Decoder (Up path)
        x = self.up1(x4, x3, t)        # 16x16, 64 ch
        x = self.sa4(x)                # 16x16, 64 ch
        x = self.up2(x, x2, t)         # 32x32, 32 ch
        x = self.sa5(x)                # 32x32, 32 ch
        x = self.up3(x, x1, t)         # 64x64, 32 ch
        x = self.sa6(x)                # 64x64, 32 ch
        out = self.out_conv(x)         # 64x64, 3 ch
        return out

model = UNet()

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")   

Total parameters: 5897155


In [3]:
###################
## Flower Images ##
###################

import gc
gc.collect()
torch.cuda.empty_cache()

from datasets import load_dataset
flowers = load_dataset(path="nkirschi/oxford-flowers", split="test")
flowers = flowers['image']

In [4]:
####################
## Trainning loop ##
####################

IMG_SIZE = 64
BATCH_SIZE = 6

torch.manual_seed(42)
model = UNet().to(device=device)

#-------
## Train
#-------

from ddpm_components import train

train(model=model, data=flowers, epochs=4000, img_size=IMG_SIZE, batch_size=BATCH_SIZE, report_interval=1000, visualize=True)

Training:   0%|          | 0/4000 [00:00<?, ?it/s]

++++++++++++++++++++++++++++++++++++++++++++++++++
Epoch: 1 | Loss: 0.34301681417752716 | Current LR: 0.0002


Sampling: 0it [00:00, ?it/s]

NameError: name 'create_diffusion_animation' is not defined