In [1]:
import json
import torch
import ultraprint.common as p
import os
import random
import numpy as np
from tqdm import tqdm
from torch import nn
import librosa.display
import torch.nn.functional as F

In [2]:
data_dir = 'D:/Downloads/MusicBench'
file_dir = 'D:/Downloads/MusicBench/MusicBench/datashare'
model_dir = 'models/ranit/description'
model_name = model_dir+'/ranit_description_embedder_v1'

#create directories if they don't exist
os.makedirs(data_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   

In [3]:
from txtai.embeddings import Embeddings

# Create an embeddings index
embeddings = Embeddings({"path": "sentence-transformers/paraphrase-MiniLM-L6-v2"})

def transform(sentence):
    if type(sentence) == str:
        sentence = [sentence]
    return embeddings.batchtransform(sentence)



In [4]:
if os.path.exists(data_dir + '/train_data.pt'):
    p.green('Loading data from cache')
    data = torch.load(data_dir + '/train_data.pt')
else:
    # load "MusicBench_train.json"
    p.red('File not found, Loading data from file')
    data = []
    with open(data_dir + '/MusicBench_train.json', 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for line in tqdm(lines, desc="Loading data"):
            temp_json = json.loads(line.strip())
            data.append({
                'location': file_dir+"/"+temp_json['location'],
                'vector': transform(temp_json['main_caption'])
            })
    # save the data
    torch.save(data, data_dir + '/train_data.pt')

[92m Loading data from cache[00m


In [5]:
def get_batch(n=5):
    #get random files from withing each folder and for each also store the array containing the expected output
    x = []
    y = []
    labels = []
    for i in range(n):
        random_data = random.choice(data)
        random_integer = random.randint(0, 1)

        if random_integer == 0:
            x.append(random_data['location'])
            y.append(random_data['vector'][0])
            labels.append(1)
        else:
            another_random_data = random.choice(data)
            # make sure that the two vectors are not the same
            while another_random_data['location'] == random_data['location']:
                another_random_data = random.choice(data)

            x.append(random_data['location'])
            y.append(another_random_data['vector'][0])
            # calculate the cosine similarity between the two vectors
            labels.append(-1)

    #turn y into a tensor
    tensor_y = torch.tensor(y, dtype=torch.float32, device=device)
    labels = torch.tensor(labels, dtype=torch.float32, device=device)

    return x, tensor_y, labels

get_batch(5)

  tensor_y = torch.tensor(y, dtype=torch.float32, device=device)


(['D:/Downloads/MusicBench/MusicBench/datashare/data_aug2/tv14XEQcY0c_3.wav',
  'D:/Downloads/MusicBench/MusicBench/datashare/data_aug2/cmJj7SxQEp8_7.wav',
  'D:/Downloads/MusicBench/MusicBench/datashare/data_aug2/5-tx4Fgqetc_8.wav',
  'D:/Downloads/MusicBench/MusicBench/datashare/data_aug2/WK-gdfCurCg_7.wav',
  'D:/Downloads/MusicBench/MusicBench/datashare/data_aug2/AAP5pAB-4jM_5.wav'],
 tensor([[-0.0498, -0.0006, -0.1066,  ...,  0.0109, -0.0298,  0.0252],
         [-0.0587,  0.0015, -0.0023,  ...,  0.0860,  0.0409,  0.0053],
         [ 0.0565, -0.0086, -0.0462,  ...,  0.0653,  0.0678, -0.0262],
         [-0.0317,  0.0066, -0.0134,  ...,  0.0813, -0.0615, -0.0531],
         [ 0.0128, -0.0364,  0.0080,  ...,  0.0727,  0.0246,  0.0222]],
        device='cuda:0'),
 tensor([ 1., -1.,  1.,  1.,  1.], device='cuda:0'))

In [None]:
from desc_model import DescriptionEmbedder


def contrastive_loss(output1, output2, label, margin=1.0):
    # Compute Euclidean distance
    euclidean_distance = F.pairwise_distance(output1, output2)
    # Adjust label: map 1 -> 0 (similar), -1 -> 1 (dissimilar)
    label = (1 - label) / 2  # 1 -> 0, -1 -> 1
    # Contrastive loss formula
    loss = torch.mean(
        (1 - label) * torch.pow(euclidean_distance, 2) +
        label * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2)
    )
    return loss

def train(epochs=100, batch_size=5, learning_rate=0.001, weight_decay=1e-2):
    
    # Initialize model and optimizer
    # Model outputs 10 classes
    model = DescriptionEmbedder(device=device)
    model.load(model_name)
    model.train()
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)  # Switched to AdamW

    # Training loop
    for epoch in range(epochs):
        epoch_losses = []
        
        # Create progress bar for each epoch
        with tqdm(range(100), desc=f'Epoch {epoch+1}/{epochs}') as pbar:
            
            for _ in pbar:
                try:
                    # Get batch
                    audio_paths, vectors, labels = get_batch(batch_size)

                    # continue if batch size is less than expected
                    if len(audio_paths) != batch_size:
                        p.yellow(f'Batch size mismatch, skipping batch')
                        continue

                    # Forward pass
                    output_vectors = model(audio_paths)

                    # check if length of probabilities is equal to length of labels
                    if len(output_vectors) != len(vectors):
                        p.yellow(f'Length mismatch between probabilities and labels, skipping batch')
                        continue

                    # Check gradient
                    if not output_vectors.requires_grad:
                        p.yellow("Warning: probabilities lost gradient tracking")
                        continue
                
                    # Calculate loss
                    loss = contrastive_loss(output_vectors, vectors, labels)
                    
                    if not loss.requires_grad:
                        p.yellow("Warning: loss lost gradient tracking")
                        continue

                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    # Track progress
                    epoch_losses.append(loss.item())

                    pbar.set_postfix({'loss': np.mean(epoch_losses)})

                except Exception as e:
                    #print full stack trace
                    p.red(e)
                    continue
        
        # Save model checkpoint every epochs
        model.save(model_name)
        p.green(f'\nCheckpoint saved at epoch {epoch+1}')
        p.blue(f'Epoch {epoch+1} average loss: {np.mean(epoch_losses):.4f}')

# Start training
if __name__ == "__main__":
    train(batch_size=2)

In [15]:
def load_model(model_class, model_path, device='cuda'):
    model = model_class(device=device)
    model.load(model_path)
    model.eval()
    return model

# Load model
model = load_model(DescriptionEmbedder, model_name)

# Get batch (assuming get_batch is implemented elsewhere)
audio_paths, vectors, labels = get_batch(n=1)
print("Audio paths:", audio_paths)
print("Labels shape:", vectors.shape)

with torch.no_grad():
    output_vectors = model(audio_paths)

print("Labels:", labels)

#compute cosine similarity output_vectors, vectors
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
output = cos(output_vectors, vectors)
print("Cosine similarity:", output.item())

print("Output vectors shape:", output_vectors)

Audio paths: ['D:/Downloads/MusicBench/MusicBench/datashare/data/DdxW_JziHTA.wav']
Labels shape: torch.Size([1, 384])
Labels: tensor([-1.], device='cuda:0')
Cosine similarity: 0.7931150197982788
Output vectors shape: tensor([[-6.9185e-03, -9.2591e-03, -1.8954e-03, -9.7962e-02, -9.3861e-02,
          4.4947e-02,  6.2599e-02, -7.5256e-02,  4.9556e-02,  1.2668e-02,
          3.2759e-02,  1.9892e-02,  2.1348e-02, -6.8725e-02,  5.8483e-02,
          4.9880e-02,  8.6015e-02,  6.9803e-02,  2.0605e-02, -2.0048e-02,
         -1.7037e-03,  8.4226e-02, -2.2626e-02, -3.5646e-02, -6.0768e-02,
         -3.3919e-02, -2.3313e-02,  6.4171e-02,  3.4541e-02, -3.5094e-02,
          4.0316e-02,  1.0456e-01,  7.3339e-02, -4.0929e-02, -9.7383e-02,
          1.0786e-01, -6.3471e-02, -2.6066e-02, -1.0236e-01,  1.6811e-02,
         -3.1472e-02,  1.0054e-02, -1.1470e-03, -1.5678e-02, -4.3293e-02,
         -6.4183e-03, -5.7657e-02, -1.4711e-02,  3.5724e-02,  4.1702e-02,
         -8.7406e-02, -4.5196e-02,  4.1524e