In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset, Dataset
import os
import numpy as np
from sklearn.model_selection import train_test_split

os.chdir("/root/dev/vcmr")


In [2]:
# HPARAMS
num_epochs = 1000
learning_rate = 1e-4
batch_size = 64

In [3]:
# PRELOAD
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_dir = './image_features_train'
music_dir = './music_features_train'

files = sorted([f for f in os.listdir(image_dir) if f.endswith('.npy')])
filenames = [os.path.splitext(f)[0] for f in files]
image_embeddings = []
music_embeddings = []
for f in files:
    image_embed = np.load(os.path.join(image_dir, f)).astype(np.float32)
    music_embed = np.load(os.path.join(music_dir, f)).astype(np.float32)
    image_embeddings.append(image_embed)
    music_embeddings.append(music_embed)
image_embeddings = torch.tensor(image_embeddings, dtype=torch.float32)
music_embeddings = torch.tensor(music_embeddings, dtype=torch.float32)

num_samples = len(filenames)
num_image_embeddings = image_embeddings.size(1)
num_music_embeddings = music_embeddings.size(1)
print(num_samples, num_image_embeddings, num_music_embeddings)

train_indices, val_indices = train_test_split(list(range(num_samples)), test_size=0.2, random_state=42)

  image_embeddings = torch.tensor(image_embeddings, dtype=torch.float32)


37656 512 512


In [4]:
class PairDataset(Dataset):
    def __init__(self, image_embeds, music_embeds, indices, positive=True):
        self.image_embeds = image_embeds
        self.music_embeds = music_embeds
        self.indices = indices
        self.positive = positive
        self.labels = torch.ones(len(indices), dtype=torch.float32) if positive else torch.ones(len(indices), dtype=torch.float32) * -1.0

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        i = self.indices[idx]
        img = self.image_embeds[i]
        if self.positive:
            music = self.music_embeds[i]
            label = self.labels[idx]
        else:
            j = (i + 4) % self.image_embeds.size(0)
            music = self.music_embeds[j]
            label = self.labels[idx]
        return img, music, label

train_positive_dataset = PairDataset(image_embeddings, music_embeddings, train_indices, positive=True)
train_negative_dataset = PairDataset(image_embeddings, music_embeddings, train_indices, positive=False)
val_positive_dataset = PairDataset(image_embeddings, music_embeddings, val_indices, positive=True)
val_negative_dataset = PairDataset(image_embeddings, music_embeddings, val_indices, positive=False)
train_dataset = ConcatDataset([train_positive_dataset, train_negative_dataset])
val_dataset = ConcatDataset([val_positive_dataset, val_negative_dataset])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [None]:

# LOSS
def cosine_similarity_loss(image_embeds, music_embeds, labels):
    loss = F.cosine_embedding_loss(image_embeds, music_embeds, labels)
    return loss

# MODEL
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=512, output_dim=128):
        super(ProjectionHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )
    def forward(self, x):
        return self.mlp(x)

image_proj = ProjectionHead(input_dim=512, output_dim=128).to(device)
music_proj = ProjectionHead(input_dim=512, output_dim=128).to(device)

# OPT
optimizer = torch.optim.Adam(
    list(image_proj.parameters()) + list(music_proj.parameters()), lr=learning_rate
)

from torch.amp import autocast, GradScaler

scaler = GradScaler()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)



In [None]:
for epoch in range(num_epochs):
    image_proj.train()
    music_proj.train()
    train_loss = 0.0
    for image_embed, music_embed, label in train_loader:
        image_embed = image_embed.to(device, non_blocking=True)
        music_embed = music_embed.to(device, non_blocking=True)
        label = label.to(device, non_blocking=True)
        optimizer.zero_grad()

        with autocast("cuda"):
            projected_image = image_proj(image_embed)
            projected_music = music_proj(music_embed)
            loss = cosine_similarity_loss(projected_image, projected_music, label)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    image_proj.eval()
    music_proj.eval()
    val_loss = 0.0
    with torch.no_grad():
        for image_embed, music_embed, label in val_loader:
            image_embed = image_embed.to(device, non_blocking=True)
            music_embed = music_embed.to(device, non_blocking=True)
            label = label.to(device, non_blocking=True)

            with autocast("cuda"):
                projected_image = image_proj(image_embed)
                projected_music = music_proj(music_embed)
                loss = cosine_similarity_loss(projected_image, projected_music, label)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    scheduler.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

Epoch 1/1000, Train Loss: 0.3883, Val Loss: 0.3587
Epoch 2/1000, Train Loss: 0.3395, Val Loss: 0.3536
Epoch 3/1000, Train Loss: 0.3157, Val Loss: 0.3510
Epoch 4/1000, Train Loss: 0.2935, Val Loss: 0.3528
Epoch 5/1000, Train Loss: 0.2733, Val Loss: 0.3505
Epoch 6/1000, Train Loss: 0.2274, Val Loss: 0.3510
Epoch 7/1000, Train Loss: 0.2130, Val Loss: 0.3542
Epoch 8/1000, Train Loss: 0.2051, Val Loss: 0.3554
Epoch 9/1000, Train Loss: 0.1977, Val Loss: 0.3564
Epoch 10/1000, Train Loss: 0.1922, Val Loss: 0.3566
Epoch 11/1000, Train Loss: 0.1826, Val Loss: 0.3579
Epoch 12/1000, Train Loss: 0.1824, Val Loss: 0.3589
Epoch 13/1000, Train Loss: 0.1805, Val Loss: 0.3590
Epoch 14/1000, Train Loss: 0.1799, Val Loss: 0.3590
Epoch 15/1000, Train Loss: 0.1798, Val Loss: 0.3591
Epoch 16/1000, Train Loss: 0.1785, Val Loss: 0.3593
Epoch 17/1000, Train Loss: 0.1781, Val Loss: 0.3599
Epoch 18/1000, Train Loss: 0.1778, Val Loss: 0.3598
Epoch 19/1000, Train Loss: 0.1782, Val Loss: 0.3598
Epoch 20/1000, Train 