In [7]:
import torch
import torchaudio
import torch.optim as optim
from torch import nn
from torchinfo import summary
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from model import Denoiser
from audio_dataset import Audio_Dataset
from trainer import Trainer
from loss import DenoiserLoss


In [8]:
config = {
    "device" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "data_dir" : "data",
    "batch_size" : 32,
    "epochs" : 1000,
    "learning_rate": 0.0005,
    "batches_per_epoch": 1096,
    "batches_per_epoch_val": 311,
    "train": 0.7,
    "val": 0.2,
}

In [9]:
train_dataset = Audio_Dataset(config, set_type="train")
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size= config["batch_size"],
    shuffle=True,
    drop_last=True,
    num_workers=2,
)

val_dataset = Audio_Dataset(config, set_type="val")
val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    drop_last=True,
    num_workers=2,
)

In [10]:
trained_weights = './trained_weights/model_030'

model = Denoiser(depth=4, N_attention=2)
model.load_state_dict(torch.load(trained_weights))
model.to(config["device"])
print(model)

loss_fn = DenoiserLoss()
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer, factor=0.5, patience=20, verbose=True, threshold=0.00001
)

Denoiser(
  (encoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(1, 48, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(48, 96, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (1): Sequential(
      (0): Conv1d(48, 96, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(96, 192, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
    (2): Sequential(
      (0): Conv1d(96, 192, kernel_size=(8,), stride=(4,))
      (1): ReLU()
      (2): Conv1d(192, 384, kernel_size=(1,), stride=(1,))
      (3): GLU(dim=1)
    )
  )
  (attention): ModuleList(
    (0): Sequential(
      (0): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (1): Linear(in_features=192, out_features=384, bias=True)
      (2): Linear(in_features=384, out_features=192, bias=True)
    )
  )
  (decoder): ModuleList(
    (0): Sequential(
      (0): Conv1d(192, 384, kernel_size=(1,), 

In [11]:
summary(model)

Layer (type:depth-idx)                                       Param #
Denoiser                                                     --
├─ModuleList: 1-1                                            --
│    └─Sequential: 2-1                                       --
│    │    └─Conv1d: 3-1                                      432
│    │    └─ReLU: 3-2                                        --
│    │    └─Conv1d: 3-3                                      4,704
│    │    └─GLU: 3-4                                         --
│    └─Sequential: 2-2                                       --
│    │    └─Conv1d: 3-5                                      36,960
│    │    └─ReLU: 3-6                                        --
│    │    └─Conv1d: 3-7                                      18,624
│    │    └─GLU: 3-8                                         --
│    └─Sequential: 2-3                                       --
│    │    └─Conv1d: 3-9                                      147,648
│    │    └─ReLU: 

In [6]:
trainer = Trainer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    epochs=config["epochs"],
    config=config,
    scheduler=scheduler,
)

model = trainer.train(train_dataloader,val_dataloader)

100%|██████████| 1097/1097 [04:33<00:00,  4.01it/s]


Training epoch complete


100%|██████████| 313/313 [01:15<00:00,  4.16it/s]


Eval epoch complete
Epoch: 1/1000, Train Loss=0.3457496464, Val Loss=0.4443694382


100%|██████████| 1097/1097 [04:27<00:00,  4.09it/s]


Training epoch complete


100%|██████████| 313/313 [01:15<00:00,  4.14it/s]


Eval epoch complete
Epoch: 2/1000, Train Loss=0.4201224148, Val Loss=0.3738174652


100%|██████████| 1097/1097 [04:20<00:00,  4.20it/s]


Training epoch complete


100%|██████████| 313/313 [01:11<00:00,  4.39it/s]


Eval epoch complete
Epoch: 3/1000, Train Loss=0.3586572409, Val Loss=0.3673672948


100%|██████████| 1097/1097 [04:06<00:00,  4.45it/s]


Training epoch complete


100%|██████████| 313/313 [01:11<00:00,  4.35it/s]


Eval epoch complete
Epoch: 4/1000, Train Loss=0.3521054089, Val Loss=0.3677671105


100%|██████████| 1097/1097 [04:06<00:00,  4.45it/s]


Training epoch complete


100%|██████████| 313/313 [01:12<00:00,  4.32it/s]


Eval epoch complete
Epoch: 5/1000, Train Loss=0.3871172965, Val Loss=0.371842296


100%|██████████| 1097/1097 [04:05<00:00,  4.46it/s]


Training epoch complete


100%|██████████| 313/313 [01:11<00:00,  4.41it/s]


Eval epoch complete
Epoch: 6/1000, Train Loss=0.3483588994, Val Loss=0.370568738


100%|██████████| 1097/1097 [8:59:25<00:00, 29.50s/it]      


Training epoch complete


100%|██████████| 313/313 [01:13<00:00,  4.24it/s]


Eval epoch complete
Epoch: 7/1000, Train Loss=0.3404135406, Val Loss=0.366237641


100%|██████████| 1097/1097 [04:20<00:00,  4.21it/s]


Training epoch complete


100%|██████████| 313/313 [01:13<00:00,  4.26it/s]


Eval epoch complete
Epoch: 8/1000, Train Loss=0.387498796, Val Loss=0.3737628617


100%|██████████| 1097/1097 [04:16<00:00,  4.27it/s]


Training epoch complete


100%|██████████| 313/313 [01:10<00:00,  4.47it/s]


Eval epoch complete
Epoch: 9/1000, Train Loss=0.3492754698, Val Loss=0.3631069736


100%|██████████| 1097/1097 [03:53<00:00,  4.70it/s]


Training epoch complete


100%|██████████| 313/313 [01:08<00:00,  4.60it/s]


Eval epoch complete
Epoch: 10/1000, Train Loss=0.3289176524, Val Loss=0.3638101881


100%|██████████| 1097/1097 [03:53<00:00,  4.71it/s]


Training epoch complete


100%|██████████| 313/313 [01:07<00:00,  4.60it/s]


Eval epoch complete
Epoch: 11/1000, Train Loss=0.3623604476, Val Loss=0.3899573355


100%|██████████| 1097/1097 [03:53<00:00,  4.70it/s]


Training epoch complete


100%|██████████| 313/313 [01:09<00:00,  4.53it/s]


Eval epoch complete
Epoch: 12/1000, Train Loss=0.391418308, Val Loss=0.3664561693


100%|██████████| 1097/1097 [03:52<00:00,  4.72it/s]


Training epoch complete


100%|██████████| 313/313 [01:07<00:00,  4.62it/s]


Eval epoch complete
Epoch: 13/1000, Train Loss=0.3452591598, Val Loss=0.3646861727


100%|██████████| 1097/1097 [04:03<00:00,  4.51it/s]


Training epoch complete


  0%|          | 1/313 [1:32:13<479:32:38, 5533.20s/it]


KeyboardInterrupt: 