In [None]:
%load_ext autoreload
%autoreload 2

## Dataset & Dataloader

In [None]:
import os
os.chdir("../")

from omegaconf import OmegaConf

from neuralfp.data.datasets import MusicSegmentDataset, collate_data
from neuralfp.utils.common import load_dataset

In [None]:
config = "configs/train.yaml"
config = OmegaConf.load(config)

dataset = MusicSegmentDataset(config["dataset"]["train"])

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset=dataset,
    collate_fn=collate_data,
    shuffle=False,
    **config["dataset"]["loaders"],
)

In [None]:
# import tqdm

# for batch in tqdm.tqdm(dataloader):
#     features, targets = batch
#     print(features.shape, targets.shape)

## Model

In [None]:
from neuralfp.model.neuralfp import NeuralAudioFingerprinter

In [None]:
config = "configs/train.yaml"
config = OmegaConf.load(config)

model = NeuralAudioFingerprinter(**config["model"]["neuralfp"])

In [None]:
import tqdm
import torch

for batch in tqdm.tqdm(dataloader):
    features, targets = batch
    print("features", features.shape)
    xs = torch.stack([features, targets], dim=0)
    print("xs", xs.shape)

    xs = torch.flatten(xs, 0, 1)
    out = model(xs)
    print("out", out.shape)


## Loss function

In [None]:
from neuralfp.criterion.contrastive_loss import NTxentLoss

criterion = NTxentLoss()

In [None]:
import tqdm
import torch

for batch in tqdm.tqdm(dataloader):
    features, targets = batch
    xs = torch.stack([features, targets], dim=0)
    xs = torch.flatten(xs, 0, 1)
    out = model(xs)
    n_anchors = out.shape[0] // 2
    print("n_anchors", n_anchors)
    loss = criterion(
        out[:n_anchors, :], out[n_anchors:, :], n_anchors
    )
    print("loss", loss)

## Load checkpoint

In [None]:
import torch

checkpoint = torch.load("/home/huynd/Code/AI-beat-maker/train/artifacts/neuralfp_epoch88.pt", map_location="cpu")

In [None]:
checkpoint["state_dict"]