In [1]:
import torch
import copy
import numpy as np
from tqdm import tqdm
import os
from utils.prune_utils.concern_identification import ConcernIdentification
from utils.prune_utils.weight_remover import WeightRemover
from utils.dataset import download_musdb, load_data
from utils.train import train_model, valid, load_not_compatible_weights, get_model_from_config
from utils.config_utils import load_config

In [2]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False  # Fix possible slow down with dilation convolutions
torch.multiprocessing.set_start_method('spawn')
train_data, test_data = download_musdb()

In [3]:
device = 'cuda:0'
cfg = load_config('train_config.yaml')
cfg.model_type = 'mdx23c'
cfg.config_path = 'Configs/mdx23c_config.yaml'
cfg.results_path = 'Results/'
cfg.data_path = 'Datasets/musdb18hq/train'
cfg.num_workers = 0
cfg.valid_path = 'Datasets/musdb18hq/valid'
cfg.seed = 44
cfg.start_check_point = 'Results/model_mdx23c(before prune).ckpt'

In [4]:
mdx23c, mdx23c_config = get_model_from_config(cfg.model_type, cfg.config_path)
load_not_compatible_weights(mdx23c, cfg.start_check_point, verbose=False)
mdx23c = mdx23c.to(device)

train_dataloader, valid_dataloader, test_dataloader = load_data(mdx23c_config, cfg, 2)
mdx23c = mdx23c.to(device)
valid(mdx23c, cfg, mdx23c_config, device)
ref_mdx23c = copy.deepcopy(mdx23c)
weight_remover = WeightRemover(mdx23c, device, 0.8)

Use augmentation for training
Loading songs data from cache: Results/metadata_train.pkl. If you updated dataset remove metadata_train.pkl before training!
Found tracks in dataset: 80
Use augmentation for training
Loading songs data from cache: Results/metadata_valid.pkl. If you updated dataset remove metadata_valid.pkl before training!
Found tracks in dataset: 20
Use augmentation for training
Loading songs data from cache: Results/metadata_test.pkl. If you updated dataset remove metadata_test.pkl before training!
Found tracks in dataset: 50


100%|██████████| 20/20 [04:57<00:00, 14.87s/it, sdr_vocals=8.22, sdr_other=19.2]

Instr SDR vocals: 11.3885
Instr SDR other: 17.8419
SDR Avg: 14.6152





In [5]:
print("start pruning")
for idx in range(10):
    tr = torch.tensor(np.random.rand(2, 2, 261120), dtype=torch.float32).to(device)
    with torch.no_grad():
        y_ = weight_remover.process(tr)
    weight_remover.apply_removal()
valid(mdx23c, cfg, mdx23c_config, device)

start pruning
before 262144
after 235270
before 262144
after 236813
before 262144
after 235857
before 262144
after 236601
before 65536
after 58728
before 65536
after 59277
before 65536
after 58787
before 65536
after 59124
before 16384
after 14617
before 16384
after 14814
before 16384
after 14385
before 16384
after 14797
before 4096
after 3682
before 4096
after 3729
before 4096
after 3673
before 4096
after 3730
before 1024
after 940
before 1024
after 937
before 1024
after 932
before 1024
after 953
before 256
after 235
before 256
after 235
before 256
after 241
before 256
after 239
before 1024
after 937
before 1024
after 947
before 1024
after 932
before 1024
after 943
before 4096
after 3696
before 4096
after 3754
before 4096
after 3669
before 4096
after 3707
before 16384
after 14545
before 16384
after 14768
before 16384
after 14603
before 16384
after 14819
before 65536
after 57795
before 65536
after 58814
before 65536
after 57423
before 65536
after 58749
before 262144
after 231605
before 

100%|██████████| 20/20 [05:08<00:00, 15.43s/it, sdr_vocals=8.19, sdr_other=19.1]

Instr SDR vocals: 11.3138
Instr SDR other: 17.7662
SDR Avg: 14.5400





14.5400019747205

In [6]:
print("CI")
ci = ConcernIdentification(ref_mdx23c, mdx23c, device, 0.4)
_, temp_config = get_model_from_config(cfg.model_type, cfg.config_path)
temp_config.training.instruments = ["vocals"]
temp_dataloader, _, _ = load_data(temp_config, cfg, 2)

CI
Use augmentation for training
Loading songs data from cache: Results/metadata_train.pkl. If you updated dataset remove metadata_train.pkl before training!
Found tracks in dataset: 80
Use augmentation for training
Loading songs data from cache: Results/metadata_valid.pkl. If you updated dataset remove metadata_valid.pkl before training!
Found tracks in dataset: 20
Use augmentation for training
Loading songs data from cache: Results/metadata_test.pkl. If you updated dataset remove metadata_test.pkl before training!
Found tracks in dataset: 50


In [7]:
pbar = tqdm(temp_dataloader)
for i, (batch, mixes) in enumerate(pbar):
    y = batch.to(device)
    x = mixes.to(device)
    with torch.no_grad():
        ci.process(x)
    ci.apply_prune()
    if i > 100:
        break

  0%|          | 0/1000 [00:04<?, ?it/s]

before 234226





RuntimeError: expand(CUDABoolType{[256, 1, 256]}, size=[-1, 1024]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)

In [None]:
store_path = os.path.join(cfg.results_path, f'model_mdx23c(after prune).ckpt')
state_dict = mdx23c.state_dict()
torch.save(state_dict, store_path)