<a href="https://colab.research.google.com/github/TongQM/SONAR_VLM/blob/main/sonar_embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install SONAR embeddings and requirements

In [None]:
# !pip show torch
# !pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124
# !pip install sonar-space

Name: torch
Version: 2.6.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /opt/homebrew/anaconda3/envs/vlm/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: fairseq2n, sonar-space, torchaudio


# Mount Google drive and set work directory

In [None]:
# import os
# from google.colab import drive

# drive.mount('/content/gdrive')
# os.chdir('/content/gdrive/My Drive/sonar')
# !ls annotations/

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


### Import necessary packages

In [164]:
import json
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.models import resnet50, resnet152, ResNet152_Weights, ResNet50_Weights
from torchsummary import summary

from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'mps' if torch.backends.mps.is_available() else device
print(f"Using device: {device}")

Using device: mps


## Load text2embedding model

In [208]:
t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder",
                                           tokenizer="text_sonar_basic_encoder",
                                           device=torch.device('cpu'),
                                           dtype=torch.float16)
sentences = ['My name is SONAR.', 'I can embed the sentences into vectorial space.']
embeddings = t2vec_model.predict(sentences, source_lang="eng_Latn")


vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder",
                                              tokenizer="text_sonar_basic_encoder",
                                              device=torch.device('cpu'),
                                              dtype=torch.float16)
reconstructed = vec2text_model.predict(embeddings, target_lang="eng_Latn", max_seq_len=512)
print("Original sentences:")
for sentence in sentences:
    print(sentence)
print("\nReconstructed sentences:")
for sentence in reconstructed:
    print(sentence)

Original sentences:
My name is SONAR.
I can embed the sentences into vectorial space.

Reconstructed sentences:
My name is SONAR.
I can embed the sentences into vector space.


## Load train and validation data

### Define dataloader with SONAR embeddings as labels

In [209]:
class COCOCaptionTextDataset(Dataset):
    """
    Returns (image_tensor, list_of_captions) for each idx.
    """
    def __init__(self, img_dir, coco_json, transform=None, numcaps=5):
        self.img_dir = Path(img_dir)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        with open(coco_json) as f:
            coco = json.load(f)
        # group captions by image_id
        caps = defaultdict(list)
        for ann in coco["annotations"]:
            caps[f"{ann['image_id']:012d}.jpg"].append(ann["caption"])
        
        # For each image, sort the captions and keep only the first 5
        for img, captions in caps.items():
            caps[img] = sorted(captions)[:numcaps]
        
        self.samples = sorted((str(self.img_dir / img), caps[img]) 
                              for img in caps)
        

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, captions = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        return img, captions


def collate_and_encode(batch, text_encoder, device="cuda"):
    """
    batch: list of (image_tensor, [cap1, cap2, ..., cap5])
    text_encoder: any model with .encode(list[str]) → Tensor(n_captions, D)
    """
    # 1) Stack images
    imgs = torch.stack([b[0] for b in batch], dim=0)

    # 2) Flatten out all captions
    all_caps = []
    counts   = []
    for _, caps in batch:
        counts.append(len(caps))
        all_caps.extend(caps)

    # 3) Run teacher in one shot
    #    Assume it returns a torch.Tensor of shape (sum(counts), D)
    with torch.no_grad():
        embs = text_encoder.predict(all_caps, source_lang="eng_Latn").to(device)
        # if it returns numpy, wrap: embs = torch.from_numpy(embs).to(device)

    # 4) Split & reduce per sample (e.g. mean)
    D = embs.size(1)
    labels = []
    idx = 0
    for n in counts:
        chunk = embs[idx: idx+n]      # shape (n, D)
        labels.append(chunk)  # → (D,)
        idx += n
    labels = torch.stack(labels, dim=0)  # → (batch_size, D)

    return imgs.to(device), labels


# ---- usage ----
train_dataset = COCOCaptionTextDataset(
    img_dir="./data/images/train2017",
    coco_json="./data/annotations/annotations/captions_train2017.json",
)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda b: collate_and_encode(b, t2vec_model, device="mps"),
)

val_dataset = COCOCaptionTextDataset(
    img_dir="./data/images/val2017",
    coco_json="./data/annotations/annotations/captions_val2017.json",
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    num_workers=0,
    pin_memory=True,
    collate_fn=lambda b: collate_and_encode(b, t2vec_model, device="mps"),
)

### Define multi-positive InfoNCE loss

In [193]:
def multi_pos_infonce_loss(img_embs, cap_embs, temperature):
    """
    img_embs: (B, D)  after F.normalize
    cap_embs: (B, m, D) after F.normalize
    """
    B, m, D = cap_embs.shape
    # flatten captions: (B*m, D)
    flat_caps = cap_embs.view(B*m, D)                             
    # similarity: (B, B*m)
    logits = img_embs @ flat_caps.t() / temperature               

    # numerator: sum over each image's m positives
    # we know that positives for image i are indices [i*m : i*m + m]
    pos_mask = torch.zeros_like(logits, dtype=torch.bool)
    for i in range(B):
        start = i * m
        pos_mask[i, start : start + m] = True

    # exp(logits)
    exp_logits = logits.exp()
    numerator   = exp_logits.masked_select(pos_mask).view(B, m).sum(dim=1) 
    denominator = exp_logits.sum(dim=1)                                 
    loss = -torch.log(numerator / denominator).mean()
    return loss

### Resnet50 Encoder

In [200]:
class ResNet50Embedder(nn.Module):
    """
    ResNet50 backbone producing fixed-size embeddings (e.g., 1024-D).

    Args:
        pretrained (bool): If True, loads ImageNet-pretrained weights.
        embedding_dim (int): Dimensionality of the output embedding.
    """
    def __init__(self, pretrained: bool = True, embedding_dim: int = 1024, weights=ResNet50_Weights.DEFAULT):
        super().__init__()
        # Load ResNet50 backbone
        if pretrained:
            # Use pretrained weights
            weights = ResNet50_Weights.DEFAULT
        base_model = resnet50(weights=weights)
        # Save the feature dimensionality for projection
        in_features = base_model.fc.in_features
        # Replace the final fully connected layer with a new one
        base_model.fc = nn.Linear(in_features, embedding_dim)
        # Initialize the new layer
        nn.init.xavier_uniform_(base_model.fc.weight)
        self.backbone = base_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: image -> embedding

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).

        Returns:
            torch.Tensor: Embedding of shape (batch_size, embedding_dim).
        """
        embeddings = self.backbone(x)   # shape: (batch_size, embedding_dim)
        return embeddings
    
    def load_weights(self, path: str):
        """
        Load weights from a file.

        Args:
            path (str): Path to the weights file.
        """
        state_dict = torch.load(path, weights_only=True)
        self.backbone.load_state_dict(state_dict)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else device)
# Example: produce 1024-D embeddings
encoder = ResNet50Embedder(pretrained=True, embedding_dim=1024).to(device)
# summary(model, (3, 224, 224), device=device.type)
dummy_input = torch.randn(10, 3, 224, 224, device=device)
embed = encoder(dummy_input)
print(f"Output embedding shape: {embed.shape}")


Output embedding shape: torch.Size([10, 1024])


In [210]:
# --- Validation function ---
def validate(model: nn.Module, loader: DataLoader, val_criterion, device: torch.device) -> float:
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for images, caption_embs in loader:
            images, caption_embs = images.to(device), caption_embs.to(device)
            img_embs = model(images)                   # (B, D)
            # normalize
            img_norm = F.normalize(img_embs,    dim=1) # (B, D)
            cap_norm = F.normalize(caption_embs, dim=1) # (B, D)
            # CosineEmbeddingLoss needs a target of +1 for each pair
            targets = torch.ones(images.size(0), device=device)
            loss = val_criterion(img_norm, cap_norm, targets)
            total_loss += loss.item() * images.size(0)
    avg_loss = total_loss / len(loader.dataset)
    model.train()
    return avg_loss

In [211]:
# Hyperparameters
lr = 1e-4
weight_decay = 1e-2
temperature = 0.07
num_epochs = 10

# Setup
# train_criterion = nn.CrossEntropyLoss()
val_criterion = nn.CosineEmbeddingLoss()
optimizer = torch.optim.AdamW(encoder.parameters(), lr=lr, weight_decay=weight_decay)
scaler = torch.amp.GradScaler()
encoder.train()

for epoch in range(1, num_epochs + 1):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
    for images, caption_embs in pbar:
        images, caption_embs = images.to(device), caption_embs.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast(device_type='mps', dtype=torch.float16):
            img_embs = encoder(images)  # (B, D)
            img_norm = F.normalize(img_embs, dim=1)      # (B, D)
            cap_norm = F.normalize(caption_embs, dim=2)  # (B, m, D)
            loss = multi_pos_infonce_loss(img_embs, cap_norm, temperature)

        # Backward with mixed precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        pbar.set_postfix(train_loss=f"{loss.item():.4f}")

    val_loss = validate(encoder, val_loader, val_criterion, device)
    print(f"Epoch {epoch} Validation CosineEmbeddingLoss: {val_loss:.4f}")

Epoch 1/10:   1%|          | 32/3697 [05:07<9:47:50,  9.62s/it, train_loss=nan]


KeyboardInterrupt: 

### Resnet152 Encoder

In [95]:
class ResNet152Embedder(nn.Module):
    """
    ResNet50 backbone producing fixed-size embeddings (e.g., 1024-D).

    Args:
        pretrained (bool): If True, loads ImageNet-pretrained weights.
        embedding_dim (int): Dimensionality of the output embedding.
    """
    def __init__(self, pretrained: bool = True, embedding_dim: int = 1024):
        super().__init__()
        # Load ResNet50 backbone
        weights = ResNet152_Weights.DEFAULT if pretrained else None
        base_model = resnet152(weights=weights)
        # Save the feature dimensionality for projection
        in_features = base_model.fc.in_features
        # Replace the final fully connected layer with a new one
        base_model.fc = nn.Linear(in_features, embedding_dim)
        # Initialize the new layer
        nn.init.xavier_uniform_(base_model.fc.weight)
        self.backbone = base_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: image -> embedding

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).

        Returns:
            torch.Tensor: Embedding of shape (batch_size, embedding_dim).
        """
        embeddings = self.backbone(x)   # shape: (batch_size, embedding_dim)
        return embeddings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example: produce 1024-D embeddings
model = ResNet152Embedder(pretrained=True, embedding_dim=1024).to(device)
summary(model, (3, 224, 224), device=device.type)
dummy_input = torch.randn(10, 3, 224, 224, device=device)
embed = model(dummy_input)
print(f"Output embedding shape: {embed.shape}")

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /Users/miaoyidi/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth
100%|██████████| 230M/230M [00:04<00:00, 54.1MB/s] 


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,