# Model creation

In [34]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
from torchaudio.transforms import Spectrogram, AmplitudeToDB
from torch import hann_window
from tqdm import tqdm
import matplotlib.pyplot as plt

DEVICE = ("cuda" if torch.cuda.is_available()
          else "mps" if torch.backends.mps.is_available()
          else "cpu")

if DEVICE == 'cuda':
    from torch.cuda.amp import autocast, GradScaler
else:
    from contextlib import nullcontext
    autocast = nullcontext  
    
    class GradScaler:
        def scale(self, loss): return loss
        def step(self, optimizer): optimizer.step()
        def update(self): pass
        def __getattr__(self, name): return lambda *args, **kwargs: None

BATCH_SIZE   = 8
EPOCHS       = 1
LEARNING_RATE= 5e-5
SAMPLE_RATE  = 16000
N_FFT        = 1024
HOP_LENGTH   = 32     
WIN_LENGTH   = N_FFT
N_FRAMES     = 256
CLEAR_DIR    = '../../data/train/clean'
DEGRADED_DIR = '../../data/train/degraded'
MODEL_SAVE   = '../../model/UNet_audio_restoration.pth'
CHECKPOINT_DIR = '../../test'

print(f'Device use: {DEVICE}')

Device use: mps


# Dataset

AudioPairDataset is a class extending `torch.utils.data.Dataset` to handle pairs of “clean” and “degraded” audio files. It loads the signals, converts them to decibel-scale spectrograms, normalizes them, and ensures a fixed length via padding or cropping, returning a tuple ready for denoising model training.

---

## Constructor Parameters

- **clean_dir**  
  Directory containing clean audio files.  
- **degraded_dir**  
  Directory containing degraded audio files.  
- **sample_rate** (default: 16000)  
  Expected sampling rate (read from file by `torchaudio.load`).  
- **n_fft** (default: 1024)  
  FFT window size used for spectrogram computation.  
- **hop_length** (default: 512)  
  Hop length (in samples) between successive windows.  
- **n_frames** (default: 256)  
  Desired number of time frames per spectrogram, enforced via padding or cropping.

### Internal Components

- **File list**: stores sorted filenames from `clean_dir`.  
- **Spectrogram**: instance computing a power spectrogram from raw signal.  
- **AmplitudeToDB**: converts power spectrogram values to decibel scale.

---

#### `__len__`

Returns the total number of examples (i.e., the number of files in the “clean” directory), enabling PyTorch DataLoader support for batching.

---

## Padding or Cropping Helper

- Accepts a 3-D spectrogram tensor of shape `[channels, frequency_bins, time_frames]`.  
- If `time_frames` < `n_frames`, pads with zeros at the end to reach the desired length.  
- If `time_frames` ≥ `n_frames`, crops to keep only the first `n_frames`.

---

#### `__getitem__`

1. Selects the filename at the given index.  
2. Loads clean and degraded audio with `torchaudio.load`.  
3. Converts to mono by averaging channels.  
4. Computes power spectrogram and converts to decibel scale.  
5. Applies z-score normalization (subtract mean, divide by standard deviation plus a small epsilon).  
6. Pads or crops each spectrogram to ensure `n_frames`.  
7. Returns a tuple `(degraded_spectrogram, clean_spectrogram)`.

---

In [35]:
class AudioPairDataset(Dataset):
    def __init__(self, clean_dir, degraded_dir,
                 sample_rate=16000, n_fft=1024,
                 hop_length=128, n_frames=256):
        self.clean_dir    = clean_dir
        self.degraded_dir = degraded_dir
        self.files        = sorted(os.listdir(clean_dir))
        self.n_frames     = n_frames
        # Spectrogram senza window arg
        self.spec  = Spectrogram(
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=n_fft,
            power=2.0
        )
        self.to_db = AmplitudeToDB(stype='power')


    def __len__(self):
        return len(self.files)

    def pad_or_crop(self, S):
        # S: [C, F, T]
        C, freq, time = S.shape
        if time < self.n_frames:
            pad_amount = self.n_frames - time
            return F.pad(S, (0, pad_amount))
        else:
            return S[:, : , :self.n_frames]

    def __getitem__(self, idx):
        fn = self.files[idx]
        cw, _ = torchaudio.load(os.path.join(self.clean_dir, fn))
        dw, _ = torchaudio.load(os.path.join(self.degraded_dir, fn))
        cw = cw.mean(0, keepdim=True)
        dw = dw.mean(0, keepdim=True)
        S_c = self.to_db(self.spec(cw))
        S_d = self.to_db(self.spec(dw))
        S_c = (S_c - S_c.mean()) / (S_c.std() + 1e-6)
        S_d = (S_d - S_d.mean()) / (S_d.std() + 1e-6)
        S_c = self.pad_or_crop(S_c)
        S_d = self.pad_or_crop(S_d)
        return S_d, S_c

# Model

The UNet model is a fully convolutional neural network designed for image-to-image tasks. It consists of symmetric encoder and decoder paths with skip connections, allowing precise localization and context integration.

---

## Building Block: UNetBlock

Each UNetBlock comprises two convolutional layers with the following sequence:
- 3×3 convolution (padding=1)  
- Instance normalization (affine)  
- ReLU activation (in-place)  
- 3×3 convolution (padding=1)  
- Instance normalization (affine)  
- ReLU activation (in-place)

This block preserves spatial dimensions and refines feature representations.

---

## Architecture Components

### Encoder Path

1. **enc1**: UNetBlock mapping `in_channels → 64`  
2. **pool**: 2×2 max pooling  
3. **enc2**: UNetBlock mapping `64 → 128`  
4. **pool**: 2×2 max pooling  
5. **enc3**: UNetBlock mapping `128 → 256`  

Each encoder stage reduces spatial resolution by half after the block via max pooling.

### Bottleneck

- **bottleneck**: UNetBlock mapping `256 → 512`  
- Applied after the third pooling to capture the most abstract features.

### Decoder Path

1. **up3**: Transposed convolution `512 → 256` upsampling by factor 2  
2. **dec3**: UNetBlock on concatenated `[upsampled, cropped enc3]` (`512 → 256`)  
3. **up2**: Transposed convolution `256 → 128` upsampling by factor 2  
4. **dec2**: UNetBlock on concatenated `[upsampled, cropped enc2]` (`256 → 128`)  
5. **up1**: Transposed convolution `128 → 64` upsampling by factor 2  
6. **dec1**: UNetBlock on concatenated `[upsampled, cropped enc1]` (`128 → 64`)  

Skip connections fuse high-resolution encoder features with decoder upsampled maps for precise reconstruction.

### Final Convolution

- **final_conv**: 1×1 convolution mapping `64 → out_channels` to produce the output feature map.

---

## Center Crop Utility

A helper function crops encoder feature maps to match decoder spatial dimensions before concatenation:
- Computes offsets `(dh, dw)` from shape differences  
- Returns centrally cropped tensor of target height and width

---

## Forward Pass Summary

1. **Encoding**  
   - e1 = enc1(x)  
   - e2 = enc2(pool(e1))  
   - e3 = enc3(pool(e2))  
2. **Bottleneck**  
   - b = bottleneck(pool(e3))  
3. **Decoding with Skip Connections**  
   - d3 = up3(b) → crop e3 to d3 size → dec3(cat(d3, e3_cropped))  
   - d2 = up2(d3) → crop e2 to d2 size → dec2(cat(d2, e2_cropped))  
   - d1 = up1(d2) → crop e1 to d1 size → dec1(cat(d1, e1_cropped))  
4. **Output**  
   - out = final_conv(d1)

---

## Additional Notes

- Instance normalization stabilizes training across batches of size one.  
- Skip connections mitigate information loss and support fine detail reconstruction.  
- The model is flexible in input/output channels and can be adapted to various image dimensions.  

In [36]:
class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm2d(out_ch, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm2d(out_ch, affine=True),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        self.enc1 = UNetBlock(in_ch, 64)
        self.enc2 = UNetBlock(64, 128)
        self.enc3 = UNetBlock(128, 256)
        self.bottleneck = UNetBlock(256, 512)
        self.pool = nn.MaxPool2d(2)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = UNetBlock(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = UNetBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = UNetBlock(128, 64)
        self.final_conv = nn.Conv2d(64, out_ch, 1)

    def center_crop(self, src, tgt):
        _,_,h,w = src.shape
        _,_,th,tw = tgt.shape
        dh, dw = (h-th)//2, (w-tw)//2
        return src[:,:,dh:dh+th, dw:dw+tw]

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.up3(b)
        e3c = self.center_crop(e3, d3)
        d3 = self.dec3(torch.cat([d3, e3c], dim=1))
        d2 = self.up2(d3)
        e2c = self.center_crop(e2, d2)
        d2 = self.dec2(torch.cat([d2, e2c], dim=1))
        d1 = self.up1(d2)
        e1c = self.center_crop(e1, d1)
        d1 = self.dec1(torch.cat([d1, e1c], dim=1))
        return self.final_conv(d1)

# Training

- **Checkpoint Directory**  
  Creates `CHECKPOINT_DIR` if it doesn’t exist and initializes `best_val` to track the lowest validation loss.

- **Dataset and DataLoader**  
  - Instantiates `AudioPairDataset` with parameters for sample rate, FFT size, hop length and number of frames.  
  - Splits the full dataset into 80 % training and 20 % validation.  
  - Wraps each split in a `DataLoader` with a specified `BATCH_SIZE`, shuffling only the training loader.

- **Model, Loss, Optimizer, Scheduler**  
  - Builds a `UNet` model and moves it to the chosen `DEVICE`.  
  - Uses `L1Loss` as the loss function.  
  - Chooses the Adam optimizer with a specified learning rate.  
  - Applies a `ReduceLROnPlateau` scheduler that halves the LR if validation loss does not improve for 5 epochs.

---

## Training Loop

For each epoch from 1 to `EPOCHS`:

1. **Training Phase**  
   - Sets the model to train mode.  
   - Iterates over the training loader, moving inputs and targets to `DEVICE`.  
   - Performs a forward pass, cropping the clean target if its shape exceeds the model output.  
   - Computes loss, backpropagates gradients, and updates model parameters.  
   - Accumulates running loss and records the epoch’s average training loss.

2. **Validation Phase**  
   - Switches the model to evaluation mode and disables gradient computation.  
   - Iterates over the validation loader, repeating the forward pass and any necessary cropping.  
   - Accumulates validation loss and records the average validation loss.  
   - Steps the LR scheduler with the current `avg_val` loss.

3. **Checkpointing**  
   - **Periodic Checkpoints**: Every 5 epochs, saves a full checkpoint (model & optimizer states, epoch, and validation loss) to `CHECKPOINT_DIR`.  
   - **Best Model**: If the current `avg_val` is lower than `best_val`, updates `best_val`, and saves the model weights as `best.pth`.

4. **Logging**  
   - Prints epoch summary including training and validation losses.  
   - Prints notifications when checkpoints or a new best model are saved.

---

## Final Save

After all epochs, saves the final model weights to `MODEL_SAVE` and prints a confirmation message.  

In [37]:
def spectral_convergence(x, y):
    return torch.norm(y - x, p='fro') / torch.norm(y, p='fro')

# Preparo dataset e dataloader
ds = AudioPairDataset(CLEAR_DIR, DEGRADED_DIR,
                      SAMPLE_RATE, N_FFT, HOP_LENGTH, N_FRAMES)
n_val = int(0.2 * len(ds))
train_ds, val_ds = random_split(ds, [len(ds)-n_val, n_val])

train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE,
    shuffle=True,  num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE,
    shuffle=False, num_workers=0, pin_memory=True
)

# Modello, opt, scheduler, AMP
model     = UNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)
use_amp = (DEVICE == 'cuda')
scaler  = GradScaler() if use_amp else None

train_losses, val_losses = [], []
best_val = float('inf')

# Training loop
for epoch in range(1, EPOCHS+1):
    model.train()
    run_loss = 0.0

    for deg, clean in tqdm(train_loader, desc=f"Train Ep{epoch}/{EPOCHS} on {DEVICE}"):
        # deg, clean: [B,1,F,T]
        deg = deg.to(DEVICE)
        cl  = clean.to(DEVICE)
        optimizer.zero_grad()

        if use_amp:
            with autocast():
                out = model(deg)
                if out.shape != cl.shape:
                    cl = cl[..., :out.size(2), :out.size(3)]
                l1 = F.l1_loss(torch.log1p(out), torch.log1p(cl))
                sc = spectral_convergence(out, cl)
                loss = l1 + 0.1 * sc
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(deg)
            if out.shape != cl.shape:
                cl = cl[..., :out.size(2), :out.size(3)]
            l1 = F.l1_loss(torch.log1p(out), torch.log1p(cl))
            sc = spectral_convergence(out, cl)
            loss = l1 + 0.1 * sc
            loss.backward()
            optimizer.step()

        run_loss += loss.item()

    avg_train = run_loss / len(train_loader)
    train_losses.append(avg_train)

    # Validation
    model.eval()
    val_run = 0.0
    with torch.no_grad():
        for deg, clean in val_loader:
            deg = deg.to(DEVICE)
            cl  = clean.to(DEVICE)
            out = model(deg)
            if out.shape != cl.shape:
                cl = cl[..., :out.size(2), :out.size(3)]
            l1 = F.l1_loss(torch.log1p(out), torch.log1p(cl))
            sc = spectral_convergence(out, cl)
            val_run += (l1 + 0.1 * sc).item()

    avg_val = val_run / len(val_loader)
    val_losses.append(avg_val)
    scheduler.step(avg_val)

    # Checkpoints
    if epoch % 5 == 0:
        torch.save({'model_state': model.state_dict()},
                   os.path.join(CHECKPOINT_DIR, f'checkpoint_ep{epoch}.pth'))
    if avg_val < best_val:
        best_val = avg_val
        torch.save(model.state_dict(),
                   os.path.join(CHECKPOINT_DIR, 'best.pth'))

    print(f"Epoch {epoch}: Train Loss = {avg_train:.4f}, Val Loss = {avg_val:.4f}")

# Salva modello finale e plot
torch.save(model.state_dict(), MODEL_SAVE)
print(f"✅ Model saved to {MODEL_SAVE}")

Train Ep1/1 on mps:   6%|▋         | 50/800 [03:56<59:00,  4.72s/it]  


KeyboardInterrupt: 

# Plot

This snippet generates a line chart to visualize the model’s performance over epochs:

1. **Create a new figure**  
   Initializes a fresh plotting canvas.

2. **Plot training loss**  
   - Uses `train_losses` list  
   - Labels the curve as “Train Loss”

3. **Plot validation loss**  
   - Uses `val_losses` list  
   - Labels the curve as “Val Loss”

4. **Label axes**  
   - X-axis: “Epoch”  
   - Y-axis: “L1 Loss”

5. **Add legend**  
   Displays labels for both curves in the plot area.

6. **Set title**  
   Titles the chart “Training & Validation Loss” for context.

7. **Render the plot**  
   Calls `plt.show()` to display the figure in the output.

In [None]:
plt.figure()
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses,   label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('L1 Loss')
plt.legend()
plt.title('Training & Validation Loss')
plt.show()