In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader


from models.unet import Unet

import matplotlib.pyplot as plt
import numpy as np

from utils.utils import show_image, CFG
from utils.noise import CosineNoiseAdder

from data.dataset import MNIST_Dataset, CIFAR10_Dataset, Dataset

from tqdm import tqdm
import os

In [2]:
# testing net shapes
net = Unet(64, depth=3, embed_dim=64, initial_channels=3, conv_layers=3)
test_img = torch.randn(1, 3, 128, 128)
test_time = torch.tensor([1])
test_label = torch.tensor([1])
print(net(test_img, test_time, test_label, verbose=True).shape)

start with shape torch.Size([1, 3, 128, 128])
after concatenating the timestep (& labels) embedds : torch.Size([1, 67, 128, 128])
down block 0, with shape torch.Size([1, 67, 128, 128])
down block 1, with shape torch.Size([1, 64, 64, 64])
down block 2, with shape torch.Size([1, 128, 32, 32])
after bottleneck : shape = torch.Size([1, 256, 32, 32])
up block 0, with shape torch.Size([1, 256, 32, 32]), and skip shape : torch.Size([1, 256, 32, 32])
up block 1, with shape torch.Size([1, 128, 64, 64]), and skip shape : torch.Size([1, 128, 64, 64])
after final : shape = torch.Size([1, 3, 128, 128])
torch.Size([1, 3, 128, 128])


## Noise Dataset

In [3]:
class NoiseDataset():
    def __init__(self, imgs_dataset, noise_schedule = None):
        self.imgs_dataset = imgs_dataset
        self.noise_schedule = noise_schedule if noise_schedule else CosineNoiseAdder()       

    def __getitem__(self, idx):
        img, label = self.imgs_dataset[idx]
        t = torch.randint(self.noise_schedule.T, (1, )).squeeze()
        noisy_img, noise = self.noise_schedule.image_at_time_step(img, t)
        return noisy_img, noise, t, label

    def __len__(self):
        return len(self.imgs_dataset)

## Training

### Weight Initialization

In [4]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

In [5]:
def eval_model(model:nn.Module, _test_loader, criterion, device):
    model.eval()
    losses = []
    with torch.no_grad():
        for i, (noisy_imgs, noises, time_steps, labels) in enumerate(_test_loader):
            noisy_imgs, noises, time_steps, labels = noisy_imgs.to(device), noises.to(device), time_steps.to(device), labels.to(device)
            outputs = model(noisy_imgs, time_steps, labels)
            loss = criterion(outputs, noises)
            losses.append(loss.item())
    return sum(losses)/len(losses)


### Training Loop

In [7]:
def train_model(cfg:CFG, model, train, test, device=None):
    if device is None:
        # Set the device to GPU if available, else CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    print(f"Using device: {device}")
    
    model.apply(init_weights)  # Initialize model weights

    train_loader = DataLoader(train, shuffle=True, batch_size=cfg.batch_size)
    test_loader = DataLoader(test, shuffle=False, batch_size=cfg.batch_size)

    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=cfg.n_epochs_lr * len(train_loader), eta_min=cfg.final_lr)
    
    
    
    # keep track of the loss
    train_losses = []
    val_losses = []
    best_loss = np.inf
    best_loss_i = 0
    stoping = False
    eval_every = int(len(train_loader) * cfg.eval_frequency) # evaluate every n% of the training set

    # Compile the model for better performance
    if cfg.use_compile:
        print("Compiling the model...")
        torch.compile(model, mode='default', dynamic=True)

    # training loop
    for epoch in range(cfg.n_epochs):
        if stoping: # if the early stopping is triggered, we stop the training
            break
        train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{cfg.n_epochs}', leave=True)
        for i, batch in enumerate(train_loader_tqdm):
            noisy_imgs, noises, time_steps, labels = batch
            noisy_imgs, noises, time_steps, labels = noisy_imgs.to(device, dtype=cfg.images_precision), noises.to(device, dtype=cfg.images_precision), time_steps.to(device, dtype=torch.int16), labels.to(device, dtype=torch.int32)
            optimizer.zero_grad()

            predicted_noise = model(noisy_imgs, time_steps, labels, verbose=False)
            loss = criterion(predicted_noise, noises)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_losses.append(loss.item())
            
            # evaluate the model every eval_every steps
            if i % eval_every == 0 and i > 0:
                
                val_loss = eval_model(model, test_loader, criterion, device)
                val_losses.append(val_loss)
                # print(f'Epoch [{epoch+1}/{n_epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, lr: {scheduler.get_last_lr()[0]}')

                if val_loss < best_loss:
                    best_loss = val_loss
                    best_loss_i = epoch * len(train_loader) + i
                    torch.save(model.state_dict(), cfg.save_path)

                train_loader_tqdm.set_postfix({'Loss': loss.item(), 'Val Loss': val_loss, 'best_loss' : best_loss, 'lr': scheduler.get_last_lr()[0]})
                
                if epoch * len(train_loader) + i - best_loss_i > cfg.patience:
                    print("Stopping early")
                    stoping = True
                    break
                
            if epoch < cfg.n_epochs_lr:
                scheduler.step()
        
        if stoping:
            break


    # plot the losses
    plt.figure(figsize=(10, 5))
    train_losses_resized = [np.mean(train_losses[i * eval_every:(i + 1) * eval_every]) for i in range(len(val_losses))]
    plt.plot(np.arange(0, len(train_losses_resized)) * eval_every / len(train_loader), train_losses_resized, label='Training Loss (averaged)')
    plt.plot(np.arange(0, len(val_losses)) * eval_every / len(train_loader), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [8]:
config_CELEBA = CFG(
    dataset_name='CELEBA',
    save_path = 'best_model_CELEBA.pth',
    initial_channels=3,
    image_size=64,
    n_epochs=30,
    num_labels=0,
    
    lr=1e-3,
    final_lr=1e-5,
    n_epochs_lr=25,
    batch_size=32,
    patience=10000,
    
    depth=3,
    conv_layers=2,
    first_hidden=64,
    embedding_dim=64,
    max_time_steps=400,
)

CosineNoise = CosineNoiseAdder(config_CELEBA.max_time_steps)
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((config_CELEBA.image_size, config_CELEBA.image_size)),
    transforms.Normalize((0.5, 0.5, 0.5, ), (0.5, 0.5, 0.5, ))
])

train = NoiseDataset(Dataset(config_CELEBA.dataset_name, transform=trans), CosineNoise)
test = NoiseDataset(Dataset(config_CELEBA.dataset_name, 'test', transform=trans), CosineNoise)

model = Unet(
    first_hidden=config_CELEBA.first_hidden, depth=config_CELEBA.depth, embed_dim=config_CELEBA.embedding_dim, 
    num_label=config_CELEBA.num_labels, initial_channels=config_CELEBA.initial_channels, 
    conv_layers=config_CELEBA.conv_layers, dropout=config_CELEBA.dropout
)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
config_CELEBA.to_yaml('configs/config_CELEBA.yaml')

AttributeError: 'CFG' object has no attribute 'to_yaml'

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_model(config_CELEBA, model, train, test, device=device)

Using device: cuda


Epoch 1/30:   8%|â–Š         | 418/5087 [04:54<54:45,  1.42it/s]  


KeyboardInterrupt: 

In [None]:
net.load_state_dict(torch.load(config_CELEBA.save_path))
net.to(device)
net.eval()

Unet(
  (time_emb): TimeEmbedding(
    (time_mlp): Sequential(
      (0): SinusoidalPositionEmbeddings()
      (1): Linear(in_features=64, out_features=64, bias=True)
      (2): SiLU()
      (3): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (label_emb): LabelEmbedding(
    (emb): Embedding(10, 64)
    (proj): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): SiLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (down_blocks): ModuleList(
    (0): DownBlock(
      (silu): SiLU()
      (convs): ModuleList(
        (0): Conv2d(67, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1-2): 2 x Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (norms): ModuleList(
        (0-2): 3 x BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (dropouts): ModuleList(
        (0-2): 3 x Dropout(p=0.2, inplace=False)
      )
      (max_poolin

In [None]:
def sample(model, max_time_steps, n_samples=1, config=None, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    CosineNoise = CosineNoiseAdder(max_time_steps, s=0.008)
    all_samples = []
    all_vis = []

    for i in range(n_samples):
        with torch.no_grad():
            full_img = torch.tensor([], device=device)
            full_predicted_noise = torch.tensor([], device=device)
        
            xt = torch.randn((1, config.initial_channels, 32, 32)).to(device)
            for t in torch.arange(max_time_steps-1, -1, -1):
                t = t.expand((1)).to(device)
                a_t = CosineNoise.get_alpha_t(t)
                alpha_t_barre = CosineNoise.get_alpha_t_barre(t)
                sigma = torch.sqrt(1-a_t).view(1, 1, 1, 1)
                noise = torch.randn_like(xt)
                # print(xt.shape)
                label = torch.tensor([i%10], device=device)
                epsilon = model(xt, t, label)
                
                a = ((1 - a_t)/(torch.sqrt(1 - alpha_t_barre))).view(1, 1, 1, 1)
                b = (1/torch.sqrt(a_t)).view(1, 1, 1, 1)
                
                if t.item() % (max_time_steps / 10) == 0 or t.item() == max_time_steps-1:
                    # print(t.item())
                    full_img = torch.cat((full_img, xt), 3)
                    full_predicted_noise = torch.cat((full_predicted_noise, epsilon), 3)
                    # print(xt.shape)
                    # show_image(xt[0], f'{t.item()}%')
                    # show_image(full_img[0])
                
                xt = b*(xt - a*epsilon) + sigma*noise

                # xt = torch.sqrt(1 - a_t).view(1, 1, 1, 1) * noise_predicted + sigma * noise
                
                # xt = b * (xt - torch.sqrt(1-alpha_t_barre)*noise_predicted) + sigma*noise

        
            all_samples.append(xt[0].cpu())
            all_vis.append(full_img[0].cpu())
        
            # show_image(xt[0], title=f'Final Image of {label.item()}')
            # show_image(full_img[0])
            
    return all_samples, all_vis

In [None]:
max_time_steps = 200
n_samples = 10

imgs, full_imgs = sample(model, max_time_steps, n_samples, config_CELEBA, device=device)
for i, (xt, full_img) in enumerate(zip(imgs, full_imgs)):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(xt.permute(1, 2, 0).cpu().numpy(), cmap='gray')
    plt.title(f"Final Image of {i}")
    plt.subplot(1, 2, 2)
    plt.imshow(full_img.permute(1, 2, 0).cpu().numpy(), cmap='gray')
    plt.title(f"Sampling process of {i}")
    plt.show()