In [1]:
import torch
import torchaudio
from torch import nn
from torch.nn import functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

In [2]:
torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.device(0)
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 3060 Ti'

In [None]:
NOISY_SAMPLES = os.path.join()

In [12]:
class Denoiser(torch.nn.Module):
  # D: no of encoder layers
  # H: no. of output channels in first layer
  # K: kernel size
  # S: stride
  def __init__(
  self, 
  n_layers, 
  output_channels, 
  chin=1,
  chout=1,
  hidden=48,
  depth=5,
  N_attention = 3,
  kernel_size=8,
  stride=4,
  causal=True,
  resample=4,
  growth=2,
  max_hidden=10_000,
  normalize=True,
  glu=True,
  rescale=0.1,
  floor=1e-3,
  sample_rate=22_050
  ):
    super(Denoiser, self).__init__()
    self.D = n_layers
    self.H = output_channels
    self.K = kernel_size
    self.S = kernel_size // 2

    self.encoder = nn.ModuleList([])
    self.decoder = nn.ModuleList([])
    self.attention = nn.ModuleList([])
    activation = nn.GLU(1) if glu else nn.ReLU()
    ch_scale = 2 if glu else 1

    for index in range(depth):

      encode = []
      encode += [
          nn.Conv1d(chin, hidden, kernel_size, stride),
          nn.ReLU(),
          nn.Conv1d(hidden, hidden * ch_scale, 1),
          activation,
      ]
      self.encoder.append(nn.Sequential(*encode))

      decode = []
      decode += [
          nn.Conv1d(hidden, ch_scale * hidden, 1),
          activation,
          nn.ConvTranspose1d(hidden, chout, kernel_size, stride),
      ]
      if index > 0:
          decode.append(nn.ReLU())
      self.decoder.insert(0, nn.Sequential(*decode))
      chout = hidden
      chin = hidden
      hidden = min(int(growth * hidden), max_hidden)

    for i in range(N):
      attention = []
      attention += [
        nn.MultiheadAttention(embed_dim=chin, num_heads=8),
        nn.Linear(chin, 2*chin),
        nn.Linear(2*chin, chin)
      ]
      self.attention.append(nn.Sequential(*attention))

  def forward(self, input):
    length = input.shape[-1]
    skip_outputs = []
    for encoder in self.encoder:
      x = encoder(x)
      skip_outputs.append(x)

    # x = x.permute(2, 0, 1)
    # x, _ = self.attention(x)
    # x = x.permute(1, 2, 0)
    for attention in self.attention:
      x, _ = attention(x, x, x)
    
    for decode in self.decoder:
        skip = skip_outputs.pop(-1)
        x = x + skip[..., :x.shape[-1]]
        x = decode(x)

    return x

In [None]:
class Trainer(object):
    def __init__(self, model, loss_fn, optimizer, epochs, scheduler = None):
        self.model = model
        self.loss = {"train":[], "val":[]}
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.epochs = epochs
        self.scheduler = scheduler
        self.checkpoint_frequency = 100
        self.early_stopping_epochs = 10
        self.early_stopping_avg = 10
        self.early_stopping_precision = 5

    def train(self, train_dataloader, val_dataloader):
        for epoch in range(self.epochs):
            self._epoch_train(train_dataloader)
            self._epoch_eval(val_dataloader)
            print(
                "Epoch: {}/{}, Train Loss={}, Val Loss={}".format(
                    epoch + 1,
                    self.epochs,
                    np.round(self.loss["train"][-1], 10),
                    np.round(self.loss["val"][-1], 10),
                )
            )

            # reducing LR if no improvement
            if self.scheduler is not None:
                self.scheduler.step(self.loss["train"][-1])

            # saving model
            if (epoch + 1) % self.checkpoint_frequency == 0:
                torch.save(
                    self.model.state_dict(), "model_{}".format(str(epoch + 1).zfill(3))
                )

            # early stopping
            if epoch < self.early_stopping_avg:
                min_val_loss = np.round(np.mean(self.loss["val"]), self.early_stopping_precision)
                no_decrease_epochs = 0

            else:
                val_loss = np.round(
                    np.mean(self.loss["val"][-self.early_stopping_avg:]), 
                                    self.early_stopping_precision
                )
                if val_loss >= min_val_loss:
                    no_decrease_epochs += 1
                else:
                    min_val_loss = val_loss
                    no_decrease_epochs = 0
                    #print('New min: ', min_val_loss)

            if no_decrease_epochs > self.early_stopping_epochs:
                print("Early Stopping")
                break

        torch.save(self.model.state_dict(), "model_final")
        return self.model


    def _epoch_train(self, dataloader):
        self.model.train()
        running_loss = []

        for i, data in enumerate(dataloader, 0):
            inputs = data["noisy"].to(self.device)
            labels = data["clean"].to(self.device)

            self.optimizer.zero_grad()
            running_loss = []

            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss.append(loss.item())

            if i == self.batches_per_epoch:
                epoch_loss = np.mean(running_loss)
                self.loss["train"].append(epoch_loss)
                break

    def _epoch_eval(self, dataloader):
        self.model.eval()
        running_loss = []

        with torch.no_grad():
            for i, data in enumerate(dataloader, 0):
                inputs = data["noisy"].to(self.device)
                labels = data["clean"].to(self.device)

                outputs = self.model(inputs)
                loss = self.loss_fn(outputs, labels)

                running_loss.append(loss.item())

                if i == self.batches_per_epoch_val:
                    epoch_loss = np.mean(running_loss)
                    self.loss["val"].append(epoch_loss)
                    break