# Training Notebook for Attention UNet
### See train_unet.ipynb for more details (The training notebooks follow similar structure)

In [None]:
import sys
import os
sys.path.append("../")

if "notebook" in os.getcwd():
    os.chdir("../")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import traceback
import gc
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
from src.data.waveform_data import WaveformDataset
from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical
from skopt.utils import use_named_args
from src.models.waveform.cicada_unet_att import CicadaUNetAttModel

model = CicadaUNetAttModel()

NOISY_DATA_PATH = f"data/processed/28spk/combined_noisy_waves.pt"
CLEAN_DATA_PATH = f"data/processed/28spk/combined_clean_waves.pt"

batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42) #Consistent results

HYPER_OPT = False #Enable for Bayesian Hyperparameter Sweeps
f = 0.25 if HYPER_OPT else 1
num_epochs = 10 if HYPER_OPT else 30
lr = 1e-4 if HYPER_OPT else 0.0002848026275422517

print(num_epochs, lr)


In [None]:
data = WaveformDataset(NOISY_DATA_PATH, CLEAN_DATA_PATH, fraction=f)

train_size = int(0.8 * len(data))
val_size = int(0.15 * len(data))
test_size = len(data) - train_size - val_size  # Ensure all samples are used

train_set, val_set, test_set = random_split(data, [train_size, val_size, test_size])
print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)


In [None]:
noisy_batch, clean_batch = next(iter(train_loader))
print(noisy_batch.shape, clean_batch.shape)  # Expected: [32, 1, 262144]

In [None]:

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")

    return total_params, trainable_params
count_parameters(model)



In [None]:
t_losses = []
v_losses = []
snrs = []

def compute_snr(clean, estimate):
    noise = clean - estimate
    snr = 10 * torch.log10(torch.sum(clean ** 2) / torch.sum(noise ** 2))
    return snr.item()
    
def train(model, train_loader, val_loader, num_epochs=num_epochs, lr=1e-4, n_encoders=None, s=None, k=None, 
          num_heads=None, hidden_channels=None, d=None, device=device):
    model = CicadaUNetAttModel(n_encoders=n_encoders, s=s, k=k, num_heads=num_heads, hidden_channels=hidden_channels, d=d)
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    best_snr = 0

    for epoch in range(num_epochs):

        # Training
        print("Training ... ")
        model.train()
        train_loss = 0.0
        for noisy, clean in tqdm(train_loader):
            noisy, clean = noisy.to(device), clean.to(device)

            optimizer.zero_grad()
            outputs = model(noisy)
            sample_loss = criterion(outputs, clean)
            sample_loss.backward()
            optimizer.step()

            train_loss += sample_loss.item() * noisy.size(0)
        
        train_loss /= len(train_loader.dataset)

        #Validation
        print("Evaluating ... ")
        model.eval()
        val_loss = 0.0
        total_snr_improvement = 0
        with torch.no_grad():  # No gradient computation
            for noisy, clean in tqdm(val_loader):
                noisy, clean = noisy.to(device), clean.to(device)

                outputs = model(noisy)
                loss = criterion(outputs, clean)
                val_loss += loss.item() * noisy.size(0)

                #Computer SNR for evaluation
                snr_noisy = compute_snr(clean, noisy)
                snr_output = compute_snr(clean, outputs)

                # Compute SNR improvement
                snr_improvement = snr_output - snr_noisy
                total_snr_improvement += snr_improvement * noisy.size(0)

            val_loss /= len(val_loader.dataset)  # Average loss
            avg_snr_improvement = total_snr_improvement / len(val_loader.dataset)
            if avg_snr_improvement > best_snr:
                best_snr = avg_snr_improvement
                if not HYPER_OPT:
                    torch.save(model.state_dict(), f"ckpts/cicadence_unet_att_epoch_{epoch}.pt")
                
        t_losses.append(train_loss)
        v_losses.append(val_loss)
        snrs.append(avg_snr_improvement)
        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | SNR Improvement: {avg_snr_improvement:.2f} dB")

    print("Training Complete!")
    if HYPER_OPT:
        return val_loss
    return model

In [None]:
def objective(params):
    n_encoders = params[0]
    s_start = params[1]
    kernel = params[2]
    lr = params[3]
    num_heads = params[4]
    hidden_channels = params[5]
    d = params[6]
    offsets = params[7:]
    s = [1]
    for i in range(n_encoders):
        s.append(s_start)
        if s_start * 2 - offsets[i] <= s_start:
            s_start *= 2
        else:
            s_start = s_start * 2 - offsets[i]
    k = [kernel for i in range(n_encoders)]
    return train(model, train_loader, val_loader, num_epochs=num_epochs, lr=lr, n_encoders=n_encoders, s=s, k=k, 
                 num_heads=num_heads, hidden_channels=hidden_channels, d=d)
    
if HYPER_OPT:
    space = [
        Integer(5, 8, name='n_encoders'),
        Integer(2, 16, name='s_start'),    # Starting hidden size
        Categorical([3, 5, 7], name='kernel_size'),
        Real(1e-5, 1e-3, prior='log-uniform', name='learning_rate'),
        Integer(2, 8, name='num_heads'),
        Real(0.3, 0.7, prior='uniform', name='hidden_channels'),
        Real(0.1, 0.4, prior='uniform', name='dropout'),
        *[Integer(1, 10, name=f's_step{i}') for i in range(1, 11)]
    ]
        

In [None]:
try:
    if HYPER_OPT:
        results = gp_minimize(objective, space, n_calls=25, random_state=42)
        print("Best hyperparameters:")
        for dim, val in zip(space, results.x):
            print(f"{dim.name}: {val}")
    else:
        print("ACTUALLY TRAINING")
        train(model, train_loader, val_loader, num_epochs=num_epochs, lr=lr, device=device)
except Exception as e:
    print(e)
    traceback.print_exc()
finally:
    gc.collect()

In [None]:
print(t_losses)
print(v_losses)
print(snrs)