# Train Diffusion Model
### Imports

In [1]:
#!pip install librosa
#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from MainScripts.Conf import *

### Config
General

In [None]:
remote_kernel: bool = True

logging_level: int = logging.INFO
model_name: str = "diffusion_v1"
model_path: str = path_to_remote_path(f"{MODEL_PATH}/{model_name}.pth", remote_kernel)
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_2000_1_1"

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size: int = 8
epochs: int = 100
diffusion_timesteps: int = 500

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

U-Net

In [3]:
learning_rate: float = 1e-4
n_starting_filters: int = 24
n_blocks: int = 2 #Each samples down by factor of 2
n_groups: int = 8 #For group norm
time_embed_dim: int = 64

### Data Loading

In [4]:
file: ndarray = load_training_data(path_to_remote_path(f"{DATA_PATH}/{training_data_name}.npy", remote_kernel))[:200, ...]

2025-03-03 13:07:32,038 - LIGHT_DEBUG - Ndarray loaded from Data/training_2000_1_1.npy of shape: (2000, 1024, 672)


In [5]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-03-03 13:07:32,364 - INFO - Data loaded with shape: (200, 1024, 672)


### Model Creation
U-Net

In [6]:
u_net = U_NET(in_channels=1, device=device, input_shape=[0, 0, file.shape[-2], file.shape[-1]], n_res_layers=n_blocks, n_starting_filters=n_starting_filters, n_groups=n_groups, time_emb_dim=time_embed_dim).to(device)
u_net = nn.DataParallel(u_net)
if os.path.exists(model_path):
    u_net.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate, weight_decay=0.001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

u_net = u_net.to(device)

#summary(u_net, [(1, 1024, 672), (1, 1, 1)])

2025-03-03 13:07:37,018 - INFO - Model diffusion_v1 created with 452065 Parameters


Diffusion

In [10]:
diffusion = Diffusion(u_net=u_net, u_net_optimizer=optimizer,diffusion_timesteps=diffusion_timesteps)
noise_schedule = cosine_noise(T=diffusion_timesteps).to(device)
#diffusion.visualize_diffusion_steps(x=torch.Tensor(file[:1]), noise_schedule=noise_schedule, device=device, n_spectograms=5)

### Train

In [None]:
diffusion.train(data_loader, device, epochs=epochs, loss_function=U_Net_loss, noise_schedule=noise_schedule, checkpoint_freq=checkpoint_freq, model_path=model_path)

2025-03-03 13:08:52,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:08:52,699 - INFO - Epoch 01: Avg. Loss: 1.36904e-01 Remaining Time: 00h 47min 27s


2025-03-03 13:09:15,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:09:15,568 - INFO - Epoch 02: Avg. Loss: 1.27605e-01 Remaining Time: 00h 42min 09s


2025-03-03 13:09:38,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:09:38,595 - INFO - Epoch 03: Avg. Loss: 1.26486e-01 Remaining Time: 00h 40min 13s


2025-03-03 13:10:01,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:10:01,574 - INFO - Epoch 04: Avg. Loss: 1.26444e-01 Remaining Time: 00h 39min 03s


2025-03-03 13:10:24,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:10:24,526 - INFO - Epoch 05: Avg. Loss: 1.25058e-01 Remaining Time: 00h 38min 11s
2025-03-03 13:10:24,538 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:10:47,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:10:47,537 - INFO - Epoch 06: Avg. Loss: 1.22650e-01 Remaining Time: 00h 37min 29s


2025-03-03 13:11:10,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:11:10,634 - INFO - Epoch 07: Avg. Loss: 1.21126e-01 Remaining Time: 00h 36min 54s


2025-03-03 13:11:33,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:11:33,750 - INFO - Epoch 08: Avg. Loss: 1.20936e-01 Remaining Time: 00h 36min 22s


2025-03-03 13:11:56,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:11:56,882 - INFO - Epoch 09: Avg. Loss: 1.18851e-01 Remaining Time: 00h 35min 52s


2025-03-03 13:12:20,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:12:20,019 - INFO - Epoch 10: Avg. Loss: 1.18987e-01 Remaining Time: 00h 35min 24s
2025-03-03 13:12:20,026 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:12:43,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:12:43,102 - INFO - Epoch 11: Avg. Loss: 1.18454e-01 Remaining Time: 00h 34min 56s


2025-03-03 13:13:06,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:13:06,069 - INFO - Epoch 12: Avg. Loss: 1.17463e-01 Remaining Time: 00h 34min 28s


2025-03-03 13:13:29,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:13:29,037 - INFO - Epoch 13: Avg. Loss: 1.17039e-01 Remaining Time: 00h 34min 01s


2025-03-03 13:13:52,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:13:52,132 - INFO - Epoch 14: Avg. Loss: 1.16837e-01 Remaining Time: 00h 33min 35s


2025-03-03 13:14:15,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:14:15,208 - INFO - Epoch 15: Avg. Loss: 1.15361e-01 Remaining Time: 00h 33min 10s
2025-03-03 13:14:15,220 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:14:38,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:14:38,339 - INFO - Epoch 16: Avg. Loss: 1.15182e-01 Remaining Time: 00h 32min 45s


2025-03-03 13:15:01,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:15:01,443 - INFO - Epoch 17: Avg. Loss: 1.15095e-01 Remaining Time: 00h 32min 20s


2025-03-03 13:15:24,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:15:24,516 - INFO - Epoch 18: Avg. Loss: 1.14780e-01 Remaining Time: 00h 31min 55s


2025-03-03 13:15:47,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:15:47,662 - INFO - Epoch 19: Avg. Loss: 1.13220e-01 Remaining Time: 00h 31min 31s


2025-03-03 13:16:10,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:16:10,680 - INFO - Epoch 20: Avg. Loss: 1.11409e-01 Remaining Time: 00h 31min 06s
2025-03-03 13:16:10,693 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:16:33,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:16:33,726 - INFO - Epoch 21: Avg. Loss: 1.11331e-01 Remaining Time: 00h 30min 42s


2025-03-03 13:16:56,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:16:56,879 - INFO - Epoch 22: Avg. Loss: 1.10994e-01 Remaining Time: 00h 30min 18s


2025-03-03 13:17:19,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:17:19,944 - INFO - Epoch 23: Avg. Loss: 1.10534e-01 Remaining Time: 00h 29min 54s


2025-03-03 13:17:43,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:17:43,013 - INFO - Epoch 24: Avg. Loss: 1.09875e-01 Remaining Time: 00h 29min 30s


2025-03-03 13:18:06,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:18:06,115 - INFO - Epoch 25: Avg. Loss: 1.09507e-01 Remaining Time: 00h 29min 06s
2025-03-03 13:18:06,125 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:18:29,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:18:29,205 - INFO - Epoch 26: Avg. Loss: 1.08196e-01 Remaining Time: 00h 28min 42s


2025-03-03 13:18:52,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:18:52,216 - INFO - Epoch 27: Avg. Loss: 1.07108e-01 Remaining Time: 00h 28min 18s


2025-03-03 13:19:15,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:19:15,216 - INFO - Epoch 28: Avg. Loss: 1.03795e-01 Remaining Time: 00h 27min 54s


2025-03-03 13:19:38,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:19:38,245 - INFO - Epoch 29: Avg. Loss: 1.02396e-01 Remaining Time: 00h 27min 30s


2025-03-03 13:20:01,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:20:01,374 - INFO - Epoch 30: Avg. Loss: 1.01444e-01 Remaining Time: 00h 27min 07s
2025-03-03 13:20:01,382 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:20:24,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:20:24,446 - INFO - Epoch 31: Avg. Loss: 1.00699e-01 Remaining Time: 00h 26min 43s


2025-03-03 13:20:47,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:20:47,522 - INFO - Epoch 32: Avg. Loss: 9.83503e-02 Remaining Time: 00h 26min 19s


2025-03-03 13:21:10,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:21:10,501 - INFO - Epoch 33: Avg. Loss: 9.86524e-02 Remaining Time: 00h 25min 56s


2025-03-03 13:21:33,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:21:33,557 - INFO - Epoch 34: Avg. Loss: 9.87503e-02 Remaining Time: 00h 25min 32s


2025-03-03 13:21:56,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:21:56,590 - INFO - Epoch 35: Avg. Loss: 9.62701e-02 Remaining Time: 00h 25min 09s
2025-03-03 13:21:56,599 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:22:19,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:22:19,615 - INFO - Epoch 36: Avg. Loss: 9.67543e-02 Remaining Time: 00h 24min 45s


2025-03-03 13:22:42,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:22:42,702 - INFO - Epoch 37: Avg. Loss: 9.56318e-02 Remaining Time: 00h 24min 22s


2025-03-03 13:23:05,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:23:05,794 - INFO - Epoch 38: Avg. Loss: 9.44131e-02 Remaining Time: 00h 23min 58s


2025-03-03 13:23:28,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:23:28,846 - INFO - Epoch 39: Avg. Loss: 9.37745e-02 Remaining Time: 00h 23min 35s


2025-03-03 13:23:51,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:23:51,910 - INFO - Epoch 40: Avg. Loss: 9.44696e-02 Remaining Time: 00h 23min 11s
2025-03-03 13:23:51,919 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:24:15,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:24:15,015 - INFO - Epoch 41: Avg. Loss: 9.42144e-02 Remaining Time: 00h 22min 48s


2025-03-03 13:24:38,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:24:38,089 - INFO - Epoch 42: Avg. Loss: 9.40920e-02 Remaining Time: 00h 22min 25s


2025-03-03 13:25:01,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:25:01,099 - INFO - Epoch 43: Avg. Loss: 9.33565e-02 Remaining Time: 00h 22min 01s


2025-03-03 13:25:24,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:25:24,172 - INFO - Epoch 44: Avg. Loss: 9.28516e-02 Remaining Time: 00h 21min 38s


2025-03-03 13:25:47,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:25:47,234 - INFO - Epoch 45: Avg. Loss: 9.17860e-02 Remaining Time: 00h 21min 14s
2025-03-03 13:25:47,244 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:26:10,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:26:10,255 - INFO - Epoch 46: Avg. Loss: 9.14034e-02 Remaining Time: 00h 20min 51s


2025-03-03 13:26:33,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:26:33,242 - INFO - Epoch 47: Avg. Loss: 9.07275e-02 Remaining Time: 00h 20min 28s


2025-03-03 13:26:56,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:26:56,334 - INFO - Epoch 48: Avg. Loss: 9.07133e-02 Remaining Time: 00h 20min 04s


2025-03-03 13:27:19,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:27:19,338 - INFO - Epoch 49: Avg. Loss: 9.11190e-02 Remaining Time: 00h 19min 41s


2025-03-03 13:27:42,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:27:42,522 - INFO - Epoch 50: Avg. Loss: 8.99571e-02 Remaining Time: 00h 19min 18s
2025-03-03 13:27:42,532 - LIGHT_DEBUG - Checkpoint saved model to Models/diffusion_v1.pth


2025-03-03 13:28:05,000 - LIGHT_DEBUG - Batch 25/25


2025-03-03 13:28:05,631 - INFO - Epoch 51: Avg. Loss: 9.46466e-02 Remaining Time: 00h 18min 55s


2025-03-03 13:28:16,000 - LIGHT_DEBUG - Batch 12/25