In [1]:
import torch
import copy
import numpy as np
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from utils.prune_utils.concern_identification import ConcernIdentification
from utils.prune_utils.weight_remover import WeightRemover
from utils.dataset import load_data
from utils.train import valid, load_not_compatible_weights, get_model_from_config
from utils.config_utils import load_config

In [2]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.multiprocessing.set_start_method('spawn')

In [3]:
print("Starting the pruning script")
device = 'cuda:0'
cfg = load_config('train_config.yaml')
cfg.model_type = 'mdx23c'
cfg.config_path = 'Configs/mdx23c_config.yaml'
cfg.results_path = 'Results/mdx23c'
cfg.data_path = 'Datasets/musdb18hq/train'
cfg.num_workers = 0
cfg.valid_path = 'Datasets/musdb18hq/valid'
cfg.seed = 44
cfg.start_check_point = 'Results/model_mdx23c(before prune).ckpt'

Starting the pruning script


In [4]:
print("Loading model and configuration")
mdx23c, mdx23c_config = get_model_from_config(cfg.model_type, cfg.config_path)
load_not_compatible_weights(mdx23c, cfg.start_check_point, verbose=False)
mdx23c = mdx23c.to(device)

Loading model and configuration


In [5]:
print("Loading data")
train_dataloader, valid_dataloader, test_dataloader = load_data(mdx23c_config, cfg, 2)

Loading data
Use augmentation for training
Loading songs data from cache: Results/metadata_train.pkl. If you updated dataset remove metadata_train.pkl before training!
Found tracks in dataset: 80
Use augmentation for training
Loading songs data from cache: Results/metadata_valid.pkl. If you updated dataset remove metadata_valid.pkl before training!
Found tracks in dataset: 20
Use augmentation for training
Loading songs data from cache: Results/metadata_test.pkl. If you updated dataset remove metadata_test.pkl before training!
Found tracks in dataset: 50


In [6]:
ref_mdx23c = copy.deepcopy(mdx23c)
weight_remover = WeightRemover(mdx23c, device, 0.9)

In [7]:
print("Start pruning")
for idx in range(10):
    tr = torch.tensor(np.random.rand(2, 2, 261120), dtype=torch.float32).to(device)
    with torch.no_grad():
        y_ = weight_remover.process(tr)
    weight_remover.apply_removal()
valid(mdx23c, cfg, mdx23c_config, device)

Start pruning
before 262144
after 235270
before 262144
after 236813
before 262144
after 235857
before 262144
after 236601
before 65536
after 58728
before 65536
after 59277
before 65536
after 58787
before 65536
after 59124
before 16384
after 14617
before 16384
after 14814
before 16384
after 14385
before 16384
after 14797
before 4096
after 3682
before 4096
after 3729
before 4096
after 3673
before 4096
after 3730
before 1024
after 940
before 1024
after 937
before 1024
after 932
before 1024
after 953
before 256
after 235
before 256
after 235
before 256
after 241
before 256
after 239
before 1024
after 937
before 1024
after 947
before 1024
after 932
before 1024
after 943
before 4096
after 3696
before 4096
after 3754
before 4096
after 3669
before 4096
after 3707
before 16384
after 14545
before 16384
after 14768
before 16384
after 14603
before 16384
after 14819
before 65536
after 57795
before 65536
after 58814
before 65536
after 57423
before 65536
after 58749
before 262144
after 231605
before 

100%|██████████| 20/20 [05:57<00:00, 17.90s/it, sdr_vocals=8.2, sdr_other=19.2] 

Instr SDR vocals: 11.3220
Instr SDR other: 17.7745
SDR Avg: 14.5483





14.548261772508022

In [8]:
print("Concern Identification")
ci = ConcernIdentification(ref_mdx23c, mdx23c, device, 0.7)
_, temp_config = get_model_from_config(cfg.model_type, cfg.config_path)
temp_config.training.instruments = ["vocals"]
temp_dataloader, _, _ = load_data(temp_config, cfg, 2)

Concern Identification
Use augmentation for training
Loading songs data from cache: Results/metadata_train.pkl. If you updated dataset remove metadata_train.pkl before training!
Found tracks in dataset: 80
Use augmentation for training
Loading songs data from cache: Results/metadata_valid.pkl. If you updated dataset remove metadata_valid.pkl before training!
Found tracks in dataset: 20
Use augmentation for training
Loading songs data from cache: Results/metadata_test.pkl. If you updated dataset remove metadata_test.pkl before training!
Found tracks in dataset: 50


In [9]:
for i, (batch, mixes) in enumerate(temp_dataloader):
    y = batch.to(device)
    x = mixes.to(device)
    with torch.no_grad():
        ci.process(x)
    ci.apply_prune()
    if i > 100:
        break

before 235270
after 213279
before 235709
after 214265
before 235857
after 212974
before 235035
after 213565
before 58728
after 53580
before 58837
after 53571
before 58787
after 53555
before 58890
after 53621
before 14617
after 13260
before 14751
after 13510
before 14385
after 12830
before 14643
after 13423
before 3682
after 3378
before 3710
after 3427
before 3673
after 3343
before 3699
after 3457
before 939
after 890
before 925
after 880
before 922
after 821
before 949
after 873
before 233
after 221
before 234
after 226
before 240
after 230
before 238
after 228
before 920
after 859
before 942
after 900
before 920
after 877
before 937
after 903
before 3670
after 3356
before 3731
after 3495
before 3669
after 3227
before 3639
after 3519
before 14545
after 13310
before 14611
after 13339
before 14603
after 13118
before 14683
after 13416
before 57795
after 51754
before 58814
after 53003
before 57423
after 51500
before 58749
after 52165
before 231605
after 206693
before 232939
after 208791
be

In [10]:
store_path = os.path.join(cfg.results_path, f'model_mdx23c(after prune).ckpt')
state_dict = mdx23c.state_dict()
torch.save(state_dict, store_path)
print("Pruning script finished successfully")

Pruning script finished successfully
