In [None]:
import copy
import json
import numpy as np
import matplotlib.pyplot as plt

from utils import ConfigObject
from utils import reserve_pop
from utils import id_generator
from utils import writer
from utils import LibriSpeechGenerator

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
import torch.nn.functional as F

from parts import VSConvBlock
from parts import DownSamplingBlock
from parts import UpSamplingBlock
from parts import OutBlock

In [None]:
import time
import os
print(os.getcwd())
os.chdir("..")

## Parameters

In [None]:
jsonConfig = {
    "test_platform": False,
    "ds_prop": 0.25,    # Dataset proportion
    "sr": 16000,        # Sampling rate
    # "n_samples": 65536, # 4.096 seconds
    "n_samples": 16000, # 1 second
    
    "n_channels": 1,    # Number of input channels (mono)
    "n_classes": 1,     # Number of output channels (mono)
    "depth": 5,         # Depth (number of encoder/decoder layers)
    "fsize": 24,        # Base filter size for convolution kernels
    "moffset": 8,       # Channel offset for calculating layer dimensions
    
    "batch_size": 32,
    "epochs": 25,
    "shuffle": True,
    "num_workers": 8,   # Number of worker threads for data loading
    "verbose": 100,     # Logging frequency (per how many batches)

    "checkpoint_path": "ae_checkpoint.pt",
    "model_path": "ae_last_model.pt",

    "save_last_batch": True,
    "history_path": "ae.json"  # Path for saving training history
}

# Convert the jsonConfig dictionary to a configuration object
# for easier access to parameters in the code
config = ConfigObject(**jsonConfig)

In [4]:
# Data Loaders
_params = {
    'batch_size': config.batch_size,
    'shuffle': config.shuffle,
    'num_workers': config.num_workers
}

In [5]:
class EarlyStopping:
    def __init__(self, patience=10, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

## Import Data

In [6]:
# print(os.getcwd())
X_train = torch.load("train.pt")
X_val = torch.load("val.pt")

In [None]:
X_train.shape, X_val.shape
# print(os.getcwd())

## Data Generators

In [8]:
# Data Generators
lsg = LibriSpeechGenerator(config, X_train, mode="ae")
lsg_val = LibriSpeechGenerator(config, X_val, mode="ae")

ls_generator = data.DataLoader(lsg, **_params)
ls_val_generator = data.DataLoader(lsg_val, **_params)

## Model

In [None]:
class SEWUNet(nn.Module):
    def __init__(self, config, fd=15, fu=5):
        """Speech Enhancenment using Wave-U-Net"""
        super(SEWUNet, self).__init__()

        # Hyperparameters
        self.n_channels = config.n_channels # Number of input channels
        self.n_classes = config.n_classes   # Number of output channels
        self.depth = config.depth           # Number of encoder/decoder layers
        self.fsize = config.fsize           # Base filter size for convolution kernels
        self.moffset = config.moffset       # Channel offset for dimension calculation
        self.fd = fd                        # Downsampling convolution kernel size
        self.fu = fu                        # Upsampling convolution kernel size

        # Generate the list of in, out channels for the encoder
        self.enc_filters = [self.n_channels]
        self.enc_filters += [self.fsize * i + self.moffset
                             for i in range(1, self.depth + 1)]
        self.n_encoder = zip(self.enc_filters, self.enc_filters[1:])

        # Bottleneck block sizes
        mid_in = self.fsize * self.depth + self.moffset
        mid_out = self.fsize * (self.depth + 1) + self.moffset

        # Generate the list of in, out channels for the decoder
        self.out_dec = reserve_pop(self.enc_filters)
        self.in_dec = [mid_out + self.enc_filters[-1]]
        self.in_dec += [self.out_dec[i] + self.out_dec[i + 1]
                        for i in range(self.depth - 1)]
        self.n_decoder = zip(self.in_dec, self.out_dec)

        # Architecture and parameters
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        # Build the encoder part of the U-net architecture
        for i, (in_ch, out_ch) in enumerate(self.n_encoder):
            self.encoder.append(DownSamplingBlock(
                in_ch=in_ch,
                out_ch=out_ch,
                kernel_size=self.fd,
                padding=self.fd // 2,
                activation=nn.LeakyReLU(0.1))
            )

        # Bottleneck block for the U-net
        self.mid_block = VSConvBlock(
            in_ch=mid_in,
            out_ch=mid_out,
            kernel_size=self.fd,
            padding=self.fd // 2,
            activation=nn.LeakyReLU(0.1))

        # Build the decoder part of the U-net architecture
        for in_ch, out_ch in self.n_decoder:
            self.decoder.append(UpSamplingBlock(
                in_ch=in_ch,
                out_ch=out_ch,
                kernel_size=self.fu,
                padding=self.fu // 2,
                activation=nn.LeakyReLU(0.1),
                mode="linear")  # linear interpolation
            )

        # Output block
        out_ch = self.out_dec[-1] + 1
        self.out_block = OutBlock(
            in_ch=out_ch,
            out_ch=self.n_classes,
            activation=nn.Tanh()
        )

    def forward(self, x):
        """"""
        enc = []    # Store encoder outputs for skip connections in decoder
        net_in = copy.copy(x)   # Copy of input for output layer skip connection

        # Encoder part
        for i in range(self.depth):
            x, xi = self.encoder[i](x)
            enc.append(xi)

        # Bottleneck layer
        x = self.mid_block(x)

        # Decoder part
        for i in range(self.depth):
            x = self.decoder[i](x, enc.pop())

        # Output layer
        x = self.out_block(x, net_in)

        return x

## Trainer

In [10]:
model = SEWUNet(config)

In [11]:
# Training parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
history = {'loss': [], 'SNR': [], 'val_loss': [], 'val_SNR': []}

In [12]:
def CustomMetric():
    """Calculate the SNR of X and Y"""
    def SNR(X, Y):
        n = X.shape[2]
        return torch.mean(10 * torch.log10(
            (torch.norm(Y, dim=2)**2 / n) /
            (torch.norm(X - Y, dim=2)**2 / n)
        ))
    return SNR

In [13]:
# Build optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-6,
    betas=(0.9, 0.999))

# lr_scheduler = torch.optim.lr_scheduler(optimizer)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

# Loss and metric
m_loss = nn.L1Loss()    # MAE Loss
m_snr = CustomMetric()

early_stopping = EarlyStopping(patience=5, verbose=True)

In [14]:
# Copy model to device
model = model.to(device)

In [None]:
# Calculate the number of trainable parameters in the model
# The number of trainable parameters in the computational model
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
sum([np.prod(p.size()) for p in model_parameters])

In [16]:
# Function to display trainning metrics
def _display_metrics(epoch, it, steps, loss, metric):
    print("Epoch [{:02d}/{:02d}]".format(
        epoch + 1, config.epochs), end=", ")

    print("Step [{:03d}/{:03d}]".format(
        it + 1, steps), end=", ")

    print("Loss: {}, SNR: {}".format(
        loss, metric))

In [None]:
# Train the model over epochs
steps = len(ls_generator)
start_time = time.time()
best_val_loss = float('inf')
best_val_metric = float('-inf')

for epoch in range(config.epochs):
    # training and val metrics for all data
    loss, metric = 0.0, 0.0
    val_loss, val_metric = 0.0, 0.0

    # ======================== Training ============================= #
    for i, (local_batch, local_labels) in enumerate(ls_generator):
        # Transfer to Device
        local_batch = local_batch.to(device)
        local_labels = local_labels.to(device)

        # Set gradient
        optimizer.zero_grad()

        # Forward pass, backward pass, optimize
        outputs = model(local_batch)
        loss_batch = m_loss(outputs, local_labels)
        batch_metric = m_snr(outputs, local_labels)
        loss_batch.backward()
        optimizer.step()

        # Compute metrics to all batch
        loss += loss_batch.item() * len(local_batch)
        metric += batch_metric.item() * len(local_batch)

        # Print the loss every "verbose" batches
        if (i + 1) % config.verbose == 0:
            _display_metrics(epoch, i, steps,
                loss_batch.item(), batch_metric.item())

    # Compute the statistics of the last epoch and save to history
    history['loss'].append(loss / len(lsg))
    history['SNR'].append(metric / len(lsg))

    # Print Validation statistics
    print(".:. Training metrics =", end=" ")
    print("Loss: {}, SNR: {}".format(loss / len(lsg), metric / len(lsg)))
    
    # ======================= Validation ============================ #
    with torch.no_grad():
        for local_batch, local_labels in ls_val_generator:
            # Transfer to device
            local_batch = local_batch.to(device)
            local_labels = local_labels.to(device)

            # Predict, get loss and metric
            outputs = model(local_batch)
            val_loss += m_loss(outputs, local_labels).item() \
                * len(local_batch)

            val_metric += m_snr(outputs, local_labels).item() \
                * len(local_batch)

        val_loss /= len(lsg_val)
        val_metric /= len(lsg_val)
                
    # Print Validation statistics
    print(".:. Validation metrics =", end=" ")
    print("Loss: {}, SNR: {}".format(val_loss, val_metric))

    # Compute the metrics and loss of last batch and save to history
    history['val_loss'].append(val_loss)
    history['val_SNR'].append(val_metric)
    lr_scheduler.step(val_loss)
    
    # Save the best model
    if val_loss < best_val_loss or (val_loss == best_val_loss and val_metric >= best_val_metric):
        best_val_loss = val_loss
        best_val_metric = val_metric
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': val_loss
        }, config.checkpoint_path)

    
    # 早停机制
    early_stopping(val_loss)
    if early_stopping.early_stop: 
        print("Early stopping")
        break
    
elapsed_time = time.time() - start_time
history['elapsed_time'] = elapsed_time

In [18]:
torch.save({
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': val_loss
}, config.model_path)

In [None]:
# Plot network history
plt.figure(figsize=(10,10))
plt.plot(history['loss'], label='train')
plt.plot(history['val_loss'], label='val')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.title('Traning history')
plt.legend()
plt.show()

In [None]:
# Plot network history
plt.figure(figsize=(10,10))
plt.plot(history['SNR'], label='train')
plt.plot(history['val_SNR'], label='val')
plt.ylabel('SNR')
plt.xlabel('Epoch')
plt.title('Traning history')
plt.legend()
plt.show()

In [21]:
# Save history to a JSON file
with open(config.history_path, 'w') as fp:
    json.dump(history, fp)