astronomy_test_data.csv

sparsecl_checkpoint.pth

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm

# --- CONFIG ---
# Ensure these files are uploaded to Colab
CSV_PATH = "/content/astronomy_test_data.csv"
CHECKPOINT_PATH = "/content/sparsecl_checkpoint.pth"
MODEL_NAME = "BAAI/bge-base-en-v1.5"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- MODEL DEFINITION (Must match your training) ---
class SentenceEncoder(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        cls = outputs.last_hidden_state[:, 0, :]
        return F.normalize(cls, p=2, dim=1)

def calculate_hoyer(vector):
    # Hoyer Sparsity: (sqrt(d) - L1/L2) / (sqrt(d) - 1)
    d = vector.shape[1]
    sqrt_d = np.sqrt(d)
    l1 = torch.norm(vector, p=1, dim=1)
    l2 = torch.norm(vector, p=2, dim=1) + 1e-8
    hoyer = (sqrt_d - (l1 / l2)) / (sqrt_d - 1)
    return hoyer

# --- MAIN SCRIPT ---
print(f"Loading model: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = SentenceEncoder(MODEL_NAME).to(DEVICE)

# Load SparseCL Checkpoint
if torch.cuda.is_available():
    ckpt = torch.load(CHECKPOINT_PATH)
else:
    ckpt = torch.load(CHECKPOINT_PATH, map_location=torch.device('cpu'))

model.load_state_dict(ckpt)
model.eval()
print("Checkpoint loaded successfully.")

# Load Data
df = pd.read_csv(CSV_PATH)
print(f"Loaded {len(df)} rows from {CSV_PATH}")

# Storage
results = []

print("Computing embeddings...")
with torch.no_grad():
    for _, row in tqdm(df.iterrows(), total=len(df)):
        # Inputs
        texts = [str(row['claim']), str(row['paraphrase']), str(row['contradiction'])]

        # Tokenize
        inputs = tokenizer(texts, padding=True, truncation=True, max_length=128, return_tensors='pt').to(DEVICE)

        # Embed
        embeddings = model(inputs['input_ids'], inputs['attention_mask'])

        emb_claim = embeddings[0].unsqueeze(0)
        emb_para  = embeddings[1].unsqueeze(0)
        emb_contra = embeddings[2].unsqueeze(0)

        # 1. Cosine Similarity (Dot product since normalized)
        cos_para = torch.mm(emb_claim, emb_para.T).item()
        cos_contra = torch.mm(emb_claim, emb_contra.T).item()

        # 2. Hoyer Sparsity of Difference
        diff_para = emb_claim - emb_para
        diff_contra = emb_claim - emb_contra

        hoyer_para = calculate_hoyer(diff_para).item()
        hoyer_contra = calculate_hoyer(diff_contra).item()

        # Append for plotting
        results.append({"Type": "Paraphrase", "Cosine": cos_para, "Hoyer": hoyer_para})
        results.append({"Type": "Contradiction", "Cosine": cos_contra, "Hoyer": hoyer_contra})

df_res = pd.DataFrame(results)

# --- PLOTTING ---
sns.set_style("whitegrid")
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot 1: Histogram
sns.histplot(data=df_res, x="Hoyer", hue="Type", kde=True, bins=20, ax=axes[0], palette=["green", "red"], alpha=0.6)
axes[0].set_title("Hoyer Sparsity Distribution", fontsize=14, fontweight='bold')
axes[0].set_xlabel("Hoyer Sparsity Score", fontsize=12)
axes[0].set_ylabel("Frequency", fontsize=12)

# Plot 2: Scatter
sns.scatterplot(data=df_res, x="Cosine", y="Hoyer", hue="Type", style="Type", ax=axes[1], palette=["green", "red"], s=80, alpha=0.7)
axes[1].set_title("Cosine vs. Hoyer Separation", fontsize=14, fontweight='bold')
axes[1].set_xlabel("Cosine Similarity", fontsize=12)
axes[1].set_ylabel("Hoyer Sparsity Score", fontsize=12)
axes[1].axhline(y=0.45, color='gray', linestyle='--', label="Potential Decision Boundary") # Approx visual aid

plt.tight_layout()
plt.savefig("sparsecl_distribution.png", dpi=300)
plt.show()

print("Plot saved as 'sparsecl_distribution.png'. Please download it!")