In [None]:
import os
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import numpy as np
from transformers import ViTModel, ViTFeatureExtractor, ASTFeatureExtractor
from transformers import ASTConfig, ASTModel

In [None]:
import cv2
def resize_spectrogram(spectrogram, target_size=1024):
    spectrogram = spectrogram.T
    current_size = spectrogram.shape[0]  # Get the first dimension (height)

    if current_size < target_size:
        padding = target_size - current_size
        padded_spectrogram = np.pad(spectrogram, ((0, padding), (0, 0)), mode='constant')
        return padded_spectrogram
    else:
        truncated_spectrogram = spectrogram[:target_size, :]
        return truncated_spectrogram

In [None]:
# Define the dataset class
class ImageAudioDataset(Dataset):
    def __init__(self, images_folder, spectrograms_folder, image_transform=None, spectrogram_transform=None):
        self.images_folder = images_folder
        self.spectrograms_folder = spectrograms_folder
        self.image_transform = image_transform
        self.spectrogram_transform = spectrogram_transform
        self.image_filenames = sorted([f for f in os.listdir(images_folder) if f.endswith('.jpg')])
        self.spectrogram_filenames = sorted([f for f in os.listdir(spectrograms_folder) if f.endswith('.jpg')])
        assert len(self.image_filenames) == len(self.spectrogram_filenames), "Mismatch between image and spectrogram files"

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        spectrogram_filename = self.spectrogram_filenames[idx]

        image_path = os.path.join(self.images_folder, image_filename)
        spectrogram_path = os.path.join(self.spectrograms_folder, spectrogram_filename)

        image = Image.open(image_path).convert('RGB')
        #spectrogram = Image.open(spectrogram_path).convert('L')
        spectrogram = cv2.imread(spectrogram_path, cv2.IMREAD_GRAYSCALE)
        spectrogram = resize_spectrogram(spectrogram)


        #print(spectrogram.shape)

        if self.image_transform:
            image = self.image_transform(image)
        if self.spectrogram_transform:
            spectrogram = self.spectrogram_transform(spectrogram)

        return image, spectrogram

# Define the models
class ImageFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

    def forward(self,x_image):
        image_outputs = self.vit(x_image)
        return image_outputs.last_hidden_state.mean(dim=1)  # Mean pooling

class AudioFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.ast = ASTModel.from_pretrained('MIT/ast-finetuned-audioset-10-10-0.4593')
        # # Initializing a AST MIT/ast-finetuned-audioset-10-10-0.4593 style configuration
        # configuration = ASTConfig(max_length = 128)

        # # Initializing a model (with random weights) from the MIT/ast-finetuned-audioset-10-10-0.4593 style configuration
        # self.ast = ASTModel(configuration).from_pretrained('MIT/ast-finetuned-audioset-10-10-0.4593')

    def forward(self,x_audio):
        audio_outputs = self.ast(x_audio.squeeze(dim=1))
        return audio_outputs.last_hidden_state.mean(dim=1)  # Mean pooling

class FocalAttention(nn.Module):
    def __init__(self, dim_image, dim_audio):
        super().__init__()
        self.dim_image = dim_image
        self.dim_audio = dim_audio
        self.fc_image = nn.Linear(dim_image, dim_audio)
        self.fc_audio = nn.Linear(dim_audio, dim_image)

    def forward(self, image_features, audio_features):
        image_features = self.fc_image(image_features)
        audio_features = self.fc_audio(audio_features)
        return image_features, audio_features

In [None]:
class ImageAudioMatchingModel(nn.Module):
    def __init__(self, image_feature_extractor, audio_feature_extractor, focal_attention):
        super().__init__()
        self.image_feature_extractor = image_feature_extractor
        self.audio_feature_extractor = audio_feature_extractor
        self.focal_attention = focal_attention

    def forward(self, image, audio):
        image_features = self.image_feature_extractor(image)
        audio_features = self.audio_feature_extractor(audio)
        image_embeddings, audio_embeddings = self.focal_attention(image_features, audio_features)
        return image_embeddings, audio_embeddings

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, alpha=0.2):
        super().__init__()
        self.alpha = alpha

    def forward(self, image_embeddings, audio_embeddings):
        # Calculate cosine similarity
        cos_sim = nn.functional.cosine_similarity(image_embeddings.unsqueeze(1), audio_embeddings.unsqueeze(0), dim=2)
        positive_pair_sim = torch.diagonal(cos_sim)
        hardest_negative_image = cos_sim.max(dim=1)[0]
        hardest_negative_audio = cos_sim.max(dim=0)[0]

        loss = torch.mean(
            torch.clamp(self.alpha - positive_pair_sim + hardest_negative_image, min=0) +
            torch.clamp(self.alpha - positive_pair_sim + hardest_negative_audio, min=0)
        )
        return loss

In [None]:
# Define image and audio transformations
image_transform = T.Compose([
    T.Resize((224, 224)),  # Resize to the size expected by ViT and AST
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

spectrogram_transform = T.Compose([
    #T.Resize((512,128)),
    T.ToTensor(),
    T.Normalize(mean=[0.5], std=[0.5]),
])

#spectrogram_transform = T.ToTensor()

In [None]:
# Initialize dataset and dataloader
images_folder = 'images'
spectrograms_folder = 'spectrograms'
dataset = ImageAudioDataset(images_folder=images_folder, spectrograms_folder=spectrograms_folder, image_transform=image_transform, spectrogram_transform=spectrogram_transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)

In [None]:
# Initialize models, loss function, and optimizer
image_feature_extractor = ImageFeatureExtractor()
audio_feature_extractor = AudioFeatureExtractor()
focal_attention = FocalAttention(dim_image=768, dim_audio=768)  # Adjust dimensions as needed
model = ImageAudioMatchingModel(image_feature_extractor, audio_feature_extractor, focal_attention)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = ContrastiveLoss(alpha=0.2)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
# Training loop
def train_model(num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for images, spectrograms in dataloader:
            optimizer.zero_grad()
            image_embeddings, audio_embeddings = model(images, spectrograms)
            loss = criterion(image_embeddings, audio_embeddings)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader)}')

In [None]:
# Start training
train_model()

  self.pid = os.fork()


Epoch [1/10], Loss: 0.4362578123807907


KeyboardInterrupt: 

In [None]:
# Define test dataset class (same as training dataset class but for test data)
class TestImageAudioDataset(Dataset):
    def __init__(self, images_folder, spectrograms_folder, image_transform=None, spectrogram_transform=None):
        self.images_folder = images_folder
        self.spectrograms_folder = spectrograms_folder
        self.image_transform = image_transform
        self.spectrogram_transform = spectrogram_transform
        self.image_filenames = sorted([f for f in os.listdir(images_folder) if f.endswith('.jpg')])
        self.spectrogram_filenames = sorted([f for f in os.listdir(spectrograms_folder) if f.endswith('.jpg')])
        assert len(self.image_filenames) == len(self.spectrogram_filenames), "Mismatch between image and spectrogram files"

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        spectrogram_filename = self.spectrogram_filenames[idx]

        image_path = os.path.join(self.images_folder, image_filename)
        spectrogram_path = os.path.join(self.spectrograms_folder, spectrogram_filename)

        image = Image.open(image_path).convert('RGB')
        spectrogram = cv2.imread(spectrogram_path, cv2.IMREAD_GRAYSCALE)
        spectrogram = resize_spectrogram(spectrogram)

        if self.image_transform:
            image = self.image_transform(image)
        if self.spectrogram_transform:
            spectrogram = self.spectrogram_transform(spectrogram)

        return image, spectrogram

# Initialize test dataset and dataloader
test_images_folder = 'images'
test_spectrograms_folder = 'spectrograms'
test_dataset = TestImageAudioDataset(images_folder=test_images_folder, spectrograms_folder=test_spectrograms_folder, image_transform=image_transform, spectrogram_transform=spectrogram_transform)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2)


In [None]:
def compute_recall_at_1(model, dataloader, device):
    model.eval()
    all_image_features = []
    all_audio_features = []
    all_labels = []

    with torch.no_grad():
        for images, spectrograms in dataloader:
            images = images.to(device)
            spectrograms = spectrograms.to(device)
            image_embeddings, audio_embeddings = model(images, spectrograms)

            all_image_features.append(image_embeddings.cpu())
            all_audio_features.append(audio_embeddings.cpu())
            all_labels.append(torch.arange(len(images)).cpu())

    # Concatenate all features and labels
    all_image_features = torch.cat(all_image_features, dim=0)
    all_audio_features = torch.cat(all_audio_features, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Calculate cosine similarity
    cos_sim = nn.functional.cosine_similarity(all_image_features.unsqueeze(1), all_audio_features.unsqueeze(0), dim=2)

    # Calculate Recall@1
    recall_at_1 = 0
    num_samples = len(all_image_features)

    for i in range(num_samples):
        sorted_indices = torch.argsort(cos_sim[i], descending=True)
        if sorted_indices[0] == i:
            recall_at_1 += 1

    recall_at_1 = recall_at_1 / num_samples
    return recall_at_1

In [None]:
# Move model to appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Calculate Recall@1
recall_at_1 = compute_recall_at_1(model, test_dataloader, device)
print(f'Recall@1: {recall_at_1:.4f}')

Recall@1: 0.0500
