# Train Diffusion Model
### Imports

In [9]:
#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from Conf import *

### Config
General

In [10]:
logging_level: int = LIGHT_DEBUG #logging.INFO
model_name: str = "diffusion_v1"
model_path: str = f"{MODEL_PATH}/{model_name}.pth"
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_1280"

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size: int = 2
epochs: int = 100
diffusion_timesteps: int = 500

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

U-Net

In [11]:
learning_rate: float = 1e-5
lr_decay: int = 40
lr_gamma: float = 0.1
n_starting_filters: int = 32
n_blocks: int = 2 #Each samples down by factor of 2
n_groups: int = 8 #For group norm
time_embed_dim: int = 128

### Data Loading

In [12]:
file = load_training_data(f"{DATA_PATH}/{training_data_name}.npy")[:4, ...]

2025-02-28 22:27:52,266 - LIGHT_DEBUG - Ndarray loaded from ../Data/training_1280.npy of shape: (1280, 1024, 672)


In [13]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-02-28 22:27:52,458 - INFO - Data loaded with shape: (4, 1024, 672)


### Model Creation
U-Net

In [14]:

u_net = U_NET(in_channels=1, device=device, input_shape=[0, 0, file.shape[-2], file.shape[-1]], n_res_layers=n_blocks, n_starting_filters=n_starting_filters, n_groups=n_groups, time_emb_dim=time_embed_dim).to(device)
if os.path.exists(model_path):
    u_net.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate, weight_decay=0.05)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay, gamma=lr_gamma)
u_net = nn.DataParallel(u_net)
u_net = u_net.to(device)

2025-02-28 22:27:52,894 - INFO - Model diffusion_v1 loaded with 833281 Parameters


Diffusion

In [15]:
diffusion = Diffusion(u_net=u_net, u_net_optimizer=optimizer,diffusion_timesteps=diffusion_timesteps)
noise_schedule = cosine_noise(T=diffusion_timesteps)

### Train

In [None]:
diffusion.train(data_loader, device, epochs=epochs, loss_function=U_Net_loss, noise_schedule=noise_schedule, checkpoint_freq=checkpoint_freq, model_path=model_path)

2025-02-28 22:29:15,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:30:14,547 - INFO - Epoch 01: Avg. Loss: 5.88422e-01 Remaining Time: 03h 53min 31s


2025-02-28 22:30:14,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:31:07,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:32:05,563 - INFO - Epoch 02: Avg. Loss: 5.78866e-01 Remaining Time: 03h 26min 14s


2025-02-28 22:32:05,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:33:02,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:34:09,107 - INFO - Epoch 03: Avg. Loss: 5.80952e-01 Remaining Time: 03h 22min 39s


2025-02-28 22:34:09,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:35:08,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:36:04,272 - INFO - Epoch 04: Avg. Loss: 5.70159e-01 Remaining Time: 03h 16min 29s


2025-02-28 22:36:04,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:36:59,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:38:00,050 - INFO - Epoch 05: Avg. Loss: 5.59024e-01 Remaining Time: 03h 12min 13s
2025-02-28 22:38:00,162 - LIGHT_DEBUG - Checkpoint saved model to ../Models/diffusion_v1.pth


2025-02-28 22:38:00,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:38:53,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:39:49,353 - INFO - Epoch 06: Avg. Loss: 5.58179e-01 Remaining Time: 03h 07min 00s


2025-02-28 22:39:49,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:40:44,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:41:44,408 - INFO - Epoch 07: Avg. Loss: 5.45393e-01 Remaining Time: 03h 04min 03s


2025-02-28 22:41:44,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:42:37,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:43:45,164 - INFO - Epoch 08: Avg. Loss: 5.46728e-01 Remaining Time: 03h 02min 27s


2025-02-28 22:43:45,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:44:50,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:45:53,022 - INFO - Epoch 09: Avg. Loss: 5.43807e-01 Remaining Time: 03h 01min 58s


2025-02-28 22:45:53,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:46:51,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:47:58,718 - INFO - Epoch 10: Avg. Loss: 5.41212e-01 Remaining Time: 03h 00min 49s
2025-02-28 22:47:58,749 - LIGHT_DEBUG - Checkpoint saved model to ../Models/diffusion_v1.pth


2025-02-28 22:47:58,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:48:50,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:49:49,879 - INFO - Epoch 11: Avg. Loss: 5.29968e-01 Remaining Time: 02h 57min 33s


2025-02-28 22:49:49,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:50:50,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:51:42,822 - INFO - Epoch 12: Avg. Loss: 5.31602e-01 Remaining Time: 02h 54min 43s


2025-02-28 22:51:42,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:52:42,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:53:34,561 - INFO - Epoch 13: Avg. Loss: 5.26413e-01 Remaining Time: 02h 51min 55s


2025-02-28 22:53:34,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:54:26,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:55:25,964 - INFO - Epoch 14: Avg. Loss: 5.25041e-01 Remaining Time: 02h 49min 12s


2025-02-28 22:55:25,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:56:29,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:57:28,793 - INFO - Epoch 15: Avg. Loss: 5.21282e-01 Remaining Time: 02h 47min 41s
2025-02-28 22:57:28,861 - LIGHT_DEBUG - Checkpoint saved model to ../Models/diffusion_v1.pth


2025-02-28 22:57:28,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 22:58:30,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 22:59:39,202 - INFO - Epoch 16: Avg. Loss: 5.20425e-01 Remaining Time: 02h 46min 46s


2025-02-28 22:59:39,000 - LIGHT_DEBUG - Batch 02/02
2025-02-28 23:00:37,000 - LIGHT_DEBUG - Batch 01/02

2025-02-28 23:01:29,372 - INFO - Epoch 17: Avg. Loss: 5.14987e-01 Remaining Time: 02h 44min 03s


2025-02-28 23:01:29,000 - LIGHT_DEBUG - Batch 02/02
