<a href="https://colab.research.google.com/github/TongQM/SONAR_VLM/blob/main/sonar_embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install SONAR embeddings and requirements

In [116]:
!pip show torch
# !pip install fairseq2 --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/pt2.6.0/cu124
# !pip install sonar-space

Name: torch
Version: 2.6.0+cu124
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /usr/local/lib/python3.11/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, fairseq2n, fastai, peft, sentence-transformers, sonar-space, timm, torchaudio, torchvision


# Mount Google drive and set work directory

In [3]:
import os
from google.colab import drive

drive.mount('/content/gdrive')
os.chdir('/content/')
# !ls annotations/

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


## Download Data

In [9]:
!mkdir images
!cd images
# # Download the images from the COCO dataset
!wget http://images.cocodataset.org/zips/train2017.zip
!wget http://images.cocodataset.org/zips/val2017.zip
!wget http://images.cocodataset.org/zips/test2017.zip
# wget http://images.cocodataset.org/zips/unlabeled2017.zip

# Unzip the images
!unzip train2017.zip
!unzip val2017.zip
!unzip test2017.zip
# unzip unlabeled2017.zip

# # Remove the zip files
# !rm train2017.zip
# !rm val2017.zip
# !rm test2017.zip
# !rm unlabeled2017.zip

# # Download the annotations from the COCO dataset
!cd ../
!mkdir annotations
!cd annotations
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
# wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip
!wget http://images.cocodataset.org/annotations/image_info_test2017.zip
# wget http://images.cocodataset.org/annotations/image_info_unlabeled2017.zip

# # Unzip the annotations
!unzip annotations_trainval2017.zip
# unzip stuff_annotations_trainval2017.zip
!unzip image_info_test2017.zip
# unzip image_info_unlabeled2017.zip

# # Remove the zip files
# !rm annotations_trainval2017.zip
# !rm stuff_annotations_trainval2017.zip
# !rm image_info_test2017.zip
# !rm image_info_unlabeled2017.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 extracting: test2017/000000318429.jpg  
 extracting: test2017/000000381787.jpg  
 extracting: test2017/000000287482.jpg  
 extracting: test2017/000000459614.jpg  
 extracting: test2017/000000010879.jpg  
 extracting: test2017/000000044270.jpg  
 extracting: test2017/000000233321.jpg  
 extracting: test2017/000000515253.jpg  
 extracting: test2017/000000562454.jpg  
 extracting: test2017/000000219662.jpg  
 extracting: test2017/000000270965.jpg  
 extracting: test2017/000000497647.jpg  
 extracting: test2017/000000273880.jpg  
 extracting: test2017/000000413721.jpg  
 extracting: test2017/000000546526.jpg  
 extracting: test2017/000000189624.jpg  
 extracting: test2017/000000432585.jpg  
 extracting: test2017/000000464855.jpg  
 extracting: test2017/000000489326.jpg  
 extracting: test2017/000000202481.jpg  
 extracting: test2017/000000572501.jpg  
 extracting: test2017/000000188680.jpg  
 extracting: test2017/00000040087

### Import necessary packages

In [123]:
import json
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from PIL import Image
import requests
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from torchvision import transforms
import torch.nn.functional as F
from torch.nn import CosineEmbeddingLoss
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.models import resnet50, resnet152, ResNet152_Weights, ResNet50_Weights
from torchsummary import summary

from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'mps' if torch.backends.mps.is_available() else device
print(f"Using device: {device}")

Using device: cuda


## Load text2embedding model

In [124]:
t2vec_model = TextToEmbeddingModelPipeline(encoder="text_sonar_basic_encoder",
                                           tokenizer="text_sonar_basic_encoder",
                                           device=torch.device(device),
                                           dtype=torch.float16)
sentences = ['I want to get an intern offer pls', 'Why I have so bad luck']
embeddings = t2vec_model.predict(sentences, source_lang="eng_Latn")


vec2text_model = EmbeddingToTextModelPipeline(decoder="text_sonar_basic_decoder",
                                              tokenizer="text_sonar_basic_encoder",
                                              device=torch.device(device),
                                              dtype=torch.float16)
reconstructed = vec2text_model.predict(embeddings, target_lang="eng_Latn", max_seq_len=512)
print("Original sentences:")
for sentence in sentences:
    print(sentence)
print("\nReconstructed sentences:")
for sentence in reconstructed:
    print(sentence)

Original sentences:
I want to get an intern offer pls
Why I have so bad luck

Reconstructed sentences:
I want to get an internship offer pls
Why I have such bad luck


## Load train and validation data

### Precompute the embeddings of captions for each image

In [141]:
import json
import torch
from collections import defaultdict
from tqdm import tqdm

def precompute_caption_embeddings(
    coco_json: str,
    text_encoder,
    output_path: str,
    numcaps: int = 5,
    batch_size: int = 512,
    device: str = "cpu",
):
    """
    Reads the COCO captions JSON, groups & truncates to `numcaps` per image,
    encodes *all* captions through your text_encoder, then splits them back
    into a dict mapping `"<image_id>.jpg"` → Tensor of shape (numcaps, D),
    and torch.saves that dict to `output_path`.
    """
    # 1) load & group captions
    with open(coco_json) as f:
        coco = json.load(f)
    caps = defaultdict(list)
    for ann in coco["annotations"]:
        key = f"{ann['image_id']:012d}.jpg"
        caps[key].append(ann["caption"])
    # truncate
    for k in caps:
        caps[k] = sorted(caps[k])[:numcaps]

    # 2) flatten all captions
    items   = list(caps.items())                # [(img_name, [cap1,…]), …]
    all_caps = [cap for _, clist in items for cap in clist]

    # 3) encode in batches for memory‐efficiency
    text_encoder.to(device).eval()
    all_embs = []
    with torch.no_grad():
        for i in tqdm(range(0, len(all_caps), batch_size)):
            batch = all_caps[i : i + batch_size]
            emb = text_encoder.predict(batch, source_lang="eng_Latn")
            emb = torch.as_tensor(emb) if not isinstance(emb, torch.Tensor) else emb
            all_embs.append(emb.cpu())
    all_embs = torch.cat(all_embs, dim=0)       # (sum(counts), D)

    # 4) split back per image
    embedding_map = {}
    idx = 0
    D   = all_embs.size(1)
    for img_name, clist in items:
        m = len(clist)
        embedding_map[img_name] = all_embs[idx : idx + m]  # (m, D)
        idx += m

    # 5) save to disk
    torch.save(embedding_map, output_path)

In [143]:
precompute_caption_embeddings(
    coco_json     = "./data/captions_val2017.json",
    text_encoder  = t2vec_model,
    output_path   = "/content/coco2017_valcaption_embs.pt",
    numcaps       = 5,
    batch_size    = 1024,
    device        = device,
)

100%|██████████| 25/25 [01:14<00:00,  2.98s/it]


In [156]:
# Load valcaption_embs
val_caption_embs = torch.load("/content/coco2017_valcaption_embs.pt")
sample_embs = val_caption_embs['000000519491.jpg']

# The original captions
with open("./data/captions_val2017.json") as f:
    coco = json.load(f)
    caps = defaultdict(list)
    for ann in coco["annotations"]:
        key = f"{ann['image_id']:012d}.jpg"
        caps[key].append(ann["caption"])
print(f"The original captions are {caps['000000519491.jpg']}")
print(f"The deciphered captions are {vec2text_model.predict(sample_embs, target_lang='eng_Latn', max_seq_len=512)}.")

The original captions are ['A tall clock tower with a statue on top.', 'There is a clock in the top of a tall tower', 'A large clock tower with a gargoyle atop sits in front of a clear blue sky.', 'A large clock tower with a statue on the top.', 'Clock tower with a bronze statue on top on a sunny day. ']
The deciphered captions are ['A large clock tower with a gargoyle on top sits in front of a clear blue sky.', 'A large clock tower with a statue on the top.', 'A tall clock tower with a statue on top.', 'Clock tower with a bronze statue on top in a sunny day.', 'There is a clock in the top of a tall tower'].


In [154]:
# Download coco2017_valcaption_embs.pt
from google.colab import files
files.download('/content/coco2017_valcaption_embs.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [157]:
precompute_caption_embeddings(
    coco_json     = "./data/captions_train2017.json",
    text_encoder  = t2vec_model,
    output_path   = "/content/coco2017_traincaption_embs.pt",
    numcaps       = 5,
    batch_size    = 1024,
    device        = device,
)

100%|██████████| 578/578 [29:05<00:00,  3.02s/it]


In [158]:
files.download('/content/coco2017_traincaption_embs.pt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Define dataloader with SONAR embeddings as labels

In [125]:
class COCOCaptionTextDataset(Dataset):
    """
    Returns (image_tensor, list_of_captions) for each idx.
    """
    def __init__(self, img_dir, coco_json, text_encoder, transform=None, numcaps=5, subset=1.0):
        self.img_dir = Path(img_dir)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.text_encoder = text_encoder
        self.numcaps = numcaps

        with open(coco_json) as f:
            coco = json.load(f)
        # group captions by image_id
        caps = defaultdict(list)
        for ann in coco["annotations"]:
            caps[f"{ann['image_id']:012d}.jpg"].append(ann["caption"])

        # For each image, sort the captions and keep only the first 5
        for img, captions in caps.items():
            caps[img] = sorted(captions)[:numcaps]

        self.samples = sorted((str(self.img_dir / img), caps[img])
                              for img in caps)

        # If subset < 1.0, take the first subset proportion sample of the dataset
        if subset < 1.0:
            total = len(self.samples)
            k = max(1, int(total * subset))
            self.samples = self.samples[:k]

        self.length = len(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, captions = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        return img, captions

    def collate_fn(self, batch):
        # 1) Stack images
        imgs = torch.stack([b[0] for b in batch], dim=0)

        # 2) Flatten out all captions
        all_caps = []
        counts   = []
        for _, caps in batch:
            counts.append(len(caps))
            all_caps.extend(caps)

        # 3) Run teacher in one shot
        #    Assume it returns a torch.Tensor of shape (sum(counts), D)
        with torch.no_grad():
            embs = self.text_encoder.predict(all_caps, source_lang="eng_Latn").to(device)
            # if it returns numpy, wrap: embs = torch.from_numpy(embs).to(device)

        # 4) Split & reduce per sample (e.g. mean)
        D = embs.size(1)
        labels = []
        idx = 0
        for n in counts:
            chunk = embs[idx: idx+n]      # shape (n, D)
            labels.append(chunk)  # → (D,)
            idx += n
        labels = torch.stack(labels, dim=0)  # → (batch_size, D)

        return imgs.to(device), labels


def collate_and_encode(batch, text_encoder, device="cuda"):
    """
    batch: list of (image_tensor, [cap1, cap2, ..., cap5])
    text_encoder: any model with .encode(list[str]) → Tensor(n_captions, D)
    """
    # 1) Stack images
    imgs = torch.stack([b[0] for b in batch], dim=0)

    # 2) Flatten out all captions
    all_caps = []
    counts   = []
    for _, caps in batch:
        counts.append(len(caps))
        all_caps.extend(caps)

    # 3) Run teacher in one shot
    #    Assume it returns a torch.Tensor of shape (sum(counts), D)
    with torch.no_grad():
        embs = text_encoder.predict(all_caps, source_lang="eng_Latn").to(device)
        # if it returns numpy, wrap: embs = torch.from_numpy(embs).to(device)

    # 4) Split & reduce per sample (e.g. mean)
    D = embs.size(1)
    labels = []
    idx = 0
    for n in counts:
        chunk = embs[idx: idx+n]      # shape (n, D)
        labels.append(chunk)  # → (D,)
        idx += n
    labels = torch.stack(labels, dim=0)  # → (batch_size, D)

    return imgs.to(device), labels


# ---- usage ----
train_dataset = COCOCaptionTextDataset(
    img_dir="./data/images/train2017",
    coco_json="./data/captions_train2017.json",
    text_encoder=t2vec_model,
    numcaps=5,
    subset=0.1,  # 0.1 = 10% of the dataset
)
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=0,
    pin_memory=False,
    collate_fn=train_dataset.collate_fn,
)

val_dataset = COCOCaptionTextDataset(
    img_dir="./data/images/val2017",
    coco_json="./data/captions_val2017.json",
    text_encoder=t2vec_model,
    numcaps=5
)
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    num_workers=0,
    pin_memory=False,
    collate_fn=val_dataset.collate_fn,
)

In [126]:
print(f"There are {len(train_dataset)} training samples.")
print(f"There are {len(val_dataset)} validation samples.")

There are 11828 training samples.
There are 5000 validation samples.


## Define the model

### Define multi-positive InfoNCE loss

In [127]:
def multi_pos_infonce_loss(img_embs, cap_embs, temperature):
    """
    img_embs: (B, D)  after F.normalize
    cap_embs: (B, m, D) after F.normalize
    """
    B, m, D = cap_embs.shape
    # flatten captions: (B*m, D)
    flat_caps = cap_embs.view(B*m, D)
    # similarity: (B, B*m)
    logits = img_embs @ flat_caps.t() / temperature

    # numerator: sum over each image's m positives
    # we know that positives for image i are indices [i*m : i*m + m]
    pos_mask = torch.zeros_like(logits, dtype=torch.bool)
    for i in range(B):
        start = i * m
        pos_mask[i, start : start + m] = True

    # exp(logits)
    exp_logits = logits.exp()
    numerator   = exp_logits.masked_select(pos_mask).view(B, m).sum(dim=1)
    denominator = exp_logits.sum(dim=1)
    loss = -torch.log(numerator / denominator).mean()
    return loss

### Resnet50 Encoder

In [133]:
class ResNet50Embedder(nn.Module):
    """
    ResNet50 backbone producing fixed-size embeddings (e.g., 1024-D).

    Args:
        pretrained (bool): If True, loads ImageNet-pretrained weights.
        embedding_dim (int): Dimensionality of the output embedding.
    """
    def __init__(self, pretrained: bool = True, embedding_dim: int = 1024, weights=ResNet50_Weights.DEFAULT):
        super().__init__()
        # Load ResNet50 backbone
        if pretrained:
            # Use pretrained weights
            weights = ResNet50_Weights.DEFAULT
        base_model = resnet50(weights=weights)
        # Save the feature dimensionality for projection
        in_features = base_model.fc.in_features
        # Replace the final fully connected layer with a new one
        base_model.fc = nn.Linear(in_features, embedding_dim)
        # Initialize the new layer
        nn.init.xavier_uniform_(base_model.fc.weight)
        self.backbone = base_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: image -> embedding

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).

        Returns:
            torch.Tensor: Embedding of shape (batch_size, embedding_dim).
        """
        embeddings = self.backbone(x)   # shape: (batch_size, embedding_dim)
        return embeddings

    def load_weights(self, path: str):
        """
        Load weights from a file.

        Args:
            path (str): Path to the weights file.
        """
        state_dict = torch.load(path, weights_only=True)
        self.backbone.load_state_dict(state_dict)

device = "cuda" if torch.cuda.is_available() else "cpu"
# Example: produce 1024-D embeddings
encoder = ResNet50Embedder(pretrained=True, embedding_dim=1024).to(device)
# summary(model, (3, 224, 224), device=device.type)
dummy_input = torch.randn(10, 3, 224, 224, device=device)
embed = encoder(dummy_input)
print(f"Output embedding shape: {embed.shape}")


Output embedding shape: torch.Size([10, 1024])


In [None]:
def train_epoch(model: nn.Module,
          train_loader: DataLoader,
          optimizer: torch.optim.Optimizer,
          scaler: torch.cuda.amp.GradScaler,
          criterion: nn.Module,
          device: str,
          epoch: int = 1,
          num_epochs: int = 10) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    cum_loss, cnt = 0.0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
    for images, caption_embs in pbar:
        images, caption_embs = images.to(device), caption_embs.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast(device_type=device, dtype=torch.float16):
            img_embs = encoder(images)  # (B, D)
            img_norm = F.normalize(img_embs, dim=-1)      # (B, D)
            cap_norm = F.normalize(caption_embs, dim=-1)  # (B, m, D)
            loss = criterion(img_norm, cap_norm, temperature=0.1)
            cum_loss += loss.item()
            cnt += 1

        pbar.set_postfix(train_loss=f"{loss.item():.4f}")

        # Backward with mixed precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return cum_loss / cnt

In [134]:
# --- Validation function ---
@torch.no_grad()
def validate(model: nn.Module, loader: DataLoader, device: torch.device, margin: float = 0.0) -> float:
    """
    Runs one epoch of validation, returning the average CosineEmbeddingLoss
    over all (positive + negative) pairs.

    Positives: (image_i, each of its m captions) → target +1
    Negatives: (image_i, each of the next image’s m captions) → target -1
    """
    model.eval()
    loss_fn = CosineEmbeddingLoss(margin=margin, reduction="mean")
    total_loss = 0.0
    n_batches  = 0

    with torch.no_grad():
        pbar = tqdm(loader, desc=f"Validation")
        for images, cap_embs in pbar:
            images, cap_embs = images.to(device), cap_embs.to(device)
            B, m, D = cap_embs.shape

            # 1) Compute normalized image and caption embeddings
            img_embs  = model(images)                    # (B, D)
            img_norm  = F.normalize(img_embs,  dim=-1)    # (B, D)
            cap_norm  = F.normalize(cap_embs,    dim=-1)  # (B, m, D)

            # 2) Build positive pairs
            #    (B, D) → (B, 1, D) → (B, m, D) → (B*m, D)
            img_pos   = img_norm.unsqueeze(1).expand(-1,m,-1).reshape(B*m, D)
            cap_pos   = cap_norm.reshape(B*m, D)
            pos_tgt   = torch.ones(B*m, device=device)    # +1

            # 3) Build negative pairs by “rolling” captions by 1 in batch
            cap_neg   = cap_norm.roll(shifts=1, dims=0)   # (B, m, D)
            img_neg   = img_pos                            # same images
            cap_neg   = cap_neg.reshape(B*m, D)
            neg_tgt   = -torch.ones(B*m, device=device)   # -1

            # 4) Concatenate into one big batch
            all_imgs = torch.cat([img_pos, img_neg], dim=0)  # (2*B*m, D)
            all_caps = torch.cat([cap_pos, cap_neg], dim=0)  # (2*B*m, D)
            all_lbls = torch.cat([pos_tgt, neg_tgt], dim=0)  # (2*B*m,)

            # 5) Compute loss
            loss = loss_fn(all_imgs, all_caps, all_lbls)
            total_loss += loss.item()
            n_batches  += 1

            pbar.set_postfix(curr_loss=f"{loss.item():.4f}")

    return total_loss / n_batches

In [135]:
def load_model(model: nn.Module, path: str):
    """
    Load weights from a file.
    """
    state_dict = torch.load(path, weights_only=True)
    model.load_state_dict(state_dict)
    print(f"Loaded weights from {path}")
    return model

def save_model(model: nn.Module, path: str):
    """
    Save weights to a file.
    """
    torch.save(model.state_dict(), path)
    print(f"Saved weights to {path}")

In [136]:
# Hyperparameters
lr = 1e-4
weight_decay = 1e-2
temperature = 0.07
num_epochs = 10

# Setup
# train_criterion = nn.CrossEntropyLoss()
val_criterion = nn.CosineEmbeddingLoss()
optimizer = torch.optim.AdamW(encoder.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
scaler = torch.amp.GradScaler()
encoder.train()

best_val_loss = float("inf")

for epoch in range(1, num_epochs + 1):
    # Train
    train_loss = train_epoch(encoder, train_loader, optimizer, scaler, multi_pos_infonce_loss, device, epoch=epoch, num_epochs=num_epochs)
    val_loss = validate(encoder, val_loader, device, margin=0.0)
    scheduler.step(val_loss)
    print(f"Epoch {epoch} Training InfoNCE Loss: {train_loss:.4f}")
    print(f"Epoch {epoch} Validation CosineEmbeddingLoss: {val_loss:.4f}")

    save_model(encoder, f"last_epoch.pth")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_model(encoder, f"best_encoder.pth")

Epoch 1/10: 100%|██████████| 370/370 [05:39<00:00,  1.09it/s, train_loss=0.8017]
Validation: 100%|██████████| 157/157 [02:10<00:00,  1.20it/s, curr_loss=0.3486]


Epoch 1 Training InfoNCE Loss: 1.6766
Epoch 1 Validation CosineEmbeddingLoss: 0.3609
Saved weights to last_epoch.pth
Saved weights to best_encoder.pth


Epoch 2/10: 100%|██████████| 370/370 [05:05<00:00,  1.21it/s, train_loss=0.5435]
Validation: 100%|██████████| 157/157 [02:02<00:00,  1.28it/s, curr_loss=0.3363]


Epoch 2 Training InfoNCE Loss: 0.9283
Epoch 2 Validation CosineEmbeddingLoss: 0.3492
Saved weights to last_epoch.pth
Saved weights to best_encoder.pth


Epoch 3/10:  83%|████████▎ | 307/370 [04:13<00:52,  1.21it/s, train_loss=0.6531]


KeyboardInterrupt: 

### Resnet152 Encoder

In [None]:
class ResNet152Embedder(nn.Module):
    """
    ResNet50 backbone producing fixed-size embeddings (e.g., 1024-D).

    Args:
        pretrained (bool): If True, loads ImageNet-pretrained weights.
        embedding_dim (int): Dimensionality of the output embedding.
    """
    def __init__(self, pretrained: bool = True, embedding_dim: int = 1024):
        super().__init__()
        # Load ResNet50 backbone
        weights = ResNet152_Weights.DEFAULT if pretrained else None
        base_model = resnet152(weights=weights)
        # Save the feature dimensionality for projection
        in_features = base_model.fc.in_features
        # Replace the final fully connected layer with a new one
        base_model.fc = nn.Linear(in_features, embedding_dim)
        # Initialize the new layer
        nn.init.xavier_uniform_(base_model.fc.weight)
        self.backbone = base_model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: image -> embedding

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, 3, H, W).

        Returns:
            torch.Tensor: Embedding of shape (batch_size, embedding_dim).
        """
        embeddings = self.backbone(x)   # shape: (batch_size, embedding_dim)
        return embeddings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example: produce 1024-D embeddings
model = ResNet152Embedder(pretrained=True, embedding_dim=1024).to(device)
summary(model, (3, 224, 224), device=device.type)
dummy_input = torch.randn(10, 3, 224, 224, device=device)
embed = model(dummy_input)
print(f"Output embedding shape: {embed.shape}")

Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /Users/miaoyidi/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth
100%|██████████| 230M/230M [00:04<00:00, 54.1MB/s] 


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,