In [1]:
import os
import sys
sys.path.append("..")
import logging
from tqdm import tqdm
import numpy as np

import torch
from torch import nn
from torch.utils import data
import torch.optim as optim

import hydra
from omegaconf import DictConfig

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from mdx.models.utils.model_utils import get_model_class
from mdx.dataloaders.audiodataloader import AudioDataset

from torchsummary import summary
import torchmetrics
from torchmetrics.audio import SignalDistortionRatio, SignalNoiseRatio, ScaleInvariantSignalDistortionRatio
from torchmetrics.regression import MeanSquaredError

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from train_pl import Configuration
config = Configuration

Global seed set to 1996


In [3]:
# Function for setting the seed
pl.seed_everything(Configuration.general.seed)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

Global seed set to 1996


In [6]:
ds = AudioDataset(
    paths=config.dataset.paths,
    sampling_rate=config.dataset.sampling_rate,
    sources=config.dataset.sources,
    targets=config.dataset.targets,
    n_samples=44100*2,
    debug=True,
    pre_init=True,
    )
dl = torch.utils.data.DataLoader(ds, batch_size=4, num_workers=4)

100%|██████████| 10/10 [03:45<00:00, 22.51s/it]


In [8]:
next(iter(dl))

{'input': tensor([[[-0.0087, -0.0154, -0.0345,  ...,  0.2963,  0.3022,  0.3101],
          [-0.0343, -0.0543, -0.0886,  ...,  0.2556,  0.2747,  0.2947]],
 
         [[ 0.2423,  0.1734,  0.1118,  ...,  0.0082, -0.0134, -0.0425],
          [-0.0085, -0.0552, -0.0240,  ...,  0.0961,  0.0849,  0.0656]],
 
         [[ 0.0223, -0.0102, -0.0122,  ...,  0.0773,  0.1697,  0.2008],
          [ 0.0100, -0.0190, -0.0170,  ...,  0.1178,  0.1763,  0.2176]],
 
         [[-0.0995, -0.0969, -0.0938,  ..., -0.0335, -0.0140,  0.0041],
          [-0.0241, -0.0251, -0.0246,  ..., -0.0133, -0.0194, -0.0166]]]),
 'output': {'bass': tensor([[[-0.0777, -0.0774, -0.0771,  ..., -0.0045, -0.0045, -0.0045],
           [-0.0777, -0.0774, -0.0771,  ..., -0.0045, -0.0045, -0.0045]],
  
          [[ 0.0602,  0.0600,  0.0598,  ...,  0.0089,  0.0091,  0.0091],
           [ 0.0611,  0.0609,  0.0608,  ...,  0.0085,  0.0086,  0.0087]],
  
          [[ 0.0040,  0.0041,  0.0041,  ...,  0.0151,  0.0149,  0.0147],
           [

In [9]:
model = get_model_class(config.model.model_name)

In [10]:
model = model(input_channels=2, target_sources_num=len(config.dataset.sources))

In [11]:
summary(model=model, input_size=(2, 44100*2), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1            [-1, 1025, 201]       2,099,200
            Conv1d-2            [-1, 1025, 201]       2,099,200
              STFT-3  [[-1, 1, 201, 1025], [-1, 1, 201, 1025]]               0
       BatchNorm2d-4         [-1, 1025, 201, 2]           2,050
       BatchNorm2d-5          [-1, 8, 256, 256]              16
            Conv2d-6         [-1, 32, 256, 256]           2,304
       BatchNorm2d-7         [-1, 32, 256, 256]              64
            Conv2d-8         [-1, 32, 256, 256]           9,216
            Conv2d-9         [-1, 32, 256, 256]             288
     ConvBlockRes-10         [-1, 32, 256, 256]               0
      BatchNorm2d-11         [-1, 32, 256, 256]              64
           Conv2d-12         [-1, 32, 256, 256]           9,216
      BatchNorm2d-13         [-1, 32, 256, 256]              64
           Conv2d-14    

In [12]:
loss_fn = torchmetrics.regression.MeanAbsoluteError()
loss_fn(torch.rand((1,2,44100)), torch.rand((1,2,44100)))  # This needs to be done for each target.

tensor(0.3330)

In [13]:
class LossFn(nn.Module):
    def __init__(self, function, order=['bass', 'drums', 'other', 'vocals']):
        super().__init__()
        self.function = function
        self.order = order

    def forward(self, sources, targets):
        sum_ls = 0
        if type(sources)==list:
            for source, target in zip(sources, targets):
                sum_ls += self.function(source, target)
            sum_ls /= len(sources)
        else:
            sum_ls = self.function(sources, targets)
        return sum_ls

In [25]:
loss_fn = LossFn(torchmetrics.regression.MeanAbsoluteError())
loss_fn([torch.rand((1,2,44100)) for i in range(4)], [torch.rand((1,2,44100)) for i in range(4)])

tensor(0.3333)

In [28]:
## One forward pass
model.eval()
inp = next(iter(dl))
# inp['input'].shape, torch.tensor(np.hstack(next(iter(dl))['output'].values())).shape
with torch.no_grad():
    out = model(inp['input'])
print(loss_fn(torch.tensor(np.hstack(list(next(iter(dl))['output'].values()))), out['waveform']))

tensor(0.0894)


In [32]:
torch.tensor(np.hstack(list(next(iter(dl))['output'].values())))[:1,:1,:10], out['waveform'][ :1, :1, :10]

(tensor([[[-0.0372, -0.0382, -0.0393, -0.0404, -0.0415, -0.0427, -0.0439,
           -0.0451, -0.0463, -0.0475]]]),
 tensor([[[ 0.0756,  0.1533,  0.0892,  0.0506, -0.0041, -0.0062,  0.0128,
            0.0567,  0.0569,  0.0717]]]))

In [36]:
def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int):
    r"""Get lr_lambda for LambdaLR. E.g.,

    .. code-block: python
        lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)

        from torch.optim.lr_scheduler import LambdaLR
        LambdaLR(optimizer, lr_lambda)

    Args:
        warm_up_steps: int, steps for warm up
        reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps

    Returns:
        learning rate: float
    """
    if step <= warm_up_steps:
        return step / warm_up_steps
    else:
        return 0.9 ** (step // reduce_lr_steps)

In [39]:
# Hyperparams definition
n_steps = 5000
optimizer = optim.Adam(
            model.parameters(),
            lr=1e-3,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=0.0,
            amsgrad=True,
        )
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)

In [None]:
progress_bar = tqdm(range(n_steps))
metric = {
    "sdr": SignalDistortionRatio(),
    "snr": SignalNoiseRatio(),
    "sisdr": ScaleInvariantSignalDistortionRatio()
}
device = config.device

losses = []
for epoch in range(n_steps):
    progress_bar.set_description(f"Step is {epoch}")
    model.train()
    train_score = [0]
    for bid, batch in enumerate(dl):
        input_wave = batch['input'].to(device)
        target_waves = torch.tensor(np.hstack([batch['output'][i] for i in config.dataset.targets]))

        model.zero_grad()
        outputs = model(input_wave)['waveform']
        loss = loss_fn(target_waves, outputs)
        loss.backward()
        losses.append(loss.item())
        
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1.0)
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        ts_score = metric['sdr'](target_waves, outputs)
        train_score.append(ts_score)
        print(loss.item(), np.mean(train_score))
        progress_bar.set_postfix(loss=f"{loss.item()}", Training=f"{np.mean(train_score)}")
        progress_bar.update(1)
    progress_bar.close()
        # break
    break
        
#         if bid%500==0:
#             scores=[0]
#             model.eval()
#             for batch in dl:
#                 batch = {k: v.to(device) for k, v in batch.items()}
#                 with torch.no_grad():
#                     outputs = model(**batch)

#                 logits = outputs.logits
#                 predictions = outputs.logits.argmax(-1)
#                 fscore = f1_score(logits.argmax(-1).detach().cpu().numpy().reshape(-1), 
#                          batch["labels"].cpu().numpy().reshape(-1), 
#                          average=f1_avg)
# #                 fscore = flat_accuracy(logits.argmax(-1).detach().cpu().numpy().reshape(-1), 
# #                                        batch["labels"].cpu().numpy().reshape(-1))
#                 scores.append(fscore)
#                 progress_bar.set_postfix(loss=loss.item(), Training=np.mean(train_score), Testing=np.mean(scores))
#             model.train()
# print(np.mean(scores))

Step is 0:   0%|          | 0/5000 [00:41<?, ?it/s]
Step is 0:   0%|          | 0/5000 [00:00<?, ?it/s]

: 

: 