In [1]:
import numpy as np
import torch

In [2]:
from embedding import Encoder, EncoderWrapper

In [None]:
embedding_dim = 128

In [None]:
class SpectrogramParamDataset(Dataset):
    def __init__(self, data_dir):
        self.file_paths = sorted([
            os.path.join(data_dir, f)
            for f in os.listdir(data_dir) if f.endswith('.pt')
        ])

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        sample = torch.load(self.file_paths[idx])
        data = sample['data']         # shape: [J, 2, F, T]
        params = sample['params']     # shape: [P] or [1, P]

        if isinstance(params, torch.Tensor) and params.ndim == 2:
            params = params[0]  # [P]

        return data.float(), params.float()]

    def load_model(path):
        self.encoder1 = Encoder(embedding_dim)
        self.encoder2 = Encoder(embedding_dim)

        # Load the checkpoint
        checkpoint = torch.load(path, map_location='cuda' if torch.cuda.is_available() else 'cpu')

        # Load weights
        self.encoder1.load_state_dict(checkpoint['encoder1_state_dict'])
        self.encoder2.load_state_dict(checkpoint['encoder2_state_dict'])

        self.encoder1.eval()
        self.encoder2.eval()