# Train Diffusion Model
### Imports

In [1]:
try: 
    import librosa
except:
    !pip install librosa


#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from MainScripts.Conf import conf

### Config
General

In [2]:
remote_kernel: bool = True

logging_level: int = logging.INFO
model_name: str = "diffusion_cifar10"
full_model_path: str = path_to_remote_path("{}/{}".format(conf["paths"].model_path, model_name + ".pth"), remote_kernel)
checkpoint_freq: int = 10 #0 for no checkpoint saving
training_data_name: str = "training_full_low_res"

device = "cuda" if torch.cuda.is_available() else "cpu"
restart_training: bool = True #If True and model already exists optimizer and lr_scheduler are reset
learning_rate: float = 5e-4 #Starting lr/first lr for Threshold Scheduler
epochs: int = 200
n_training_samples: int = 2000

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

### Data Loading

In [6]:
file: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_data_name + ".npy"), remote_kernel))[:n_training_samples, ...]
data_loader = create_dataloader(Audio_Data(file), conf["model"].batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-04-28 20:00:14,141 - LIGHT_DEBUG - Ndarray loaded from ../Data/training_full_low_res.npy of shape: (7087, 224, 416)
2025-04-28 20:00:15,027 - INFO - Data loaded with shape: (2000, 224, 416)


In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Grayscale(num_output_channels=1), 
])

cifar10_train = torchvision.datasets.CIFAR10(
    root=path_to_remote_path(conf["paths"].data_path, remote_kernel), 
    train=True, 
    download=True, 
    transform=transform
)
indices = np.random.choice(len(cifar10_train), size=5000, replace=False)
cifar10_subset = Subset(cifar10_train, indices)
data_loader = DataLoader(cifar10_train, batch_size=32, shuffle=True)

Files already downloaded and verified


### Model Creation
#### U-Net

Small U-Net

In [4]:
u_net = U_NET(in_channels=1,
            channels=[16, 32],
            res_blocks=[2, 4],
            factors=[2, 2],
            attentions=[0, 1], 
            attention_heads=8,
            attention_features=48,
            activation=nn.GELU(), 
            embeding_dim=conf["model"].time_embed_dim, 
            device=device
            ).to(device)

Large U-Net

In [None]:
u_net = U_NET(in_channels=1,
            channels=[64, 128, 256, 512],
            res_blocks=[12, 2, 2, 2],
            factors=[ 2, 2, 2, 2],
            attentions=[ 0, 0, 0, 1], 
            attention_heads=8,
            attention_features=48,
            activation=nn.GELU(), 
            embeding_dim=conf["model"].time_embed_dim, 
            device=device
            ).to(device)

2025-04-22 12:16:42,340 - INFO - Model diffusion2_v0 created with 50226237 Parameters


U-Net loading if possible

In [5]:
ema = EMA(u_net, decay = 0.99)
optimizer = optim.Adam(u_net.parameters(), lr=learning_rate)
#scheduler = Threshold_LR(optimizer, [1, 0.1, 0.09, 0.85, 0.08, 0.07], [learning_rate, 1e-4, 1e-5, 5e-6, 1e-6, 1e-7])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
start_epoch: int = 0
if os.path.exists(full_model_path):
    model = torch.load(full_model_path, map_location=device)
    u_net.load_state_dict(model["model"])
    if 'ema_state' in model:
        for name, param in model.named_parameters():
            if param.requires_grad and name in model['ema_state']:
                ema.shadow[name] = model['ema_state'][name].clone()
    if not restart_training:
        optimizer.load_state_dict(model["optim"])
        scheduler.load_state_dict(model["scheduler"])
        start_epoch = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

2025-04-28 19:47:24,327 - INFO - Model diffusion_cifar10 created with 312229 Parameters


Diffusion

In [7]:
diffusion = Diffusion(model=u_net, 
                        noise_steps=conf["model"].diffusion_timesteps, 
                        noise_schedule="cosine", 
                        input_dim= [32, 1, 32, 32],
                        ema = None,
                        device=device
)

### Train

In [None]:
x = diffusion.train(epochs=epochs, 
                    data_loader=data_loader, 
                    loss_function=nn.MSELoss(),
                    optimizer=optimizer, 
                    lr_scheduler=scheduler, 
                    gradient_accum=conf["model"].gradient_accum,
                    checkpoint_freq=checkpoint_freq, 
                    model_path=full_model_path, 
                    start_epoch=start_epoch,
                    patience=100,
                    ema_freq=0
                )
scatter_plot(x)

2025-04-28 19:47:36,596 - INFO - Training started on cuda
