# Train Diffusion Model
### Imports

In [1]:
#!pip install librosa

#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
#from torchsummary import summary
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from MainScripts.Conf import *

### Config
General

In [2]:
remote_kernel: bool = True

logging_level: int = logging.INFO
model_name: str = "diffusion_v3"
model_path: str = path_to_remote_path(f"{MODEL_PATH}/{model_name}.pth", remote_kernel)
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_full_low_res"

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size: int = 8
epochs: int = 100
diffusion_timesteps: int = 500
n_training_samples: int = 1280

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

U-Net

In [3]:
learning_rate: float = 1e-4
n_starting_filters: int = 24
n_downsamples: int = 3
time_embed_dim: int = 128
n_starting_attention_size: int = 32


### Data Loading

In [4]:
file: ndarray = load_training_data(path_to_remote_path(f"{DATA_PATH}/{training_data_name}.npy", remote_kernel))[:n_training_samples, ...]

In [5]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-03-06 20:10:06,807 - INFO - Data loaded with shape: (1280, 224, 416)


### Model Creation
U-Net

In [6]:
u_net = Conv_U_NET(in_channels=1, time_embed_dim=time_embed_dim, n_starting_filters=n_starting_filters, n_downsamples=n_downsamples, activation=nn.GELU(), device=device).to(device)
if os.path.exists(model_path):
    u_net.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate, weight_decay=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-8)


2025-03-06 20:10:11,531 - INFO - Model diffusion_v3 created with 9743665 Parameters


In [7]:
for name, module in u_net.named_modules():
    if module.__class__.__module__ != 'torch.nn.modules':  # Ignore standard PyTorch modules
        devices = {p.device for p in module.parameters()} | {b.device for b in module.buffers()}
        print(f"{name} ({module.__class__.__name__}): {devices}")


#summary(u_net, [(1, file.shape[-2], file.shape[-1]), (1,)], batch_size=2)

 (Conv_U_NET): {device(type='cuda', index=0)}
inp_lay (Conv2d): {device(type='cuda', index=0)}
encoder (ModuleList): {device(type='cuda', index=0)}
encoder.0 (ModuleList): {device(type='cuda', index=0)}
encoder.0.0 (Down): {device(type='cuda', index=0)}
encoder.0.0.seq (Sequential): {device(type='cuda', index=0)}
encoder.0.0.seq.0 (MaxPool2d): set()
encoder.0.0.seq.1 (DoubleConv): {device(type='cuda', index=0)}
encoder.0.0.seq.1.activation (GELU): set()
encoder.0.0.seq.1.seq (Sequential): {device(type='cuda', index=0)}
encoder.0.0.seq.1.seq.0 (Conv2d): {device(type='cuda', index=0)}
encoder.0.0.seq.1.seq.1 (GroupNorm): {device(type='cuda', index=0)}
encoder.0.0.seq.1.seq.3 (Conv2d): {device(type='cuda', index=0)}
encoder.0.0.seq.1.seq.4 (GroupNorm): {device(type='cuda', index=0)}
encoder.0.0.seq.2 (DoubleConv): {device(type='cuda', index=0)}
encoder.0.0.seq.2.seq (Sequential): {device(type='cuda', index=0)}
encoder.0.0.seq.2.seq.0 (Conv2d): {device(type='cuda', index=0)}
encoder.0.0.se

Diffusion

In [8]:
diffusion = Diffusion(model=u_net, noise_steps=diffusion_timesteps, noise_schedule="linear", input_dim=[8, 1, file.shape[-2], file.shape[-1]], device=device)

#diffusion.visualize_diffusion_steps(x=torch.Tensor(file[:1]), noise_schedule=noise_schedule, device=device, n_spectograms=5)

### Train

In [9]:
x = diffusion.train(epochs=epochs, data_loader=data_loader,loss_function=nn.MSELoss(),optimizer=optimizer, lr_scheduler=scheduler, gradient_accum=4, checkpoint_freq=checkpoint_freq, model_path=model_path)
scatter_plot(x)

2025-03-06 20:10:12,551 - INFO - Training started on cuda
2025-03-06 20:10:16,151 - INFO - Epoch 1, Timestep 398:
2025-03-06 20:10:16,153 - INFO -   True Noise min/max: -4.442, 4.357
2025-03-06 20:10:16,154 - INFO -   Pred Noise min/max: -0.441, 2.256
2025-03-06 20:10:52,946 - INFO - Epoch 01: Avg. Loss: 1.03646e+00 Remaining Time: 01h 06min 38s LR: 9.99753e-05
2025-03-06 20:10:53,190 - INFO - Epoch 2, Timestep 210:
2025-03-06 20:10:53,192 - INFO -   True Noise min/max: -4.396, 4.450
2025-03-06 20:10:53,194 - INFO -   Pred Noise min/max: -0.615, 0.699
2025-03-06 20:11:30,070 - INFO - Epoch 02: Avg. Loss: 9.36455e-01 Remaining Time: 01h 03min 18s LR: 9.99013e-05
2025-03-06 20:11:30,341 - INFO - Epoch 3, Timestep 226:
2025-03-06 20:11:30,342 - INFO -   True Noise min/max: -4.423, 4.259
2025-03-06 20:11:30,344 - INFO -   Pred Noise min/max: -1.377, 1.632
2025-03-06 20:12:07,069 - INFO - Epoch 03: Avg. Loss: 8.73159e-01 Remaining Time: 01h 01min 42s LR: 9.97781e-05
2025-03-06 20:12:07,270 

KeyboardInterrupt: 