In [1]:
import torch
import copy
import numpy as np
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
from utils.prune_utils.concern_identification import ConcernIdentification
from utils.prune_utils.weight_remover import WeightRemover
from utils.dataset import load_data
from utils.train import valid, load_not_compatible_weights, get_model_from_config
from utils.config_utils import load_config

In [2]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.multiprocessing.set_start_method('spawn')

In [3]:
print("Starting the pruning script")
device = 'cuda:0'
cfg = load_config('train_config.yaml')
cfg.model_type = 'mdx23c'
cfg.config_path = 'Configs/mdx23c_config.yaml'
cfg.results_path = 'Results/mdx23c'
cfg.data_path = 'Datasets/musdb18hq/train'
cfg.num_workers = 0
cfg.valid_path = 'Datasets/musdb18hq/valid'
cfg.seed = 44
cfg.start_check_point = 'Results/model_mdx23c(before prune).ckpt'

Starting the pruning script


In [4]:
print("Loading model and configuration")
mdx23c, mdx23c_config = get_model_from_config(cfg.model_type, cfg.config_path)
load_not_compatible_weights(mdx23c, cfg.start_check_point, verbose=False)
mdx23c = mdx23c.to(device)

Loading model and configuration


In [7]:
valid(mdx23c, cfg, mdx23c_config, device)

100%|██████████| 20/20 [05:36<00:00, 16.81s/it, sdr_vocals=8.22, sdr_other=19.2]

Instr SDR vocals: 11.3885
Instr SDR other: 17.8419
SDR Avg: 14.6152





14.615231353169484

In [5]:
print("Loading data")
train_dataloader, valid_dataloader, test_dataloader = load_data(mdx23c_config, cfg, 2)

Loading data
Use augmentation for training
Loading songs data from cache: Results/mdx23c/metadata_train.pkl. If you updated dataset remove metadata_train.pkl before training!
Found tracks in dataset: 80
Use augmentation for training
Loading songs data from cache: Results/mdx23c/metadata_valid.pkl. If you updated dataset remove metadata_valid.pkl before training!
Found tracks in dataset: 20
Use augmentation for training
Loading songs data from cache: Results/mdx23c/metadata_test.pkl. If you updated dataset remove metadata_test.pkl before training!
Found tracks in dataset: 50


In [6]:
ref_mdx23c = copy.deepcopy(mdx23c)
weight_remover = WeightRemover(mdx23c, device, 0.8)

In [8]:
print("Start pruning")
for idx in range(10):
    tr = torch.tensor(np.random.rand(2, 2, 261120), dtype=torch.float32).to(device)
    with torch.no_grad():
        y_ = weight_remover.process(tr)
    weight_remover.apply_removal()
valid(mdx23c, cfg, mdx23c_config, device)

Start pruning
before 262144
after 235270
before 262144
after 236813
before 262144
after 235857
before 262144
after 236601
before 65536
after 58728
before 65536
after 59277
before 65536
after 58787
before 65536
after 59124
before 16384
after 14617
before 16384
after 14814
before 16384
after 14385
before 16384
after 14797
before 4096
after 3682
before 4096
after 3729
before 4096
after 3673
before 4096
after 3730
before 1024
after 940
before 1024
after 937
before 1024
after 932
before 1024
after 953
before 256
after 235
before 256
after 235
before 256
after 241
before 256
after 239
before 1024
after 937
before 1024
after 947
before 1024
after 932
before 1024
after 943
before 4096
after 3696
before 4096
after 3754
before 4096
after 3669
before 4096
after 3707
before 16384
after 14545
before 16384
after 14768
before 16384
after 14603
before 16384
after 14819
before 65536
after 57795
before 65536
after 58814
before 65536
after 57423
before 65536
after 58749
before 262144
after 231605
before 

100%|██████████| 20/20 [05:47<00:00, 17.36s/it, sdr_vocals=8.19, sdr_other=19.1]

Instr SDR vocals: 11.3138
Instr SDR other: 17.7663
SDR Avg: 14.5400





14.540026382374885

In [9]:
print("Concern Identification")
ci = ConcernIdentification(ref_mdx23c, mdx23c, device, 0.6)
_, temp_config = get_model_from_config(cfg.model_type, cfg.config_path)
temp_config.training.instruments = ["vocals"]
temp_dataloader, _, _ = load_data(temp_config, cfg, 2)

Concern Identification
Use augmentation for training
Loading songs data from cache: Results/mdx23c/metadata_train.pkl. If you updated dataset remove metadata_train.pkl before training!
Found tracks in dataset: 80
Use augmentation for training
Loading songs data from cache: Results/mdx23c/metadata_valid.pkl. If you updated dataset remove metadata_valid.pkl before training!
Found tracks in dataset: 20
Use augmentation for training
Loading songs data from cache: Results/mdx23c/metadata_test.pkl. If you updated dataset remove metadata_test.pkl before training!
Found tracks in dataset: 50


In [10]:
for i, (batch, mixes) in enumerate(temp_dataloader):
    y = batch.to(device)
    x = mixes.to(device)
    with torch.no_grad():
        ci.process(x)
    ci.apply_prune()
    if i > 100:
        break

before 234226
after 212700
before 235611
after 214170
before 234904
after 212149
before 234891
after 213456
before 58419
after 53311
before 58776
after 53477
before 58583
after 53393
before 58888
after 53619
before 14465
after 13123
before 14751
after 13510
before 14297
after 12762
before 14639
after 13422
before 3645
after 3337
before 3710
after 3427
before 3658
after 3330
before 3699
after 3457
before 939
after 890
before 925
after 880
before 922
after 821
before 949
after 873
before 233
after 221
before 234
after 226
before 240
after 230
before 238
after 228
before 918
after 857
before 942
after 900
before 916
after 873
before 937
after 903
before 3667
after 3353
before 3731
after 3495
before 3608
after 3222
before 3626
after 3506
before 14463
after 13245
before 14601
after 13331
before 14556
after 13083
before 14664
after 13419
before 57555
after 51587
before 58512
after 52783
before 56945
after 51141
before 58428
after 51995
before 230564
after 205899
before 231059
after 207013
be

In [12]:
valid(mdx23c, cfg, mdx23c_config, device)

100%|██████████| 20/20 [06:12<00:00, 18.64s/it, sdr_vocals=3.07, sdr_other=14]   

Instr SDR vocals: 6.0725
Instr SDR other: 12.4738
SDR Avg: 9.2731





9.273112375028784

In [11]:
store_path = os.path.join(cfg.results_path, f'model_mdx23c(after prune).ckpt')
state_dict = mdx23c.state_dict()
torch.save(state_dict, store_path)
print("Pruning script finished successfully")

Pruning script finished successfully
