In [1]:
import torch
import torchaudio
import torch.optim as optim
from torch import nn
from torchinfo import summary
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from model import Denoiser
from audio_dataset import Audio_Dataset
from trainer import Trainer
from loss import DenoiserLoss


In [2]:
config = {
    "device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "data_dir" : "data",
    "batch_size" : 100,
    "epochs" : 1000,
    "learning_rate": 0.1,
    "batches_per_epoch": 50,
    "batches_per_epoch_val": 20,
    "train": 0.7,
    "val": 0.2,
}

In [3]:
train_dataset = Audio_Dataset(config, set_type="train")
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size= 100,
    shuffle=True,
    drop_last=True,
    num_workers=2,
)

val_dataset = Audio_Dataset(config, set_type="val")
val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=100,
    shuffle=True,
    drop_last=True,
    num_workers=2,
)

In [4]:
model = Denoiser().to(config["device"])
print(model)

loss_fn = DenoiserLoss()
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, factor=0.5, patience=20, verbose=True, threshold=0.00001
)

Denoiser(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(1, 48, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(48, 96, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (1): Sequential(
      (0): Conv1d(48, 96, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(96, 192, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (2): Sequential(
      (0): Conv1d(96, 192, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(192, 384, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
  )
  (attention): ModuleList(
    (0): Sequential(
      (0): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (1): Linear(in_features=192, out_features=384, bias=True)
      (2): Linear(in_features=384, out_features=192, bias=True)
    )
  )
  (decoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(192, 384, kernel_size=(1,), 

In [5]:
summary(model)

Layer (type:depth-idx)                                       Param #
Denoiser                                                     --
├─ModuleList: 1-1                                            --
│    └─Sequential: 2-1                                       --
│    │    └─Conv1d: 3-1                                      432
│    │    └─ReLU: 3-2                                        --
│    │    └─Conv1d: 3-3                                      4,704
│    │    └─GLU: 3-4                                         --
│    └─Sequential: 2-2                                       --
│    │    └─Conv1d: 3-5                                      36,960
│    │    └─ReLU: 3-6                                        --
│    │    └─Conv1d: 3-7                                      18,624
│    │    └─GLU: 3-8                                         --
│    └─Sequential: 2-3                                       --
│    │    └─Conv1d: 3-9                                      147,648
│    │    └─ReLU: 

In [6]:
trainer = Trainer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    epochs=config["epochs"],
    config=config,
    scheduler=scheduler,
)

model = trainer.train(train_dataloader,val_dataloader)

  0%|          | 0/351 [00:00<?, ?it/s]

{'noisy': tensor([[[ 6.2354e-02,  5.7179e-02,  4.7627e-02,  ..., -1.0490e-03,
          -7.2836e-03, -9.5174e-03]],

        [[ 4.6136e-05,  2.7331e-03,  2.9101e-03,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.9993e-03,  2.4323e-03,  2.8769e-03,  ..., -7.6399e-04,
          -1.3561e-03, -1.5548e-03]],

        ...,

        [[-2.2313e-03, -2.1952e-03, -4.3396e-03,  ..., -7.3664e-02,
          -7.9593e-02, -7.8593e-02]],

        [[-1.5039e-02, -1.5624e-02, -1.3504e-02,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 5.8690e-04,  1.0670e-02,  3.0618e-02,  ..., -3.0622e-02,
          -7.3170e-02, -9.9186e-02]]], device='cuda:0'), 'clean': tensor([[[-4.1955e-04, -1.5043e-04,  5.3628e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]],

        [[ 2.3808e-05, -1.7695e-06, -1.2629e-06,  ...,  1.0458e-02,
           6.7130e-03,  3.7196e-03]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -1.4650e-03,
          -1.6874e-03, -4

  0%|          | 1/351 [00:10<1:02:46, 10.76s/it]

{'noisy': tensor([[[-0.0104, -0.0115, -0.0123,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0479, -0.0442, -0.0361,  ...,  0.0127,  0.0100,  0.0139]],

        [[-0.0034, -0.0026, -0.0017,  ...,  0.0095,  0.0088,  0.0073]],

        ...,

        [[ 0.1472,  0.2269,  0.1841,  ..., -0.0889, -0.0540,  0.0042]],

        [[ 0.0096,  0.0097,  0.0086,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0140, -0.0039,  0.0113,  ...,  0.0296,  0.0007,  0.0020]]],
       device='cuda:0'), 'clean': tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')}


  1%|          | 2/351 [00:13<33:43,  5.80s/it]  

{'noisy': tensor([[[-0.0342, -0.0226, -0.0145,  ..., -0.1246, -0.1217, -0.1290]],

        [[-0.0039, -0.0054, -0.0064,  ..., -0.0031, -0.0016, -0.0006]],

        [[-0.0752, -0.1204, -0.1263,  ..., -0.0790, -0.0638, -0.0267]],

        ...,

        [[-0.1543, -0.1187, -0.0381,  ...,  0.0003,  0.0199,  0.0031]],

        [[-0.1559, -0.1107, -0.0758,  ...,  0.0690,  0.0371,  0.0457]],

        [[-0.0528, -0.0502, -0.1647,  ..., -0.0764, -0.0687,  0.0066]]],
       device='cuda:0'), 'clean': tensor([[[-1.8913e-01, -1.6909e-01, -1.4843e-01,  ...,  4.4366e-02,
           3.7217e-02,  3.5971e-02]],

        [[-6.3947e-02, -7.6201e-02, -7.0454e-02,  ..., -3.9580e-02,
          -5.2358e-02, -5.1714e-02]],

        [[ 2.1223e-03,  4.4324e-03,  6.2671e-03,  ...,  4.1500e-03,
           2.9911e-04, -6.1057e-03]],

        ...,

        [[-8.1600e-02, -7.3400e-02, -5.5996e-02,  ...,  4.8317e-03,
           4.0129e-03,  4.4789e-03]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  7.6801

  1%|          | 3/351 [00:14<22:00,  3.79s/it]

{'noisy': tensor([[[-0.0510, -0.0564, -0.0539,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0842, -0.1014, -0.1106,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.1450,  0.0080, -0.0749,  ..., -0.0099, -0.0758, -0.1079]],

        ...,

        [[-0.1397,  0.0474,  0.0706,  ...,  0.1158,  0.1314,  0.1013]],

        [[-0.0649, -0.0751, -0.0713,  ...,  0.0000,  0.0000,  0.0000]],

        [[-0.0134, -0.0309, -0.0415,  ..., -0.0208,  0.0798,  0.1474]]],
       device='cuda:0'), 'clean': tensor([[[ 1.9561e-02,  1.5819e-02,  1.4842e-02,  ..., -9.0432e-04,
          -9.4227e-04, -1.0542e-03]],

        [[ 2.1305e-03,  2.6168e-03,  2.2400e-04,  ...,  8.2954e-03,
           8.2493e-03,  7.7089e-03]],

        [[ 7.6656e-03,  7.6906e-03,  8.1053e-03,  ..., -2.9722e-02,
          -2.8095e-02, -2.6133e-02]],

        ...,

        [[-6.0563e-02, -5.9859e-02, -5.8734e-02,  ..., -8.9420e-03,
          -9.9048e-03, -1.0188e-02]],

        [[ 5.9847e-04,  1.0869e-03,  1.6065e-03,  ...,  2.6589

  1%|          | 4/351 [00:19<24:11,  4.18s/it]

{'noisy': tensor([[[ 0.0070,  0.0121,  0.0197,  ..., -0.1822, -0.2059, -0.2343]],

        [[ 0.0136,  0.0154,  0.0156,  ..., -0.0155, -0.0089, -0.0039]],

        [[ 0.0373,  0.2343,  0.0833,  ...,  0.0705,  0.0097, -0.0684]],

        ...,

        [[-0.0419, -0.0435, -0.0390,  ...,  0.0223,  0.0181,  0.0171]],

        [[ 0.0917,  0.1043,  0.1127,  ..., -0.0213, -0.0311, -0.0394]],

        [[ 0.0787,  0.0722,  0.0685,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0'), 'clean': tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')}


  1%|          | 4/351 [00:21<31:10,  5.39s/it]


KeyboardInterrupt: 