In [None]:
import os
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset, load_from_disk , Dataset, DatasetDict
from sentence_transformers import SentenceTransformer
import time

In [None]:
from google.colab import drive

# Mount Drive
drive.mount('/content/drive')

In [None]:
REPO = "/content/drive/MyDrive/00-github/sentence-embedding-sensitivity"
DATA = os.path.join(REPO,"Data")
DATASETS_SAVE_PATH = os.path.join(DATA,"visla_datasets")
GEN_DATA = os.path.join(REPO,"VISLA","Generic_VISLA.tsv")
SPA_DATA = os.path.join(REPO,"VISLA","Spatial_VISLA.tsv")

In [None]:
model_dict = {
    "par-dis-roberta": "paraphrase-distilroberta-base-v1",
    "roberta-base-v3": "msmarco-roberta-base-v3",
    "par-mpnet": "paraphrase-mpnet-base-v2",
    "par-xlm-r": "paraphrase-xlm-r-multilingual-v1",
    "labse": "LaBSE",
    "e5-base": "intfloat/e5-base-v2",
    "gte-base": "thenlper/gte-base",
    "bge-base-v15": "BAAI/bge-base-en-v1.5"
}

In [None]:
# @title make dataset folder
os.makedirs(DATASETS_SAVE_PATH, exist_ok=True)

In [None]:
# @title load generic dataset
generic_df = pd.read_csv(GEN_DATA, sep="\t")

In [None]:
# @title  Fast GPU modes
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_float32_matmul_precision("high")
    gpu_name = torch.cuda.get_device_name(0)
    ENCODE_BS = 1024 if "A100" in gpu_name else 256
    AMP_DTYPE = torch.bfloat16 if "A100" in gpu_name else torch.float16
else:
    gpu_name = "CPU"
    ENCODE_BS = 64
    AMP_DTYPE = None

print(f"Running on {gpu_name}, batch_size={ENCODE_BS}, dtype={AMP_DTYPE}")

In [None]:
for model_name, model_id in model_dict.items():
    print(f"Model {model_name} is on processing...")
    model = SentenceTransformer(model_id, device=device)
    model.eval()
    if device == "cuda":
        try:
            base = model._first_module().auto_model
            base.to(dtype=AMP_DTYPE, device=device)
        except Exception:
            pass
    caption = generic_df["caption"].tolist()
    positive_sent = generic_df["second positive"].tolist()
    negative_sent = generic_df["negative-caption"].tolist()

    with torch.inference_mode():
        with torch.autocast(device_type=device, dtype=AMP_DTYPE):
            # Encode sentences
            t1 = time.time()
            embA = model.encode(
                caption,
                batch_size=ENCODE_BS,
                convert_to_numpy=True,
                show_progress_bar=True,
                normalize_embeddings=True,
            )
            embB = model.encode(
                positive_sent,
                batch_size=ENCODE_BS,
                convert_to_numpy=True,
                show_progress_bar=True,
                normalize_embeddings=True,
            )
            embC = model.encode(
                negative_sent,
                batch_size=ENCODE_BS,
                convert_to_numpy=True,
                show_progress_bar=True,
                normalize_embeddings=True,
            )
            t2 = time.time()
            print(f"Encoded in {t2-t1:.1f} seconds")
            print(
                f"Embeddings: {embA.shape}, {embB.shape}, {embC.shape}, dtype={embA.dtype}"
            )

            # Calculate cosine similarity
            t1 = time.time()
            cosine_scores_pos = np.sum(embA * embB, axis=1)
            cosine_scores_neg = np.sum(embA * embC, axis=1)
            cosine_scores = cosine_scores_pos - cosine_scores_neg
            t2 = time.time()
            print(f"Calculated cosine similarity in {t2-t1:.1f} seconds")
            print(
                f"Cosine scores: min {cosine_scores.min():.4f}, max {cosine_scores.max():.4f}, mean {cosine_scores.mean():.4f}, std {cosine_scores.std():.4f}"
            )
            save_path = os.path.join(
                DATASETS_SAVE_PATH, f"VISLA_generic_{model_name}.npz"
            )
            np.savez_compressed(
                save_path,
                embedding1=embA.astype(np.float16),
                embedding2=embB.astype(np.float16),
                embedding3=embC.astype(np.float16),
                cosine_scores_pos=cosine_scores_pos.astype(np.float16),
                cosine_scores_neg=cosine_scores_neg.astype(np.float16),
                cosine_scores=cosine_scores.astype(np.float16),
            )
            print(f"Saved to {save_path}")