In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader 
from timm.utils import ModelEmaV3 #pip install timm 

import matplotlib.pyplot as plt #pip install matplotlib
import matplotlib.font_manager as font_manager
plt.rcParams["font.family"] = "serif"
plt.style.use('classic')
font = font_manager.FontProperties(family='serif', size=16)

import numpy as np
import math, os, random, h5py
from einops import rearrange #pip install einops
from typing import List

from pathlib import Path
from PIL import Image

import tqdm
from tqdm.notebook import tqdm #pip install tqdm

In [2]:
class UnetLayer(nn.Module):
    def __init__(self, 
            upscale: bool, 
            attention: bool, 
            num_groups: int, 
            dropout_prob: float,
            num_heads: int,
            C: int):
        super().__init__()
        self.ResBlock1 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
        self.ResBlock2 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
        if upscale:
            self.conv = nn.ConvTranspose2d(C, C//2, kernel_size=4, stride=2, padding=1)
        else:
            self.conv = nn.Conv2d(C, C*2, kernel_size=3, stride=2, padding=1)
        if attention:
            self.attention_layer = Attention(C, num_heads=num_heads, dropout_prob=dropout_prob)

    def forward(self, x, embeddings):
        x = self.ResBlock1(x, embeddings)
        if hasattr(self, 'attention_layer'):
            x = self.attention_layer(x)
        x = self.ResBlock2(x, embeddings)
        return self.conv(x), x

In [3]:
class ResBlock(nn.Module):
    def __init__(self, C: int, num_groups: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1, padding_mode='circular')
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1, padding_mode='circular')
        self.dropout = nn.Dropout(p=dropout_prob, inplace=True)

    def forward(self, x, embeddings):
        x = x + embeddings[:, :x.shape[1], :, :]
        r = self.conv1(self.relu(self.gnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.gnorm2(r)))
        return r + x

class MLPBlock(nn.Module):
    def __init__(self, h_sizes, out_size):
        ...
        # Hidden layers
        self.hidden = nn.ModuleList()
        for k in range(len(h_sizes)-1):
            self.hidden.append(nn.Linear(h_sizes[k], h_sizes[k+1]))
        # Output layer
        self.out = nn.Linear(h_sizes[-1], out_size)

    def forward(self, x):
        # Feedforward
        for layer in self.hidden:
            x = F.relu(layer(x))
        output= F.softmax(self.out(x), dim=1)

In [4]:
class SinusoidalEmbeddings(nn.Module):
    def __init__(self, time_steps:int, embed_dim: int):
        super().__init__()
        position = torch.arange(time_steps).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)
        embeddings[:, 0::2] = torch.sin(position * div)
        embeddings[:, 1::2] = torch.cos(position * div)
        self.embeddings = embeddings

    def forward(self, x, t):
        embeds = self.embeddings[t].to(x.device)
        return embeds[:, :, None, None]

In [7]:
class UNET_classifier(nn.Module):
    def __init__(self,
            Channels: List = [64, 128, 256, 512, 512, 384],
            Attentions: List = [False, True, False, False, False, True],
            Upscales: List = [False, False, False, True, True, True],
            num_groups: int = 32,
            dropout_prob: float = 0.1,
            num_heads: int = 8,
            input_channels: int = 1,
            output_channels: int = 1,
            time_steps: int = 1000,
             MLP_h_sizes: int=50,
             MLP_out_size: int=3,
                ):
        super().__init__()
        self.num_layers = len(Channels)
        self.shallow_conv = nn.Conv2d(input_channels, Channels[0], kernel_size=3, padding=1)
        out_channels = (Channels[-1]//2)+Channels[0]
        self.late_conv = nn.Conv2d(out_channels, out_channels//2, kernel_size=3, padding=1)
        self.output_conv = nn.Conv2d(out_channels//2, output_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.embeddings = SinusoidalEmbeddings(time_steps=time_steps, embed_dim=max(Channels))
        self.flatten = nn.Flatten()
        self.MLP_output = MLPBlock(MLP_h_sizes, MLP_out_size)
        for i in range(self.num_layers):
            layer = UnetLayer(
                upscale=Upscales[i],
                attention=Attentions[i],
                num_groups=num_groups,
                dropout_prob=dropout_prob,
                C=Channels[i],
                num_heads=num_heads
            )
            setattr(self, f'Layer{i+1}', layer)

    def forward(self, x, t):
        x = self.shallow_conv(x)
        residuals = []
        for i in range(self.num_layers//2):
            layer = getattr(self, f'Layer{i+1}')
            embeddings = self.embeddings(x, t)
            x, r = layer(x, embeddings)
            residuals.append(r)
        for i in range(self.num_layers//2, self.num_layers):
            layer = getattr(self, f'Layer{i+1}')
            x = torch.concat((layer(x, embeddings)[0], residuals[self.num_layers-i-1]), dim=1)
        x = self.output_conv(self.relu(self.late_conv(x)))
        x = self.flatten(x)
        x = self.MLP_output(x)
        return

In [8]:
def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    
def npy_loader(paths, input_dim=64):
    samples = []
    for path in paths:
        sample = torch.from_numpy(np.load(path))
        if len(sample.shape)==3:
            sample = sample[:, None, :, :]
        if sample.shape[-1]>input_dim:
            sample = transforms.CenterCrop(input_dim)(sample)
        print(sample.shape)
        print(type(sample))
        samples.append(sample)
    samples=torch.cat(samples, dim=0)
    return samples

In [None]:
def train(
    train_dataset, batch_size: int=32, 
    num_time_steps: int=1000, 
    num_epochs: int=15, 
    save_patience: int=1,
    seed: int=-1, 
    ema_decay: float=0.9999,  
    lr=2e-5, 
    checkpoint_folder_path: str=None, 
    checkpoint_filename: str=None, input_dim: int=32,
    unet_Channels: List=[64, 128, 256, 512, 512, 384],
    unet_Attentions: List = [False, True, False, False, False, True],
    unet_Upscales: List = [False, False, False, True, True, True],
    unet_num_groups: List = 32,
    unet_dropout_prob: List = 0.1,
    unet_num_heads: List = 8,
    unet_input_channels: List = 1,
    unet_output_channels: List = 1,
    unet_MLP_h_sizes: int=50,
    unet_MLP_out_size: int=3,
         ):
    
    set_seed(random.randint(0, 2**32-1)) if seed == -1 else set_seed(seed)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
    
    # define scheduler
    scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)
    
    # define models
    model = UNET_classifier(
            unet_Channels,
            unet_Attentions,
            unet_Upscales,
            unet_num_groups,
            unet_dropout_prob,
            unet_num_heads,
            unet_input_channels,
            unet_output_channels,
            unet_MLP_h_sizes,
            unet_MLP_out_size
    ).cuda()
    # exponential moving average (ema)
    # When training a model, it is often beneficial to maintain moving averages of the trained parameters. 
    # Evaluations that use averaged parameters sometimes produce significantly better results than the final trained values.
    ema = ModelEmaV3(model, decay=ema_decay)
    
    # define optmizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # load checkpoint if present
    if os.path.exists(checkpoint_folder_path+'/'+checkpoint_filename):
        print('Loading previous checkpoint')
        checkpoint = torch.load(checkpoint_folder_path+'/'+checkpoint_filename)
        model.load_state_dict(checkpoint['weights'])
        ema.load_state_dict(checkpoint['ema'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    else:
        os.makedirs(checkpoint_folder_path, exist_ok=True)
    # define loss
    criterion = nn.CrossEntropyLoss(reduction='mean')

    # start the training
    for i in range(num_epochs):
        total_loss = 0
        for bidx, x in enumerate(tqdm(train_loader, desc=f"Epoch {i+1}/{num_epochs}")):
            x = x.cuda()
            x = F.pad(x, (2,2,2,2))
            
            # FORWARD PROCESS
            # pick a random time step (t)
            t = torch.randint(0,num_time_steps,(batch_size,))
            # generate noise (e)
            e = torch.randn_like(x, requires_grad=False)
            # define the scheduler at time t (alpha_t)
            a = scheduler.alpha[t].view(batch_size,1,1,1).cuda()
            # update x: x(0) --> x(t)
            x = (torch.sqrt(a)*x) + (torch.sqrt(1-a)*e)

            # NOISE MATCHING
            output = model(x, t)
            optimizer.zero_grad()
            loss = criterion(output, e)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            ema.update(model)
        print(f'Epoch {i+1} | Loss {total_loss / (60000/batch_size):.5f}')
        torch.cuda.empty_cache()

        if (not i%save_patience) or (i==(num_epochs-1)):
            print("Save checkpoint %i"%(i))
            # save model at the end of each epoch
            checkpoint = {
                'weights': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'ema': ema.state_dict()
            }
            torch.save(checkpoint, checkpoint_folder_path+'/'+checkpoint_filename+'_%i'%(i))