In [1]:
#Cicadence
import torch
import torch.nn as nn
import torch.optim as optim
import glob
import numpy as np
import matplotlib.pyplot as plt
import librosa
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
from spectro_data import generateSpectrogramFiles, SpectrogramDataset
from models.cicada_base import CicadaBaseAutoencoder
from models.cicada_custom import CicadaCustomAutoencoder

NOISY_DATA_PATH = "data/processed/28spk/noisy_specs.pt"
CLEAN_DATA_PATH = "data/processed/28spk/clean_specs.pt"

batch_size = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42) #Consistent results


<torch._C.Generator at 0x1130df050>

In [2]:
# generateSpectrogramFiles("data/raw/28spk/clean_train/", CLEAN_DATA_PATH)
# generateSpectrogramFiles("data/raw/28spk/noisy_train/", NOISY_DATA_PATH)

In [3]:
data = SpectrogramDataset(NOISY_DATA_PATH, CLEAN_DATA_PATH)

train_size = int(0.8 * len(data))
val_size = int(0.15 * len(data))
test_size = len(data) - train_size - val_size  # Ensure all samples are used

train_set, val_set, test_set = random_split(data, [train_size, val_size, test_size])
print(f"Train: {len(train_set)}, Val: {len(val_set)}, Test: {len(test_set)}")


train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)


Train: 9257, Val: 1735, Test: 580


In [4]:
noisy_batch, clean_batch = next(iter(train_loader))
print(noisy_batch.shape, clean_batch.shape)  # Expected: [32, 1, 257, 291]

torch.Size([32, 1, 256, 290]) torch.Size([32, 1, 256, 290])


In [5]:
# model = CicadaBaseAutoencoder()
model = CicadaCustomAutoencoder()
num_epochs=2
lr=1e-4


In [6]:
def compute_snr(clean, estimate):
    noise = clean - estimate
    snr = 10 * torch.log10(torch.sum(clean ** 2) / torch.sum(noise ** 2))
    return snr.item()

In [7]:
def train(model, train_loader, val_loader, num_epochs=num_epochs, learning_rate=lr, device=device):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):

        # Training
        print("Training ... ")
        model.train()
        train_loss = 0.0
        for noisy, clean in tqdm(train_loader):
            noisy, clean = noisy.to(device), clean.to(device)

            optimizer.zero_grad()
            outputs = model(noisy)
            sample_loss = criterion(outputs, clean)
            sample_loss.backward()
            optimizer.step()

            train_loss += sample_loss.item() * noisy.size(0)
        
        train_loss /= len(train_loader.dataset)

        #Validation
        print("Evaluating ... ")
        model.eval()
        val_loss = 0.0
        with torch.no_grad():  # No gradient computation
            for noisy, clean in tqdm(val_loader):
                noisy, clean = noisy.to(device), clean.to(device)

                outputs = model(noisy)
                loss = criterion(outputs, clean)
                val_loss += loss.item() * noisy.size(0)

                #Computer SNR for evaluation
                snr_noisy = compute_snr(clean, noisy)
                snr_output = compute_snr(clean, outputs)

                # Compute SNR improvement
                snr_improvement = snr_output - snr_noisy
                total_snr_improvement += snr_improvement * noisy.size(0)

            val_loss /= len(val_loader.dataset)  # Average loss
            avg_snr_improvement = total_snr_improvement / len(val_loader.dataset)  

        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | SNR Improvement: {avg_snr_improvement:.2f} dB")

    print("Training Complete!")
    
    return model

In [8]:
train(model, train_loader, val_loader, num_epochs=num_epochs, learning_rate=lr, device=device)

Training ... 


  2%|▏         | 7/290 [00:42<28:21,  6.01s/it]


KeyboardInterrupt: 