# Train Diffusion Model
### Imports

In [None]:
try: 
    import librosa
except:
    !pip install librosa


#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim

# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from MainScripts.Conf import conf

No module named 'librosa'
Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.61.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-0.5.0.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.44.0-cp311-cp311-manylinux_2

### Config
General

In [7]:
remote_kernel: bool = True

logging_level: int = logging.INFO
model_name: str = "diffusion_v4"
full_model_path: str = path_to_remote_path("{}/{}".format(conf["paths"].model_path, model_name + ".pth"), remote_kernel)
checkpoint_freq: int = 10 #0 for no checkpoint saving
training_data_name: str = "training_full_low_res"

device = "cuda" if torch.cuda.is_available() else "cpu"
restart_training: bool = True #If True and model already exists optimizer and lr_scheduler are reset
learning_rate: float = 1e-5
total_epochs: int = 300 #total training epochs 
epochs: int = 100 #epochs for current run
n_training_samples: int = 2000

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

### Data Loading

In [3]:
file: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_data_name + ".npy"), remote_kernel))[:n_training_samples, ...]
data_loader = create_dataloader(Audio_Data(file), conf["model"].batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-03-13 17:01:48,090 - INFO - Data loaded with shape: (2000, 224, 416)


### Model Creation
U-Net

In [8]:
u_net = Conv_U_NET(in_channels=1,
                    time_embed_dim=conf["model"].time_embed_dim, 
                    n_starting_filters=conf["model"].n_starting_filters, 
                    n_downsamples=conf["model"].n_downsamples, 
                    activation=nn.GELU(), 
                    device=device
                ).to(device)

optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs, eta_min=1e-8)
start_epoch: int = 0

if os.path.exists(full_model_path):
    model = torch.load(full_model_path, map_location=device)
    u_net.load_state_dict(model["model"])
    if not restart_training:
        optimizer.load_state_dict(model["optim"])
        scheduler.load_state_dict(model["scheduler"])
        start_epoch = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

2025-03-13 17:21:14,030 - INFO - Model diffusion_v4 loaded with 17288513 Parameters


Diffusion

In [9]:
diffusion = Diffusion(model=u_net, 
                        noise_steps=conf["model"].diffusion_timesteps, 
                        noise_schedule="linear", 
                        input_dim=[conf["model"].batch_size, 1, file.shape[-2], file.shape[-1]],
                        device=device
                    )

#diffusion.visualize_diffusion_steps(x=torch.Tensor(file[:1]), noise_schedule=noise_schedule, device=device, n_spectograms=5)

### Train

In [None]:
x = diffusion.train(epochs=epochs, 
                    data_loader=data_loader, 
                    loss_function=nn.MSELoss(),
                    optimizer=optimizer, 
                    lr_scheduler=scheduler, 
                    gradient_accum=conf["model"].gradient_accum,
                    checkpoint_freq=checkpoint_freq, 
                    model_path=full_model_path, 
                    start_epoch=start_epoch
                )
scatter_plot(x)

2025-03-13 17:21:17,002 - INFO - Training started on cuda
2025-03-13 17:22:20,563 - INFO - Epoch 01: Avg. Loss: 1.13831e-01 Remaining Time: 01h 44min 52s LR: 9.99973e-06
2025-03-13 17:23:23,893 - INFO - Epoch 02: Avg. Loss: 1.20976e-01 Remaining Time: 01h 43min 37s LR: 9.99890e-06
2025-03-13 17:24:29,405 - INFO - Epoch 03: Avg. Loss: 1.19720e-01 Remaining Time: 01h 43min 40s LR: 9.99754e-06
2025-03-13 17:25:34,384 - INFO - Epoch 04: Avg. Loss: 1.21057e-01 Remaining Time: 01h 42min 56s LR: 9.99562e-06
2025-03-13 17:26:35,093 - INFO - Epoch 05: Avg. Loss: 1.16676e-01 Remaining Time: 01h 40min 43s LR: 9.99315e-06
2025-03-13 17:27:38,020 - INFO - Epoch 06: Avg. Loss: 1.21376e-01 Remaining Time: 01h 39min 29s LR: 9.99014e-06
2025-03-13 17:28:41,908 - INFO - Epoch 07: Avg. Loss: 1.14651e-01 Remaining Time: 01h 38min 30s LR: 9.98659e-06
2025-03-13 17:29:43,638 - INFO - Epoch 08: Avg. Loss: 1.20707e-01 Remaining Time: 01h 37min 06s LR: 9.98248e-06
