In [3]:
from data.processing import ParseBalanced

directory = "large-melspec-dataset-top-50-LIBROSA-256-Triplet"
data_directory = "E:/mtg-jamendo/"
subset_file_name = "autotagging_top50tags"
ParseBalanced(subset_file_name, f"{data_directory}", f"E:/SongsDataset/{directory}", convert=True, target_per_genre=1300, chunk_size=256, chunks_per_batch=1, write_individually=True)

KeyboardInterrupt: 

In [1]:
from torchaudio.transforms import TimeMasking, FrequencyMasking
#from loss.ConstrastiveLoss import InfoCNELoss
from info_nce import InfoNCE
from data.data_utils import *

directory = "large-melspec-dataset-top-50-LIBROSA-256-Triplet"
data_directory = "E:/mtg-jamendo/"
subset_file_name = "autotagging_top50tags"

augmentations = Compose([
    AddGaussianNoise(std=0.25),
    TimeMasking(time_mask_param=int(0.05* 256)),
    FrequencyMasking(freq_mask_param=int(0.05 * 128)),
])

class Config:
    # === General ===
    model_name = "ViT-Contrastive-Custom-Slope"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float32
    save_path = f"trained_models\\{model_name}\\"
    seed = 42

    # === Training ===
    num_classes = 50
    num_epochs = 10
    batch_size = 64
    max_batch_size = 64
    learning_rate = 1e-4
    min_learning_rate = 1e-4
    weight_decay = 1e-4

    warmup_threshold = 1.0 / 100.0
    step_coefficient = 25.0 / 100.0

    gamma = 2.0
    save_checkpoints = True

    # === Dataset ===
    transforms = None
    use_masks = True
    num_workers = 1
    prefetch_factor = 3
    val_split = 0.1
    shuffle = True
    pos_weight = (torch.ones(num_classes) * 50).to("cuda")
    criterion = InfoNCE()

In [2]:
from torch.utils.data import DataLoader

large_directory = directory

train_dataset = StreamViewDataset(f"E:\\SongsDataset\\{large_directory}\\train_set\\data", f"E:\\SongsDataset\\{large_directory}\\train_set\\genre_labels")
test_dataset  = StreamViewDataset(f"E:\\SongsDataset\\{large_directory}\\test_set\\data", f"E:\\SongsDataset\\{large_directory}\\test_set\\genre_labels")

train_dataloader = DataLoader(
    train_dataset,
    batch_size=Config.batch_size,
    shuffle=True,
    num_workers=Config.num_workers,
    prefetch_factor=Config.prefetch_factor,
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=Config.batch_size,
    shuffle=True,
    num_workers=Config.num_workers,
    prefetch_factor=Config.prefetch_factor,
)

In [3]:
from models.AudioTransformer import AudioTransformer
from utils import misc

model = AudioTransformer(latent_space=32, input_dim=128, d_model=256, dim_feedforward=512, length=256, num_heads=8, encoder_layers=16, decoder_layers=16, dropout=0.1, use_alibi=True, custom_slopes=True)

print(f"{misc.model_size(model)} Parameters")

25462368 Parameters


In [6]:
from models.AudioViT import AudioViT
from utils import misc

model = AudioViT(latent_space=32, input_dim=128, d_model=256, length=256, num_heads=8, encoder_layers=8, decoder_layers=8, dropout=0.1, use_alibi=True, use_rope=True)

print(f"{misc.model_size(model)} Parameters")

23554720 Parameters


In [None]:
from models.AudioPreCNNTransformer import AudioPreCNNTransformer
from utils import misc

model = AudioPreCNNTransformer(latent_space=512, input_dim=128, length=1024, num_heads=8, transformer_layers=8, d_model=256, dropout=0.1)
print(f"{misc.model_size(model)} Parameters")

In [None]:
from models.AudioTransformerWeaved import AudioTransformerWeaved
from utils import misc

model = AudioTransformerWeaved(latent_space=128, input_dim=128, d_model=256, dim_feedforward=512, length=1024, num_heads=8, encoder_layers=8, decoder_layers=8, dropout=0.1, use_alibi=True)

print(f"{misc.model_size(model)} Parameters")

In [9]:
from models.AudioViTEncoder import AudioViTEncoder
from utils import misc

model = AudioViTEncoder(latent_space=32, input_dim=128, d_model=256, length=256, num_heads=8, encoder_layers=8, decoder_layers=8, dropout=0.1, use_alibi=True, use_rope=True)

print(f"{misc.model_size(model)} Parameters")

13792256 Parameters


In [10]:
from training.contrastive_training import train_contrastive
train_contrastive(model, test_dataloader, train_dataloader, Config, show_graph=False)

  0%|          | 0/155 [00:03<?, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x128 and 256x128)

In [3]:
from models.AudioTransformerEmbeddings import AudioTransformerEmbeddings
from utils import misc

model = AudioTransformerEmbeddings(latent_space=128, input_dim=128, d_model=256, dim_feedforward=512, length=256, num_heads=8, encoder_layers=8, dropout=0.1, use_alibi=True, use_rope=True, custom_slope=1)

print(f"{misc.model_size(model)} Parameters")

24468914 Parameters


In [4]:
from training.contrastive_training import train_contrastive
train_contrastive(model, test_dataloader, train_dataloader, Config, show_graph=False)

  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 4.1884

Batch Loss [2/155]: 4.1526

Batch Loss [3/155]: 4.1499

Batch Loss [4/155]: 4.1457

Batch Loss [5/155]: 4.1451

Batch Loss [6/155]: 4.1244

Batch Loss [7/155]: 4.0774

Batch Loss [8/155]: 3.9222

Batch Loss [9/155]: 4.3615

Batch Loss [10/155]: 4.0819

Batch Loss [11/155]: 4.1155

Batch Loss [12/155]: 4.1398

Batch Loss [13/155]: 4.1480

Batch Loss [14/155]: 4.1465

Batch Loss [15/155]: 4.1500

Batch Loss [16/155]: 4.1527

Batch Loss [17/155]: 4.1539

Batch Loss [18/155]: 4.1544

Batch Loss [19/155]: 4.1559

Batch Loss [20/155]: 4.1561

Batch Loss [21/155]: 4.1558

Batch Loss [22/155]: 4.1562

Batch Loss [23/155]: 4.1552

Batch Loss [24/155]: 4.1561

Batch Loss [25/155]: 4.1577

Batch Loss [26/155]: 4.1569

Batch Loss [27/155]: 4.1568

Batch Loss [28/155]: 4.1564

Batch Loss [29/155]: 4.1559

Batch Loss [30/155]: 4.1561

Batch Loss [31/155]: 4.1563

Batch Loss [32/155]: 4.1561

Batch Loss [33/155]: 4.1565

Batch Loss [34/155]: 4.1563

Batch Loss [35/155]: 4.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 0] Train:  4.0232
Test:  Contrastive Loss: 3.7577


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 3.8405

Batch Loss [2/155]: 3.7926

Batch Loss [3/155]: 3.7873

Batch Loss [4/155]: 3.8034

Batch Loss [5/155]: 3.8721

Batch Loss [6/155]: 4.0102

Batch Loss [7/155]: 4.0158

Batch Loss [8/155]: 3.7368

Batch Loss [9/155]: 3.8911

Batch Loss [10/155]: 3.8592

Batch Loss [11/155]: 3.7474

Batch Loss [12/155]: 3.5344

Batch Loss [13/155]: 3.5241

Batch Loss [14/155]: 3.6664

Batch Loss [15/155]: 3.3324

Batch Loss [16/155]: 3.7630

Batch Loss [17/155]: 3.6764

Batch Loss [18/155]: 3.5998

Batch Loss [19/155]: 3.8721

Batch Loss [20/155]: 3.8764

Batch Loss [21/155]: 3.8514

Batch Loss [22/155]: 3.6784

Batch Loss [23/155]: 3.7626

Batch Loss [24/155]: 3.5968

Batch Loss [25/155]: 3.7355

Batch Loss [26/155]: 3.9297

Batch Loss [27/155]: 3.5766

Batch Loss [28/155]: 3.8647

Batch Loss [29/155]: 3.5471

Batch Loss [30/155]: 3.5490

Batch Loss [31/155]: 3.6468

Batch Loss [32/155]: 3.5345

Batch Loss [33/155]: 3.7559

Batch Loss [34/155]: 3.6981

Batch Loss [35/155]: 4.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 1] Train:  3.4836
Test:  Contrastive Loss: 3.1360


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 3.2981

Batch Loss [2/155]: 3.5625

Batch Loss [3/155]: 3.6153

Batch Loss [4/155]: 3.3741

Batch Loss [5/155]: 3.1808

Batch Loss [6/155]: 3.4592

Batch Loss [7/155]: 2.8864

Batch Loss [8/155]: 3.4723

Batch Loss [9/155]: 3.4055

Batch Loss [10/155]: 3.0117

Batch Loss [11/155]: 3.3070

Batch Loss [12/155]: 3.2370

Batch Loss [13/155]: 3.1811

Batch Loss [14/155]: 3.1600

Batch Loss [15/155]: 3.1813

Batch Loss [16/155]: 3.0550

Batch Loss [17/155]: 3.2667

Batch Loss [18/155]: 2.9552

Batch Loss [19/155]: 3.3671

Batch Loss [20/155]: 3.1685

Batch Loss [21/155]: 3.5511

Batch Loss [22/155]: 3.0719

Batch Loss [23/155]: 3.1001

Batch Loss [24/155]: 2.8897

Batch Loss [25/155]: 2.9194

Batch Loss [26/155]: 3.0195

Batch Loss [27/155]: 3.5323

Batch Loss [28/155]: 3.1495

Batch Loss [29/155]: 2.9395

Batch Loss [30/155]: 3.2017

Batch Loss [31/155]: 2.7772

Batch Loss [32/155]: 3.0174

Batch Loss [33/155]: 3.1835

Batch Loss [34/155]: 3.0413

Batch Loss [35/155]: 3.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 2] Train:  3.1956
Test:  Contrastive Loss: 3.1084


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 3.0421

Batch Loss [2/155]: 2.9511

Batch Loss [3/155]: 3.1770

Batch Loss [4/155]: 3.3029

Batch Loss [5/155]: 2.8730

Batch Loss [6/155]: 3.1153

Batch Loss [7/155]: 3.6709

Batch Loss [8/155]: 3.0855

Batch Loss [9/155]: 2.7987

Batch Loss [10/155]: 2.7534

Batch Loss [11/155]: 3.2721

Batch Loss [12/155]: 3.2517

Batch Loss [13/155]: 3.3331

Batch Loss [14/155]: 2.9804

Batch Loss [15/155]: 3.0275

Batch Loss [16/155]: 3.4674

Batch Loss [17/155]: 3.2641

Batch Loss [18/155]: 3.0097

Batch Loss [19/155]: 3.0534

Batch Loss [20/155]: 3.4083

Batch Loss [21/155]: 3.3051

Batch Loss [22/155]: 3.3070

Batch Loss [23/155]: 2.8847

Batch Loss [24/155]: 3.3194

Batch Loss [25/155]: 2.9499

Batch Loss [26/155]: 3.3059

Batch Loss [27/155]: 3.2021

Batch Loss [28/155]: 3.2367

Batch Loss [29/155]: 3.3345

Batch Loss [30/155]: 2.8073

Batch Loss [31/155]: 2.6892

Batch Loss [32/155]: 3.1819

Batch Loss [33/155]: 2.7545

Batch Loss [34/155]: 3.2491

Batch Loss [35/155]: 3.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 3] Train:  3.1101
Test:  Contrastive Loss: 3.0885


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 2.6913

Batch Loss [2/155]: 3.1025

Batch Loss [3/155]: 2.8624

Batch Loss [4/155]: 2.9823

Batch Loss [5/155]: 3.2944

Batch Loss [6/155]: 2.8856

Batch Loss [7/155]: 3.0882

Batch Loss [8/155]: 3.4618

Batch Loss [9/155]: 2.8896

Batch Loss [10/155]: 2.6142

Batch Loss [11/155]: 2.9882

Batch Loss [12/155]: 2.7989

Batch Loss [13/155]: 3.2420

Batch Loss [14/155]: 3.1664

Batch Loss [15/155]: 3.1133

Batch Loss [16/155]: 2.8073

Batch Loss [17/155]: 3.4149

Batch Loss [18/155]: 2.9738

Batch Loss [19/155]: 3.3024

Batch Loss [20/155]: 3.4123

Batch Loss [21/155]: 3.2222

Batch Loss [22/155]: 3.0960

Batch Loss [23/155]: 2.9716

Batch Loss [24/155]: 3.0357

Batch Loss [25/155]: 3.0447

Batch Loss [26/155]: 3.0406

Batch Loss [27/155]: 3.5186

Batch Loss [28/155]: 3.2198

Batch Loss [29/155]: 3.2987

Batch Loss [30/155]: 3.0716

Batch Loss [31/155]: 2.9497

Batch Loss [32/155]: 2.7856

Batch Loss [33/155]: 2.8548

Batch Loss [34/155]: 3.2370

Batch Loss [35/155]: 3.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 4] Train:  2.9896
Test:  Contrastive Loss: 2.9599


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 2.8469

Batch Loss [2/155]: 2.7223

Batch Loss [3/155]: 2.8344

Batch Loss [4/155]: 3.1824

Batch Loss [5/155]: 3.3654

Batch Loss [6/155]: 3.0972

Batch Loss [7/155]: 2.9265

Batch Loss [8/155]: 3.0113

Batch Loss [9/155]: 3.0162

Batch Loss [10/155]: 3.0195

Batch Loss [11/155]: 2.9068

Batch Loss [12/155]: 2.7497

Batch Loss [13/155]: 3.1078

Batch Loss [14/155]: 2.9592

Batch Loss [15/155]: 3.0224

Batch Loss [16/155]: 3.0663

Batch Loss [17/155]: 2.7032

Batch Loss [18/155]: 3.1835

Batch Loss [19/155]: 2.6513

Batch Loss [20/155]: 3.2090

Batch Loss [21/155]: 2.6095

Batch Loss [22/155]: 3.0229

Batch Loss [23/155]: 3.2113

Batch Loss [24/155]: 2.8797

Batch Loss [25/155]: 2.9817

Batch Loss [26/155]: 2.9433

Batch Loss [27/155]: 3.3217

Batch Loss [28/155]: 2.9197

Batch Loss [29/155]: 2.9864

Batch Loss [30/155]: 2.7503

Batch Loss [31/155]: 2.6702

Batch Loss [32/155]: 2.6200

Batch Loss [33/155]: 2.7730

Batch Loss [34/155]: 3.4008

Batch Loss [35/155]: 2.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 5] Train:  2.9073
Test:  Contrastive Loss: 2.8778


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 3.3330

Batch Loss [2/155]: 3.2400

Batch Loss [3/155]: 2.8894

Batch Loss [4/155]: 2.6360

Batch Loss [5/155]: 2.8867

Batch Loss [6/155]: 2.2880

Batch Loss [7/155]: 2.8384

Batch Loss [8/155]: 3.0213

Batch Loss [9/155]: 2.6251

Batch Loss [10/155]: 3.1536

Batch Loss [11/155]: 2.8734

Batch Loss [12/155]: 2.9391

Batch Loss [13/155]: 3.1916

Batch Loss [14/155]: 2.7642

Batch Loss [15/155]: 2.6112

Batch Loss [16/155]: 2.6110

Batch Loss [17/155]: 2.9648

Batch Loss [18/155]: 2.5096

Batch Loss [19/155]: 2.9813

Batch Loss [20/155]: 2.7457

Batch Loss [21/155]: 2.8642

Batch Loss [22/155]: 3.2810

Batch Loss [23/155]: 2.6882

Batch Loss [24/155]: 2.9640

Batch Loss [25/155]: 2.4064

Batch Loss [26/155]: 2.6226

Batch Loss [27/155]: 3.0960

Batch Loss [28/155]: 3.1989

Batch Loss [29/155]: 2.8461

Batch Loss [30/155]: 2.6999

Batch Loss [31/155]: 2.9911

Batch Loss [32/155]: 2.3826

Batch Loss [33/155]: 2.9716

Batch Loss [34/155]: 2.9422

Batch Loss [35/155]: 2.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 6] Train:  2.8212
Test:  Contrastive Loss: 2.7419


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 2.7353

Batch Loss [2/155]: 2.7859

Batch Loss [3/155]: 2.6750

Batch Loss [4/155]: 3.0084

Batch Loss [5/155]: 2.6063

Batch Loss [6/155]: 2.7833

Batch Loss [7/155]: 3.4698

Batch Loss [8/155]: 2.7249

Batch Loss [9/155]: 2.7267

Batch Loss [10/155]: 2.7425

Batch Loss [11/155]: 2.3420

Batch Loss [12/155]: 2.9534

Batch Loss [13/155]: 2.9140

Batch Loss [14/155]: 2.6093

Batch Loss [15/155]: 2.8584

Batch Loss [16/155]: 2.8271

Batch Loss [17/155]: 2.4673

Batch Loss [18/155]: 2.3253

Batch Loss [19/155]: 2.8019

Batch Loss [20/155]: 3.1746

Batch Loss [21/155]: 2.5976

Batch Loss [22/155]: 2.7936

Batch Loss [23/155]: 3.0719

Batch Loss [24/155]: 2.6509

Batch Loss [25/155]: 3.1183

Batch Loss [26/155]: 2.6229

Batch Loss [27/155]: 2.5749

Batch Loss [28/155]: 3.0110

Batch Loss [29/155]: 2.3895

Batch Loss [30/155]: 2.7478

Batch Loss [31/155]: 2.4448

Batch Loss [32/155]: 2.7105

Batch Loss [33/155]: 2.6872

Batch Loss [34/155]: 2.5083

Batch Loss [35/155]: 2.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 7] Train:  2.7265
Test:  Contrastive Loss: 2.7541


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 2.1357

Batch Loss [2/155]: 2.4402

Batch Loss [3/155]: 2.7098

Batch Loss [4/155]: 3.0688

Batch Loss [5/155]: 2.6289

Batch Loss [6/155]: 2.3479

Batch Loss [7/155]: 2.5101

Batch Loss [8/155]: 2.5024

Batch Loss [9/155]: 2.3499

Batch Loss [10/155]: 2.6873

Batch Loss [11/155]: 2.7202

Batch Loss [12/155]: 2.4679

Batch Loss [13/155]: 2.7765

Batch Loss [14/155]: 2.7184

Batch Loss [15/155]: 2.5186

Batch Loss [16/155]: 2.5707

Batch Loss [17/155]: 2.7563

Batch Loss [18/155]: 2.4547

Batch Loss [19/155]: 2.5311

Batch Loss [20/155]: 2.1665

Batch Loss [21/155]: 2.4842

Batch Loss [22/155]: 2.7562

Batch Loss [23/155]: 3.2144

Batch Loss [24/155]: 2.3407

Batch Loss [25/155]: 2.8353

Batch Loss [26/155]: 2.9414

Batch Loss [27/155]: 2.8967

Batch Loss [28/155]: 2.5629

Batch Loss [29/155]: 2.6213

Batch Loss [30/155]: 3.1570

Batch Loss [31/155]: 3.1657

Batch Loss [32/155]: 2.8460

Batch Loss [33/155]: 2.7603

Batch Loss [34/155]: 2.5515

Batch Loss [35/155]: 2.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 8] Train:  2.6642
Test:  Contrastive Loss: 2.5575


  0%|          | 0/155 [00:03<?, ?it/s]

Batch Loss [1/155]: 2.3081

Batch Loss [2/155]: 2.4887

Batch Loss [3/155]: 2.4743

Batch Loss [4/155]: 2.4905

Batch Loss [5/155]: 2.8146

Batch Loss [6/155]: 2.7982

Batch Loss [7/155]: 2.9638

Batch Loss [8/155]: 2.4079

Batch Loss [9/155]: 2.5997

Batch Loss [10/155]: 2.7343

Batch Loss [11/155]: 2.5180

Batch Loss [12/155]: 2.6350

Batch Loss [13/155]: 2.9264

Batch Loss [14/155]: 2.8807

Batch Loss [15/155]: 2.4665

Batch Loss [16/155]: 2.7810

Batch Loss [17/155]: 2.1995

Batch Loss [18/155]: 2.5400

Batch Loss [19/155]: 2.9568

Batch Loss [20/155]: 2.1293

Batch Loss [21/155]: 2.4708

Batch Loss [22/155]: 2.2857

Batch Loss [23/155]: 2.6717

Batch Loss [24/155]: 2.2578

Batch Loss [25/155]: 2.4573

Batch Loss [26/155]: 2.4425

Batch Loss [27/155]: 2.5966

Batch Loss [28/155]: 2.8571

Batch Loss [29/155]: 2.5796

Batch Loss [30/155]: 2.2298

Batch Loss [31/155]: 2.4132

Batch Loss [32/155]: 2.4424

Batch Loss [33/155]: 2.6939

Batch Loss [34/155]: 2.5359

Batch Loss [35/155]: 3.

  0%|          | 0/18 [00:00<?, ?it/s]

[Epoch 9] Train:  2.5832
Test:  Contrastive Loss: 2.5203


In [7]:
from training.autoencoding_training import train_autoencode
train_autoencode(model, test_dataloader, train_dataloader, Config, show_graph=False)

  0%|          | 0/98 [00:00<?, ?it/s]

Mid-Epoch Loss [1/98]: 1.855

Mid-Epoch Loss [2/98]: 1.814

Mid-Epoch Loss [3/98]: 1.78

Mid-Epoch Loss [4/98]: 1.751

Mid-Epoch Loss [5/98]: 1.729

Mid-Epoch Loss [6/98]: 1.707

Mid-Epoch Loss [7/98]: 1.689

Mid-Epoch Loss [8/98]: 1.675

Mid-Epoch Loss [9/98]: 1.66

Mid-Epoch Loss [10/98]: 1.648

Mid-Epoch Loss [11/98]: 1.635

Mid-Epoch Loss [12/98]: 1.623

Mid-Epoch Loss [13/98]: 1.611

Mid-Epoch Loss [14/98]: 1.6

Mid-Epoch Loss [15/98]: 1.589

Mid-Epoch Loss [16/98]: 1.577

Mid-Epoch Loss [17/98]: 1.566

Mid-Epoch Loss [18/98]: 1.556

Mid-Epoch Loss [19/98]: 1.546

Mid-Epoch Loss [20/98]: 1.536

Mid-Epoch Loss [21/98]: 1.527

Mid-Epoch Loss [22/98]: 1.517

Mid-Epoch Loss [23/98]: 1.507

Mid-Epoch Loss [24/98]: 1.499

Mid-Epoch Loss [25/98]: 1.49

Mid-Epoch Loss [26/98]: 1.481

Mid-Epoch Loss [27/98]: 1.473

Mid-Epoch Loss [28/98]: 1.465

Mid-Epoch Loss [29/98]: 1.457

Mid-Epoch Loss [30/98]: 1.449

Mid-Epoch Loss [31/98]: 1.442

Mid-Epoch Loss [32/98]: 1.434

Mid-Epoch Loss [33/98]

  0%|          | 0/11 [00:00<?, ?it/s]

[Epoch 0] Train:  1.1093
Test:  Cos: 0.4985, MSE: 0.7842


In [4]:
from training.autoencoding_training import train_autoencode
train_autoencode(model, test_dataloader, train_dataloader, Config, show_graph=False)

  0%|          | 0/98 [00:00<?, ?it/s]

Mid-Epoch Loss [1/98]: 0.7627

Mid-Epoch Loss [2/98]: 0.7602

Mid-Epoch Loss [3/98]: 0.6974

Mid-Epoch Loss [4/98]: 0.6474

Mid-Epoch Loss [5/98]: 0.6145

Mid-Epoch Loss [6/98]: 0.5882

Mid-Epoch Loss [7/98]: 0.5669

Mid-Epoch Loss [8/98]: 0.551

Mid-Epoch Loss [9/98]: 0.535

Mid-Epoch Loss [10/98]: 0.5224

Mid-Epoch Loss [11/98]: 0.5126

Mid-Epoch Loss [12/98]: 0.5031

Mid-Epoch Loss [13/98]: 0.4956

Mid-Epoch Loss [14/98]: 0.4897

Mid-Epoch Loss [15/98]: 0.4836

Mid-Epoch Loss [16/98]: 0.4773

Mid-Epoch Loss [17/98]: 0.4733

Mid-Epoch Loss [18/98]: 0.4682

Mid-Epoch Loss [19/98]: 0.4636

Mid-Epoch Loss [20/98]: 0.4598

Mid-Epoch Loss [21/98]: 0.4568

Mid-Epoch Loss [22/98]: 0.4535

Mid-Epoch Loss [23/98]: 0.4503

Mid-Epoch Loss [24/98]: 0.4477

Mid-Epoch Loss [25/98]: 0.4455

Mid-Epoch Loss [26/98]: 0.4429

Mid-Epoch Loss [27/98]: 0.4401

Mid-Epoch Loss [28/98]: 0.4383

Mid-Epoch Loss [29/98]: 0.4359

Mid-Epoch Loss [30/98]: 0.4338

Mid-Epoch Loss [31/98]: 0.4319

Mid-Epoch Loss [32/

  0%|          | 0/11 [00:00<?, ?it/s]

[Epoch 0] Train:  0.3797
Test:  Cos: 0.2006, MSE: 0.3512


  0%|          | 0/98 [00:00<?, ?it/s]

Mid-Epoch Loss [1/98]: 0.3538

Mid-Epoch Loss [2/98]: 0.3486

Mid-Epoch Loss [3/98]: 0.35

Mid-Epoch Loss [4/98]: 0.351

Mid-Epoch Loss [5/98]: 0.3511

Mid-Epoch Loss [6/98]: 0.3497

Mid-Epoch Loss [7/98]: 0.3504

Mid-Epoch Loss [8/98]: 0.348

Mid-Epoch Loss [9/98]: 0.3482

Mid-Epoch Loss [10/98]: 0.3481

Mid-Epoch Loss [11/98]: 0.3483

Mid-Epoch Loss [12/98]: 0.3478

Mid-Epoch Loss [13/98]: 0.3489

Mid-Epoch Loss [14/98]: 0.3499

Mid-Epoch Loss [15/98]: 0.3486

Mid-Epoch Loss [16/98]: 0.3487

Mid-Epoch Loss [17/98]: 0.3483

Mid-Epoch Loss [18/98]: 0.3492

Mid-Epoch Loss [19/98]: 0.3497

Mid-Epoch Loss [20/98]: 0.3502

Mid-Epoch Loss [21/98]: 0.3498

Mid-Epoch Loss [22/98]: 0.3493

Mid-Epoch Loss [23/98]: 0.3498

Mid-Epoch Loss [24/98]: 0.3501

Mid-Epoch Loss [25/98]: 0.3501

Mid-Epoch Loss [26/98]: 0.3501

Mid-Epoch Loss [27/98]: 0.3506

Mid-Epoch Loss [28/98]: 0.3512

Mid-Epoch Loss [29/98]: 0.3512

Mid-Epoch Loss [30/98]: 0.3513

Mid-Epoch Loss [31/98]: 0.3517

Mid-Epoch Loss [32/98

KeyboardInterrupt: 

In [None]:
import librosa
import IPython
import numpy as np
import torch
import os

from datasets import tqdm
from training.inference import load_and_parse_audio

def test(model):
    path = "E:\\SongsDataset\\songs\\"
    all_folders = os.listdir(path)

    with torch.no_grad():
        for each_song in tqdm(all_folders[100:110]):
            song_path = os.path.join(path, each_song)
            chunks = load_and_parse_audio(song_path, convert=True, chunk_size=1024).to("cuda")
            permuted_chunks = torch.stack([c for c in chunks])

            mean = permuted_chunks.mean(dim=[1, 2], keepdim=True)
            std = permuted_chunks.std(dim=[1, 2], keepdim=True)

            permuted_chunks = (permuted_chunks - mean) / (std + 1e-6)

            reconstructed, latent = model(permuted_chunks)

            input_tensor = np.concatenate(permuted_chunks.cpu().detach().numpy(), axis=1)
            reconstructed = np.concatenate(reconstructed.cpu().detach().numpy(), axis=1)

            input_tensor = input_tensor[:, :512]
            reconstructed = reconstructed[:, :512]

            graph(input_tensor, reconstructed)

            S_recon = librosa.feature.inverse.mel_to_stft(reconstructed)
            Y_recon = librosa.griffinlim(S_recon)

            S_orig = librosa.feature.inverse.mel_to_stft(input_tensor)
            Y_orig = librosa.griffinlim(S_orig)

            IPython.display.display(IPython.display.Audio(Y_orig, rate=44100))
            IPython.display.display(IPython.display.Audio(Y_recon, rate=44100))

In [None]:
test(model)