# Train Diffusion Model
### Imports

In [1]:
#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
# Utils
import numpy as np
from numpy import ndarray
from torchsummary import summary
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from Conf import *

### Config
General

In [2]:
logging_level: int = LIGHT_DEBUG #logging.INFO
model_name: str = "diffusion_v1"
model_path: str = f"{MODEL_PATH}/{model_name}.pth"
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_1280"

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size: int = 2
epochs: int = 100
diffusion_timesteps: int = 500

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

U-Net

In [3]:
learning_rate: float = 1e-5
lr_decay: int = 40
lr_gamma: float = 0.1
n_starting_filters: int = 24
n_blocks: int = 2 #Each samples down by factor of 2
n_groups: int = 8 #For group norm
time_embed_dim: int = 64

### Data Loading

In [4]:
file = load_training_data(f"{DATA_PATH}/{training_data_name}.npy")[:4, ...]

2025-03-02 22:50:12,950 - LIGHT_DEBUG - Ndarray loaded from ../Data/training_1280.npy of shape: (1280, 1024, 672)


In [5]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-03-02 22:50:12,971 - INFO - Data loaded with shape: (4, 1024, 672)


### Model Creation
U-Net

In [None]:
u_net = U_NET(in_channels=1, device=device, input_shape=[0, 0, file.shape[-2], file.shape[-1]], n_res_layers=n_blocks, n_starting_filters=n_starting_filters, n_groups=n_groups, time_emb_dim=time_embed_dim).to(device)
if os.path.exists(model_path):
    u_net.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate, weight_decay=0.05)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay, gamma=lr_gamma)
u_net = nn.DataParallel(u_net)
u_net = u_net.to(device)
#summary(u_net, [(1, 1024, 672), (1, 1, 1)])# ######akscOICCXSASSSassddff#

2025-03-02 22:50:12,999 - INFO - Model diffusion_v1 created with 452065 Parameters


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
 TimestepEmbedding-1                   [-1, 64]               0
            Conv2d-2        [-1, 24, 1024, 672]             240
         GroupNorm-3        [-1, 24, 1024, 672]              48
         LeakyReLU-4        [-1, 24, 1024, 672]               0
         LeakyReLU-5        [-1, 24, 1024, 672]               0
         LeakyReLU-6        [-1, 24, 1024, 672]               0
         LeakyReLU-7        [-1, 24, 1024, 672]               0
         LeakyReLU-8        [-1, 24, 1024, 672]               0
            Conv2d-9        [-1, 48, 1024, 672]          10,416
        LeakyReLU-10        [-1, 48, 1024, 672]               0
        LeakyReLU-11        [-1, 48, 1024, 672]               0
        LeakyReLU-12        [-1, 48, 1024, 672]               0
        LeakyReLU-13        [-1, 48, 1024, 672]               0
        LeakyReLU-14        [-1, 48, 10

Diffusion

In [7]:
diffusion = Diffusion(u_net=u_net, u_net_optimizer=optimizer,diffusion_timesteps=diffusion_timesteps)
noise_schedule = cosine_noise(T=diffusion_timesteps)

### Train

In [8]:
diffusion.train(data_loader, device, epochs=epochs, loss_function=U_Net_loss, noise_schedule=noise_schedule, checkpoint_freq=checkpoint_freq, model_path=model_path)

2025-03-02 22:51:55,000 - LIGHT_DEBUG - Batch 02/02

2025-03-02 22:51:55,897 - INFO - Epoch 01: Avg. Loss: 6.07030e-01 Remaining Time: 02h 16min 34s



2025-03-02 22:53:26,000 - LIGHT_DEBUG - Batch 02/02

2025-03-02 22:53:26,020 - INFO - Epoch 02: Avg. Loss: 6.12806e-01 Remaining Time: 02h 21min 11s



2025-03-02 22:53:56,000 - LIGHT_DEBUG - Batch 01/02

KeyboardInterrupt: 