# Train Diffusion Model
### Imports

In [1]:
#!pip install librosa

#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
#from torchsummary import summary
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from MainScripts.Conf import *

Collecting librosa
  Downloading librosa-0.10.2.post1-py3-none-any.whl.metadata (8.6 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.61.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-0.5.0.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.44.0-cp311-cp311-manylinux_2_17_x86_64.manylinux

### Config
General

In [2]:
remote_kernel: bool = True

logging_level: int = logging.INFO
model_name: str = "diffusion_v4"
model_path: str = path_to_remote_path(f"{MODEL_PATH}/{model_name}.pth", remote_kernel)
checkpoint_freq: int = 10 #0 for no checkpoint saving
training_data_name: str = "training_full_low_res"

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size: int = 16
epochs: int = 100
diffusion_timesteps: int = 500
n_training_samples: int = 2000

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

U-Net

In [3]:
learning_rate: float = 1e-4
n_starting_filters: int = 32
n_downsamples: int = 3
time_embed_dim: int = 128


### Data Loading

In [5]:
file: ndarray = load_training_data(path_to_remote_path(f"{DATA_PATH}/{training_data_name}.npy", remote_kernel))#[:n_training_samples, ...]
file.shape

ValueError: cannot reshape array of size 633602016 into shape (7087,224,416)

In [None]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-03-07 09:08:10,935 - INFO - Data loaded with shape: (7087, 224, 416)


### Model Creation
U-Net

In [None]:
u_net = Conv_U_NET(in_channels=1, time_embed_dim=time_embed_dim, n_starting_filters=n_starting_filters, n_downsamples=n_downsamples, activation=nn.GELU(), device=device)
optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200, eta_min=1e-6)
if os.path.exists(model_path):
    model = torch.load(model_path)
    u_net.load_state_dict(model["model"], map_location=device)
    optimizer.load_state_dict(model["optimizer"])
    scheduler.load_state_dict(model["scheduler"])
    start_epoch: int = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    start_epoch: int = 0
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")
u_net = u_net.to(device)

Diffusion

In [None]:
diffusion = Diffusion(model=u_net, noise_steps=diffusion_timesteps, noise_schedule="linear", input_dim=[8, 1, file.shape[-2], file.shape[-1]], device=device)

#diffusion.visualize_diffusion_steps(x=torch.Tensor(file[:1]), noise_schedule=noise_schedule, device=device, n_spectograms=5)

### Train

In [None]:
x = diffusion.train(epochs=epochs, data_loader=data_loader,loss_function=nn.MSELoss(),optimizer=optimizer, lr_scheduler=scheduler, gradient_accum=2, checkpoint_freq=checkpoint_freq, model_path=model_path, start_epoch=start_epoch)
scatter_plot(x)

2025-03-07 09:08:23,303 - INFO - Training started on cuda
2025-03-07 09:18:45,050 - INFO - Epoch 01: Avg. Loss: 1.09582e-01 Remaining Time: 17h 05min 52s LR: 9.99993e-05
2025-03-07 09:29:04,600 - INFO - Epoch 02: Avg. Loss: 9.44737e-02 Remaining Time: 16h 53min 42s LR: 9.99973e-05
2025-03-07 09:39:25,461 - INFO - Epoch 03: Avg. Loss: 9.40548e-02 Remaining Time: 16h 43min 28s LR: 9.99939e-05
2025-03-07 09:49:46,412 - INFO - Epoch 04: Avg. Loss: 9.36952e-02 Remaining Time: 16h 33min 13s LR: 9.99891e-05
2025-03-07 10:00:02,536 - INFO - Epoch 05: Avg. Loss: 9.08883e-02 Remaining Time: 16h 21min 24s LR: 9.99830e-05
2025-03-07 10:10:21,395 - INFO - Epoch 06: Avg. Loss: 9.02535e-02 Remaining Time: 16h 10min 48s LR: 9.99756e-05
2025-03-07 10:20:39,608 - INFO - Epoch 07: Avg. Loss: 8.88492e-02 Remaining Time: 16h 00min 09s LR: 9.99668e-05
2025-03-07 10:30:59,135 - INFO - Epoch 08: Avg. Loss: 8.65391e-02 Remaining Time: 15h 49min 50s LR: 9.99566e-05
2025-03-07 10:41:22,499 - INFO - Epoch 09: Avg

: 