# Train Diffusion Model
### Imports

In [1]:
#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
import torch.optim as optim
# Utils
import numpy as np
from numpy import ndarray
import logging

# Base Scripts
from Libraries.U_Net import *
from Libraries.Diffusion import *
from Libraries.Utils import *
from Conf import *

### Config
General

In [None]:
logging_level: int = logging.INFO
model_name: str = "diffusion_v1"
model_path: str = f"{MODEL_PATH}/{model_name}.pth"
checkpoint_freq: int = 5 #0 for no checkpoint saving
training_data_name: str = "training_1280"

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size: int = 32
epochs: int = 100
diffusion_timesteps: int = 500

logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

U-Net

In [None]:
learning_rate: float = 1e-5
lr_decay: int = 40
lr_gamma: float = 0.1
n_starting_filters: int = 32
n_blocks: int = 2 #Each samples down by factor of 2
n_groups: int = 8 #For group norm

### Data Loading

In [None]:
file = load_training_data(f"{DATA_PATH}/{training_data_name}.npy")

In [None]:
data_loader = create_dataloader(Audio_Data(file), batch_size)
logger.info(f"Data loaded with shape: {file.shape}")

### Model Creation
U-Net

In [None]:

model = U_NET(in_channels=1, device=device, input_shape=[0, 0, file.shape[-2], file.shape[-1]], n_res_layers=n_blocks, n_starting_filters=n_starting_filters, n_groups=n_groups).to(device)
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, weights_only=False, map_location=device))
    logger.info(f"Model {model_name} loaded with {count_parameters(model)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(model)} Parameters")

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay, gamma=lr_gamma)

Diffusion