In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# --- Define Image Encoder ---
class ImageEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 6 * 6, embedding_dim)
        )

    def forward(self, x):
        return self.encoder(x)


# --- Define Audio Encoder ---
class AudioEncoder(nn.Module):
    def __init__(self, embedding_dim, input_length):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv1d(1, 32, 3, stride=2),
            nn.ReLU(),
            nn.Conv1d(32, 64, 3, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(self._calculate_output_dim(input_length), embedding_dim)
        )

    def _calculate_output_dim(self, input_length):
        # This depends on your input audio length
        length = input_length
        length = (length - 3) // 2 + 1  # first conv
        length = (length - 3) // 2 + 1  # second conv
        return length * 64

    def forward(self, x):
        return self.encoder(x)


# --- Cosine Similarity ---
def cosine_similarity(a, b):
    a_norm = F.normalize(a, dim=-1)
    b_norm = F.normalize(b, dim=-1)
    return a_norm @ b_norm.T


# --- NT-Xent Loss ---
def nt_xent_loss(img_embeds, aud_embeds, temperature=0.07):
    batch_size = img_embeds.size(0)
    sim_matrix = cosine_similarity(img_embeds, aud_embeds) / temperature
    targets = torch.arange(batch_size, device=img_embeds.device)

    loss_img2audio = F.cross_entropy(sim_matrix, targets)
    loss_audio2img = F.cross_entropy(sim_matrix.T, targets)

    return (loss_img2audio + loss_audio2img) / 2


# --- Training Step with Alternating Updates ---
def train_alternating(image_encoder, audio_encoder, dataloader, optimizer_img, optimizer_audio, device, temperature=0.07):
    for batch in dataloader:
        images, audios = batch
        images, audios = images.to(device), audios.to(device)

        # --- Step 1: Update Image Encoder ---
        audio_encoder.eval()
        image_encoder.train()

        optimizer_img.zero_grad()
        img_embeds = image_encoder(images)
        with torch.no_grad():
            aud_embeds = audio_encoder(audios)

        loss_img = nt_xent_loss(img_embeds, aud_embeds, temperature)
        loss_img.backward()
        optimizer_img.step()

        # --- Step 2: Update Audio Encoder ---
        image_encoder.eval()
        audio_encoder.train()

        optimizer_audio.zero_grad()
        with torch.no_grad():
            img_embeds = image_encoder(images)
        aud_embeds = audio_encoder(audios)

        loss_audio = nt_xent_loss(img_embeds, aud_embeds, temperature)
        loss_audio.backward()
        optimizer_audio.step()


# --- Example Usage ---
embedding_dim = 128
input_audio_length = 16000  # Example length after resampling

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_encoder = ImageEncoder(embedding_dim).to(device)
audio_encoder = AudioEncoder(embedding_dim, input_audio_length).to(device)

optimizer_img = torch.optim.Adam(image_encoder.parameters(), lr=1e-4)
optimizer_audio = torch.optim.Adam(audio_encoder.parameters(), lr=1e-4)

# --- Dummy DataLoader Example ---
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, size):
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        img = torch.randn(1, 28, 28)
        audio = torch.randn(1, input_audio_length)
        return img, audio

dataloader = DataLoader(DummyDataset(100), batch_size=16, shuffle=True)

# --- Run Training Epoch ---
for epoch in range(5):
    train_alternating(image_encoder, audio_encoder, dataloader, optimizer_img, optimizer_audio, device)
    print(f"Finished epoch {epoch + 1}")


Finished epoch 1
Finished epoch 2
Finished epoch 3
Finished epoch 4
Finished epoch 5


In [None]:
image_encoder()