# Model creation

In [None]:
#import stantment
import os
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from tqdm import tqdm
import matplotlib.pyplot as plt

# Declare device constant
DEVICE = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
BATCH_SIZE = 8
EPOCHS = 80
LEARNING_RATE = 1e-4
SAMPLE_RATE = 16000
N_MELS = 128
CLEAR_DIR = '../../data/train/clean'
DEGRADED_DIR = '../../data/train/degraded'
MODEL_SAVE = '../../model/UNet_audio_restoration.pth'

# Dataset
This is the dataset creation

In [None]:
class AudioPairDataset(Dataset):
    
    def __init__(self, clean_dir, degraded_dir, sample_rate=16000, n_mels=128, n_frames=256):
        self.clean_dir = clean_dir
        self.degraded_dir = degraded_dir
        self.filenames = sorted(os.listdir(clean_dir))
        self.sample_rate = sample_rate
        self.n_frames = n_frames
        self.n_mels = n_mels

        # Trasformazioni
        self.to_mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=512,
            n_mels=n_mels
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]

        # Percorsi file
        clean_path = os.path.join(self.clean_dir, fname)
        degraded_path = os.path.join(self.degraded_dir, fname)

        # Carica waveform
        clean_waveform, _ = torchaudio.load(clean_path)
        degraded_waveform, _ = torchaudio.load(degraded_path)

        # Converti a mono
        clean_waveform = clean_waveform.mean(dim=0, keepdim=True)
        degraded_waveform = degraded_waveform.mean(dim=0, keepdim=True)

        # MelSpectrogram + dB
        clean_mel = self.amplitude_to_db(self.to_mel(clean_waveform))
        degraded_mel = self.amplitude_to_db(self.to_mel(degraded_waveform))

        # Forza dimensione esatta (1, 128, n_frames)
        def adjust_shape(tensor, target_shape):
            c, f, t = tensor.shape
            pad_t = target_shape[2] - t
            if pad_t > 0:
                tensor = torch.nn.functional.pad(tensor, (0, pad_t))
            else:
                tensor = tensor[:, :, :target_shape[2]]
            return tensor

        target_shape = (1, self.n_mels, self.n_frames)
        clean_mel = adjust_shape(clean_mel, target_shape)
        degraded_mel = adjust_shape(degraded_mel, target_shape)

        return degraded_mel, clean_mel

# Model

This is the model part

In [None]:
class UNetBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace = True)
        )

    def forward(self, x):
        return self.block(x)
        

class UNet(nn.Module):
    
    def __init__(self, in_channels = 1, out_channels = 1):
        super().__init__()

        self.enc1 = UNetBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = UNetBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = UNetBlock(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = UNetBlock(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = UNetBlock(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = UNetBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = UNetBlock(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))

        # Bottleneck
        b = self.bottleneck(self.pool3(e3))

        # Decoder with skip connections
        d3 = self.up3(b)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        # Output
        out = self.final_conv(d1)
        return out


# Training
Training part

In [None]:
train_dataset = AudioPairDataset(
    clean_dir = CLEAR_DIR, 
    degraded_dir = DEGRADED_DIR,
    sample_rate = SAMPLE_RATE,
    n_mels = N_MELS
)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)

model = UNet(in_channels = 1, out_channels = 1).to(DEVICE)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)


loss_history = []
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    loop = tqdm(train_loader, desc = f"Epoch {epoch + 1}/{EPOCHS}", leave = True)

    for degraded, clean in loop:
        degraded, clean = degraded.to(DEVICE), clean.to(DEVICE)

        output = model(degraded)
        loss = criterion(output, clean)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        loop.set_postfix(loss = loss.item())

    avg_loss = epoch_loss / len(train_loader)
    loss_history.append(avg_loss)


torch.save(model.state_dict(), MODEL_SAVE)
    

# Plot

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(range(1, EPOCHS + 1), loss_history, marker='o')
plt.title("Training Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("L1 Loss")
plt.grid(True)
plt.tight_layout()
plt.show()