## Top-5 Document Retrieval

In [5]:
import torch
import torch.nn as nn
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence

class CBOW(torch.nn.Module):
  def __init__(self, voc, emb):
    super().__init__()
    self.embeddings = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
    self.linear = torch.nn.Linear(in_features=emb, out_features=voc)

  def forward(self, inpt):
    emb = self.embeddings(inpt)
    emb = emb.mean(dim=1)
    out = self.linear(emb)
    return out


class QryTower(nn.Module):
    def __init__(self, embedding_dim=100, hidden_dim=64):
        super().__init__()
        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, x, lengths):
        # 1. Sort by descending length
        lengths, sort_idx = lengths.sort(descending=True)
        x = x[sort_idx]

        # 2. Pack
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True)

        # 3. GRU
        _, hidden = self.rnn(packed)  # hidden: [1, batch_size, hidden_dim]

        # 4. Unsort
        _, unsort_idx = sort_idx.sort()
        hidden = hidden.squeeze(0)[unsort_idx]  # [batch_size, hidden_dim]

        return hidden

class DocTower(nn.Module):
    def __init__(self, embedding_dim=100, hidden_dim=64):
        super().__init__()
        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, x, lengths):
        # 1. Sort by descending length
        lengths, sort_idx = lengths.sort(descending=True)
        x = x[sort_idx]

        # 2. Pack
        packed = pack_padded_sequence(x, lengths.cpu(), batch_first=True)

        # 3. GRU
        _, hidden = self.rnn(packed)  # hidden: [1, batch_size, hidden_dim]

        # 4. Unsort
        _, unsort_idx = sort_idx.sort()
        hidden = hidden.squeeze(0)[unsort_idx]  # [batch_size, hidden_dim]

        return hidden



## retrieving the model

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models with same architecture
qryTower = QryTower().to(device)
docTower = DocTower().to(device)

# Load saved checkpoint
checkpoint = torch.load("two_tower_model_GRU_padding.pt", map_location=device)

qryTower.load_state_dict(checkpoint["qryTower"])
docTower.load_state_dict(checkpoint["docTower"])
token_to_index = checkpoint["token_to_index"]

qryTower.eval()
docTower.eval()

print("✅ Model loaded and ready for inference!")


✅ Model loaded and ready for inference!


## Precompute Document Embeddings

In [None]:
from torch.utils.data import Dataset
import torch

class TripletDataset(Dataset):
    def __init__(self, df, token_to_index, embedding_layer, device):
        self.df = df
        self.token_to_index = token_to_index
        self.embedding_layer = embedding_layer
        self.device = device

        self.embedding_dim = embedding_layer.embedding_dim
        self.oov_embeddings = {}  # For storing fixed random vectors for OOV tokens

        self.query_max_len = max(len(text.lower().split()) for text in df["query"])
        all_docs = df["positive_passage"].tolist() + df["negative_passage"].tolist()
        self.doc_max_len = max(len(text.lower().split()) for text in all_docs)

    def embed(self, token):
        """Return embedding for token: from vocab or generate fixed OOV vector."""
        if token in self.token_to_index:
            idx = self.token_to_index[token]
            return self.embedding_layer(torch.tensor(idx, device=self.device))
        else:
            if token not in self.oov_embeddings:
                self.oov_embeddings[token] = torch.randn(self.embedding_dim, device=self.device) * 0.1
            return self.oov_embeddings[token]
    
    def embed_text(self, text, max_len):
        tokens = text.lower().split()
        embedded_tokens = []

        for tok in tokens[:max_len]:
            emb = self.embed(tok)
            embedded_tokens.append(emb)

        true_len = len(embedded_tokens)

        # Padding with vector at index 0
        pad_len = max_len - true_len
        if pad_len > 0:
            pad_vec = self.embedding_layer(torch.tensor(0, device=self.device))  # index 0 used for padding
            embedded_tokens.extend([pad_vec] * pad_len)

        embedded = torch.stack(embedded_tokens)
        return embedded, true_len


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        query, q_len = self.embed_text(row["query"], self.query_max_len)
        pos, p_len   = self.embed_text(row["positive_passage"], self.doc_max_len)
        neg, n_len   = self.embed_text(row["negative_passage"], self.doc_max_len)

        return query, q_len, pos, p_len, neg, n_len


In [None]:
docTower.eval()

from datasets import load_dataset

df_sn = load_dataset("cocoritzy/week_2_triplet_dataset_soft_negatives")
# dataset = load_dataset("cocoritzy/week_2_triplet_dataset_hard_negatives", split="train[:10%]") # 10% of the datab

df_sn = df_sn["train"].to_pandas()

all_doc_texts = df_sn["positive_passage"].tolist()[:1000]  # Adjust this as needed

doc_vectors = []
with torch.no_grad():
    for text in all_doc_texts:
        doc_tensor, doc_len = triplet_dataset.embed_text(text, triplet_dataset.doc_max_len)
        doc_tensor = doc_tensor.unsqueeze(0).to(device)
        doc_len = torch.tensor([doc_len], device=device)

        doc_vec = docTower(doc_tensor, doc_len)  # [1, hidden_dim]
        doc_vectors.append(doc_vec.squeeze(0))

doc_matrix = torch.stack(doc_vectors).to(device)  # [num_docs, hidden_dim]


NameError: name 'df' is not defined

In [None]:
qryTower.eval()

query = "how do I get solar panels for free"

qry_tensor, qry_len = triplet_dataset.embed_text(query, triplet_dataset.query_max_len)
qry_tensor = qry_tensor.unsqueeze(0).to(device)
qry_len = torch.tensor([qry_len], device=device)

qry_vec = qryTower(qry_tensor, qry_len)  # [1, hidden_dim]


In [None]:
from torch.nn.functional import cosine_similarity

# Compute cosine similarities
similarities = cosine_similarity(qry_vec, doc_matrix)  # [num_docs]

# Get top 5 document indices
top_k = 5
top_indices = similarities.topk(top_k).indices.cpu().numpy()

# Display results
print(f"\n🔍 Query: {query}")
print(f"\n📚 Top {top_k} most relevant passages:")
for i in top_indices:
    print(f"\n🔸 Similarity: {similarities[i].item():.4f}")
    print(all_doc_texts[i])
