In [2]:
%load_ext autoreload
%autoreload 2

## Dataset & Dataloader

In [3]:
import os
os.chdir("../")

from omegaconf import OmegaConf

from neuralfp.data.datasets import MusicSegmentDataset, collate_data
from neuralfp.utils.common import load_dataset

In [4]:
config = "configs/train.yaml"
config = OmegaConf.load(config)

dataset = MusicSegmentDataset(config["dataset"]["train"])

Loading IR: 100%|██████████| 345/345 [00:03<00:00, 110.66it/s]


In [5]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset=dataset,
    collate_fn=collate_data,
    shuffle=False,
    **config["dataset"]["loaders"],
)

In [6]:
# import tqdm

# for batch in tqdm.tqdm(dataloader):
#     features, targets = batch
#     print(features.shape, targets.shape)

## Model

In [7]:
from neuralfp.model.neuralfp import NeuralAudioFingerprinter

In [8]:
config = "configs/train.yaml"
config = OmegaConf.load(config)

model = NeuralAudioFingerprinter(**config["model"]["neuralfp"])

In [9]:
import tqdm
import torch

for batch in tqdm.tqdm(dataloader):
    features, targets = batch
    print("features", features.shape)
    xs = torch.stack([features, targets], dim=0)
    print("xs", xs.shape)

    xs = torch.flatten(xs, 0, 1)
    out = model(xs)
    print("out", out.shape)




features torch.Size([116, 256, 32])
xs torch.Size([2, 116, 256, 32])


  0%|          | 1/5000 [00:26<36:36:52, 26.37s/it]

out torch.Size([232, 128])
features torch.Size([116, 256, 32])
xs torch.Size([2, 116, 256, 32])


  0%|          | 1/5000 [00:29<40:56:09, 29.48s/it]


KeyboardInterrupt: 

## Loss function

In [10]:
from neuralfp.criterion.contrastive_loss import NTxentLoss

criterion = NTxentLoss()

In [11]:
import tqdm
import torch

for batch in tqdm.tqdm(dataloader):
    features, targets = batch
    xs = torch.stack([features, targets], dim=0)
    xs = torch.flatten(xs, 0, 1)
    out = model(xs)
    n_anchors = out.shape[0] // 2
    print("n_anchors", n_anchors)
    loss = criterion(
        out[:n_anchors, :], out[n_anchors:, :], n_anchors
    )
    print("loss", loss)

  0%|          | 0/5000 [00:00<?, ?it/s]



n_anchors 116


  0%|          | 1/5000 [00:10<14:44:08, 10.61s/it]

loss tensor(10.8773, grad_fn=<AddBackward0>)


  0%|          | 2/5000 [00:13<8:40:40,  6.25s/it] 

n_anchors 116
loss tensor(10.8913, grad_fn=<AddBackward0>)


  0%|          | 3/5000 [00:17<6:56:30,  5.00s/it]

n_anchors 116
loss tensor(10.8739, grad_fn=<AddBackward0>)


  0%|          | 4/5000 [00:20<5:57:24,  4.29s/it]

n_anchors 116
loss tensor(10.8706, grad_fn=<AddBackward0>)


  0%|          | 5/5000 [00:23<5:28:33,  3.95s/it]

n_anchors 117
loss tensor(10.8902, grad_fn=<AddBackward0>)
n_anchors 117


  0%|          | 6/5000 [00:27<5:14:21,  3.78s/it]

loss tensor(10.8964, grad_fn=<AddBackward0>)


  0%|          | 7/5000 [00:30<5:05:41,  3.67s/it]

n_anchors 116
loss tensor(10.8862, grad_fn=<AddBackward0>)


  0%|          | 8/5000 [00:33<4:53:36,  3.53s/it]

n_anchors 117
loss tensor(10.8986, grad_fn=<AddBackward0>)


  0%|          | 9/5000 [00:37<4:51:06,  3.50s/it]

n_anchors 116
loss tensor(10.8768, grad_fn=<AddBackward0>)
n_anchors 117


  0%|          | 10/5000 [00:40<4:47:46,  3.46s/it]

loss tensor(10.8910, grad_fn=<AddBackward0>)


  0%|          | 11/5000 [00:44<4:55:56,  3.56s/it]

n_anchors 116
loss tensor(10.8888, grad_fn=<AddBackward0>)
n_anchors 117


  0%|          | 12/5000 [00:48<4:55:30,  3.55s/it]

loss tensor(10.8949, grad_fn=<AddBackward0>)
n_anchors 116


  0%|          | 13/5000 [00:51<4:55:32,  3.56s/it]

loss tensor(10.8878, grad_fn=<AddBackward0>)


  0%|          | 13/5000 [00:53<5:39:51,  4.09s/it]


KeyboardInterrupt: 

## Load checkpoint

In [5]:
import torch

checkpoint = torch.load("/home/huynd/Code/AI-beat-maker/train/artifacts/neuralfp_epoch88.pt", map_location="cpu")

In [6]:
checkpoint["state_dict"]

{'model': OrderedDict([('encoder.convs.0.conv1.weight',
               tensor([[[[-5.5425e-02, -1.3337e-01,  3.2048e-01]]],
               
               
                       [[[-5.1148e-01,  3.4289e-01,  4.2535e-01]]],
               
               
                       [[[ 1.6304e-01,  7.4732e-01,  6.5381e-02]]],
               
               
                       [[[ 3.0992e-01,  2.7043e-02, -2.7534e-01]]],
               
               
                       [[[ 2.6602e-01, -4.0101e-01, -2.6661e-01]]],
               
               
                       [[[ 4.1530e-01,  3.4362e-01, -1.6174e-01]]],
               
               
                       [[[-2.9111e-01, -3.6940e-01,  1.8244e-01]]],
               
               
                       [[[ 4.7709e-01, -2.7537e-02,  3.5096e-01]]],
               
               
                       [[[ 4.0124e-01, -3.2955e-01,  9.9622e-02]]],
               
               
                       [[[-8.4981e-02,  7.02