In [2]:
import sys
sys.path.append('../src')

from hydra import compose, initialize

import pytorch_lightning as pl
import torch
import torch.nn as nn

from train import initialize_loaders, initialize_model, initialize_featurizer

from tqdm import tqdm

In [3]:
with initialize(version_base=None, config_path="../src/conf", job_name="test_app"):
    cfg = compose(config_name="config")
#     print(OmegaConf.to_yaml(cfg))

In [4]:
# train_loader, val_loader = initialize_loaders(cfg)
# featurizer, inverse_featurizer = initialize_featurizer(cfg)
# model, opt, sch = initialize_model(cfg)

In [7]:
from data import SAD
import musdb
import torch

from pathlib import Path
from typing import Iterable

In [8]:
def prepare_save_line(
    track_name: str, 
    start_indices: torch.Tensor, 
    window_size: int
) -> Iterable[str]:
    for i in start_indices:
        save_line = f"{track_name}\t{i}\t{i + window_size}\n"
        yield save_line
        
        
def save_to_file(
    file_path: str,
    target: str,
    sad: SAD, 
):
    with open(file_path, 'w') as wf:
        for track in tqdm(db):
            y = track.targets[target].audio.T
            y = torch.tensor(
                y, dtype=torch.float32
            )
            indices = sad.calculate_salient_indices(y)
            for l in prepare_save_line(track.name, indices, sad.window_size):
                wf.write(l)
    return None

In [15]:
targets = ['vocals']
db_dir = '../../../datasets/musdb18hq'
directory = Path('../src/files/')
subset = 'train' # 'test'
split = 'train' # 'valid'    

In [16]:
db = musdb.DB(
    root=db_dir,
    download=False,
    subsets=subset,
    split=split,
    is_wav=True
)
sad = SAD(**cfg.sad)

In [18]:
for target in targets:
    if subset == split == 'train':
        file_path = directory / f"{target}_train.txt"
    elif subset == 'train' and split == 'valid':
        file_path = directory / f"{target}_valid.txt"
    elif subset == 'test':
        file_path = directory / f"{target}_test.txt"
    save_to_file(file_path, target, sad)

 30%|███       | 26/86 [00:17<00:40,  1.49it/s]


KeyboardInterrupt: 

In [4]:
class PLModel(pl.LightningModule):
    def __init__(
        self, 
        model: nn.Module,
        featurizer: nn.Module,
        inverse_featurizer: nn.Module,
    ):
        super().__init__()
        
        # featurizers 
        self.featurizer = featurizer
        self.inverse_featurizer = inverse_featurizer
                
        # model
        self.model = model
        
        # losses
        self.mae_specR = nn.L1Loss() 
        self.mae_specI = nn.L1Loss() 
        self.mae_time = nn.L1Loss() 
        
        # opts
        
    
    def on_after_batch_transfer(
        self, batch, dataloader_idx
    ):
        for k in batch:
            batch[k] = self.featurizer(batch[k])
        return batch
    
    def training_step(
        self, batch, batch_idx
    ):
        mix, tgt = batch['mix'], batch['tgt']
        mix = self.model(mix)
        loss = self.loss(mix, tgt)
        return loss
    
    def loss(self, mix, tgt):
        # frequence domain
        lossR = self.mae_specR(mix.real, tgt.real)
        lossI = self.mae_specI(mix.imag, tgt.imag)
        
        # time domain
        mix = self.inverse_featurizer(mix)
        tgt = self.inverse_featurizer(tgt)
        lossT = self.mae_time(mix, tgt) 
        
        # total
        loss = lossR + lossI + lossT
        
        return loss
        

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(), 
            lr=1e-3
        )

In [5]:
plmodel = PLModel(
    model, 
    featurizer,
    inverse_featurizer
)
trainer = pl.Trainer(
    accelerator="cpu",
    devices=1,
    fast_dev_run=True
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


In [6]:
trainer.fit(
    plmodel, 
    train_dataloaders=train_loader,
)


  | Name               | Type               | Params
----------------------------------------------------------
0 | featurizer         | Spectrogram        | 0     
1 | inverse_featurizer | InverseSpectrogram | 0     
2 | model              | BandSplitRNN       | 11.7 M
3 | mae_specR          | L1Loss             | 0     
4 | mae_specI          | L1Loss             | 0     
5 | mae_time           | L1Loss             | 0     
----------------------------------------------------------
11.7 M    Trainable params
0         Non-trainable params
11.7 M    Total params
46.687    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/amanturamatov/opt/anaconda3/envs/SourceSeparationBandSplitRNN/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/amanturamatov/opt/anaconda3/envs/SourceSeparationBandSplitRNN/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/amanturamatov/opt/anaconda3/envs/SourceSeparationBandSplitRNN/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/amanturamatov/PythonProjects/projects/SourceSeparationBandSplitRNN/notebooks/../src/data/dataset.py", line 67, in __getitem__
    mix, tgt = self.prepare_fragments(
  File "/Users/amanturamatov/PythonProjects/projects/SourceSeparationBandSplitRNN/notebooks/../src/data/dataset.py", line 54, in prepare_fragments
    tgt_frags, mask = self.sad(tgt_audio)
  File "/Users/amanturamatov/PythonProjects/projects/SourceSeparationBandSplitRNN/notebooks/../src/data/preprocessing.py", line 95, in __call__
    y_salient = self.calculate_salient(y, segment_saliency_mask)
  File "/Users/amanturamatov/PythonProjects/projects/SourceSeparationBandSplitRNN/notebooks/../src/data/preprocessing.py", line 73, in calculate_salient
    y = y[:, mask, ...].view(C, D1, D2*D3)
RuntimeError: shape '[2, 78, 264600]' is invalid for input of size 31752000
