In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import numpy as np

# 🔥 Use MPS if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# --------------------------
# Encoder with MLP Projection
# --------------------------
class SimpleRetrieverEncoder(nn.Module):
    def __init__(self, base_model="sentence-transformers/all-MiniLM-L6-v2", proj_dim=128):
        super().__init__()
        self.encoder = SentenceTransformer(base_model)
        self.embedding_adapter = nn.Sequential(
            nn.Linear(384, 384),
            nn.ReLU()
        )
        self.projector = nn.Sequential(
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Linear(256, proj_dim)
        )

    def forward(self, texts):
        with torch.no_grad():
            emb = self.encoder.encode(texts, convert_to_tensor=True)
        emb = emb.to(next(self.projector.parameters()).device)
        emb = self.embedding_adapter(emb)
        return self.projector(emb)

# --------------------------
# Dataset + Loss
# --------------------------
class QCDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item["question"], item["context"]

def cosine_similarity_loss(q_embed, c_embed):
    scores = F.cosine_similarity(q_embed, c_embed)
    targets = torch.ones_like(scores)
    return F.mse_loss(scores, targets)

# --------------------------
# Load and Train
# --------------------------
print("Loading data...")
dataset = load_dataset("neural-bridge/rag-dataset-12000")
train_data = dataset["train"]
train_loader = DataLoader(QCDataset(train_data), batch_size=32, shuffle=True)

print("Initializing models...")
question_model = SimpleRetrieverEncoder().to(device)
context_model = SimpleRetrieverEncoder().to(device)

optimizer = torch.optim.Adam(
    list(question_model.parameters()) + list(context_model.parameters()), lr=1e-4
)

print("Training...")
for epoch in range(10):
    total_loss = 0.0
    for questions, contexts in train_loader:
        # Get question embeddings and context embeddings
        q_embed = question_model(questions).to(device)
        c_embed = context_model(contexts).to(device)

        loss = cosine_similarity_loss(q_embed, c_embed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}")

  from .autonotebook import tqdm as notebook_tqdm


Loading data...
Initializing models...
Training...
Epoch 1 - Loss: 10.9604
Epoch 2 - Loss: 0.0190
Epoch 3 - Loss: 0.0045


In [2]:
import numpy as np
import json
import os

print("Encoding and saving context embeddings...")

all_contexts = [item["context"] for item in train_data]
all_ids = [str(i) for i in range(len(all_contexts))]

context_embeddings = []
batch_size = 64
for i in range(0, len(all_contexts), batch_size):
    batch = all_contexts[i:i+batch_size]
    with torch.no_grad():
        emb = context_model(batch).cpu().numpy()
        context_embeddings.append(emb)

# Stack and save
context_embeddings = np.vstack(context_embeddings)
np.save("context_embeddings.npy", context_embeddings)

with open("context_texts.json", "w") as f:
    json.dump(all_contexts, f)

print("Saved to: context_embeddings.npy and context_texts.json")

Encoding and saving context embeddings...
Saved to: context_embeddings.npy and context_texts.json


In [3]:
from sklearn.metrics.pairwise import cosine_similarity

# Load saved data
embeddings = np.load("context_embeddings.npy")
with open("context_texts.json") as f:
    context_texts = json.load(f)

# Encode the query
query = "What is the Berry Export Summary 2028 and what is its purpose?"
with torch.no_grad():
    query_embedding = question_model([query]).cpu().numpy()

# Compute cosine similarity
sims = cosine_similarity(query_embedding, embeddings)[0]
top_k_indices = sims.argsort()[::-1][:3]

print("\nTop-3 Retrieved Contexts:")
for i, idx in enumerate(top_k_indices):
    print(f"\nRank {i+1} (Score: {sims[idx]:.4f}): {context_texts[idx]}")


Top-3 Retrieved Contexts:

Rank 1 (Score: 0.9987): Export Regulations and Compliance
Hansen, Fay, Business Credit
U.S. export laws and regulations are far-reaching and have become more so in recent years. Even large, sophisticated U.S. companies with substantial resources and compliance programs occasionally run afoul of the law and face time-consuming investigations and significant fines. In December 2003, Sun Microsystems and two of its subsidiaries agreed to pay $291,000 in fines to settle charges involving illegal exports of computers to military end-users in China and Egypt. In the same month, Honeywell International paid a penalty to settle changes that it illegally exported chemicals to Mexico. In February of this year, Morton International and its French and Japanese affiliates agreed to pay a $647,500 penalty to settle charges in connection with the export and reexport of chemical compounds in violation of U.S. regulations.
"The export control regime of the United States is o