# Train Diffusion Model
### Imports

In [1]:
try: 
    import librosa
except:
    !pip install librosa


#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim

# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from MainScripts.Conf import conf

### Config
General

In [None]:
remote_kernel: bool = True

logging_level: int = logging.INFO
model_name: str = "diffusion_v7"
full_model_path: str = path_to_remote_path("{}/{}".format(conf["paths"].model_path, model_name + ".pth"), remote_kernel)
checkpoint_freq: int = 10 #0 for no checkpoint saving
training_data_name: str = "training_full_low_res"

device = "cuda" if torch.cuda.is_available() else "cpu"
restart_training: bool = True #If True and model already exists optimizer and lr_scheduler are reset
learning_rate: float = 5e-4 #Starting lr/first lr for Threshold Scheduler
epochs: int = 100
n_training_samples: int = 4000

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

### Data Loading

In [3]:
file: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_data_name + ".npy"), remote_kernel))[:n_training_samples, ...]
data_loader = create_dataloader(Audio_Data(file), conf["model"].batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

2025-03-17 18:01:40,506 - INFO - Data loaded with shape: (2000, 224, 416)


### Model Creation
U-Net

In [4]:
u_net = Conv_U_NET(in_channels=1,
                    time_embed_dim=conf["model"].time_embed_dim, 
                    n_starting_filters=conf["model"].n_starting_filters, 
                    n_downsamples=conf["model"].n_downsamples, 
                    activation=nn.SiLU(), 
                    device=device
                ).to(device)

optimizer = optim.AdamW(u_net.parameters(), lr=learning_rate)
scheduler = Threshold_LR(optimizer, [1, 0.1, 0.09, 0.85, 0.08, 0.07], [learning_rate, 2e-4, 1e-4, 1e-5, 1e-6, 1e-7])
start_epoch: int = 0

if os.path.exists(full_model_path):
    model = torch.load(full_model_path, map_location=device)
    u_net.load_state_dict(model["model"])
    if not restart_training:
        optimizer.load_state_dict(model["optim"])
        scheduler.load_state_dict(model["scheduler"])
        start_epoch = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(u_net)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(u_net)} Parameters")

2025-03-17 18:01:42,296 - INFO - Model diffusion_v7 loaded with 72857025 Parameters


Diffusion

In [5]:
diffusion = Diffusion(model=u_net, 
                        noise_steps=conf["model"].diffusion_timesteps, 
                        noise_schedule="linear", 
                        input_dim=[conf["model"].batch_size, 1, file.shape[-2], file.shape[-1]],
                        device=device
                    )

#diffusion.visualize_diffusion_steps(x=torch.Tensor(file[:1]), noise_schedule=noise_schedule, device=device, n_spectograms=5)

### Train

In [None]:
x = diffusion.train(epochs=epochs, 
                    data_loader=data_loader, 
                    loss_function=nn.MSELoss(),
                    optimizer=optimizer, 
                    lr_scheduler=scheduler, 
                    gradient_accum=conf["model"].gradient_accum,
                    checkpoint_freq=checkpoint_freq, 
                    model_path=full_model_path, 
                    start_epoch=start_epoch
                )
scatter_plot(x)

2025-03-17 18:01:42,500 - INFO - Training started on cuda
2025-03-17 18:03:01,340 - INFO - Epoch 011: Avg. Loss: 8.00835e-02 Remaining Time: 02h 10min 04s LR: 1.00000e-04
2025-03-17 18:04:22,276 - INFO - Epoch 012: Avg. Loss: 7.88112e-02 Remaining Time: 02h 10min 28s LR: 5.00000e-05
2025-03-17 18:05:40,788 - INFO - Epoch 013: Avg. Loss: 8.03173e-02 Remaining Time: 02h 08min 24s LR: 1.00000e-04
2025-03-17 18:07:01,559 - INFO - Epoch 014: Avg. Loss: 7.57898e-02 Remaining Time: 02h 07min 37s LR: 5.00000e-05
2025-03-17 18:08:21,792 - INFO - Epoch 015: Avg. Loss: 7.64216e-02 Remaining Time: 02h 06min 26s LR: 5.00000e-05
2025-03-17 18:09:41,813 - INFO - Epoch 016: Avg. Loss: 8.04569e-02 Remaining Time: 02h 05min 09s LR: 1.00000e-04
2025-03-17 18:11:02,520 - INFO - Epoch 017: Avg. Loss: 7.80051e-02 Remaining Time: 02h 04min 00s LR: 5.00000e-05
2025-03-17 18:12:24,995 - INFO - Epoch 018: Avg. Loss: 8.14918e-02 Remaining Time: 02h 03min 08s LR: 1.00000e-04
2025-03-17 18:13:45,324 - INFO - Epoch