In [1]:
from transformers import BertModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F

class Retriever_model(nn.Module):
    def __init__(self):

        super(Retriever_model, self).__init__()
        self.encoder = BertModel.from_pretrained("bert-base-uncased")

        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        for param in self.encoder.parameters():
            param.requires_grad = False

        self.pooler = nn.AdaptiveAvgPool1d(1)
        self.dense = nn.Linear(self.encoder.config.hidden_size, 512)
        self.dense2 = nn.Linear(512,256)
        self.dense3 = nn.Linear(256,256)
    
    def forward(self, input_ids, attention_mask):

        with torch.no_grad(): 
            outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden = outputs.last_hidden_state 
            pooled = torch.mean(last_hidden, dim=1)

        pooled = F.relu(self.dense(pooled))
        pooled = F.relu(self.dense2(pooled))
        return self.dense3(pooled) 

    def encode_texts(self, texts, device):
        encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        return self.forward(encoding["input_ids"].to(device), encoding["attention_mask"].to(device))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# from torch.utils.data import Dataset
# import torch
# import random
# from torch.utils.data import ConcatDataset, DataLoader
# from datasets import load_dataset


# data = load_dataset("neural-bridge/rag-dataset-12000")
# train_data = data.get("train")


# class Model_Dataset(Dataset):
#     def __init__(self,data):
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, index):

#         item = self.data[index]
#         context = item['context']
#         question = item['question']

#         return  {
#             "context" : context , 
#             "question" : question
#         }
    

# train_dataset = Model_Dataset(train_data)
# train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [3]:
# def multiple_negatives_ranking_loss(query_embeds, context_embeds):
#     # [batch_size, embedding_dim] → cosine sim matrix
#     sim_scores = torch.matmul(query_embeds, context_embeds.T)  # [B, B]
#     labels = torch.arange(len(sim_scores)).to(sim_scores.device)  # [0, 1, ..., B-1]
#     return F.cross_entropy(sim_scores, labels)

In [4]:
# device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
# model = Retriever_model().to(device)
# print(device)

# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)
# num_epochs = 1

# for epoch in range(num_epochs):
#     model.train()
#     total_loss = 0.0

#     for batch in train_dataloader:
#         queries = batch["question"]
#         contexts = batch["context"]

#         query_embeddings = model.encode_texts(queries, device)  
#         context_embeddings = model.encode_texts(contexts, device) 

#         loss = multiple_negatives_ranking_loss(query_embeddings, context_embeddings)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()

#     print(f"Epoch {epoch+1} | Loss: {total_loss:.8f}")

In [5]:
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch

# Dataset
class RagDataset(Dataset):
    def __init__(self, questions, contexts):
        self.questions = questions
        self.contexts = contexts

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        return {
            "question": self.questions[idx],
            "context": self.contexts[idx]
        }

# Collate to keep strings intact
def collate_batch(batch):
    return {
        "question": [item["question"] for item in batch],
        "context": [item["context"] for item in batch]
    }

# Loss
def multiple_negatives_ranking_loss(q_embeds, c_embeds):
    sim = torch.matmul(q_embeds, c_embeds.T)
    labels = torch.arange(len(sim)).to(sim.device)
    return F.cross_entropy(sim, labels)

# Load data
from datasets import load_dataset
data = load_dataset("neural-bridge/rag-dataset-12000", split="train")
questions, contexts = data["question"], data["context"]

# Dataloader
train_dataset = RagDataset(questions, contexts)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)

# Model
model = Retriever_model()
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Optimizer
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

# Training loop
for epoch in range(5):
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        q_texts, c_texts = batch["question"], batch["context"]
        q_embed = model.encode_texts(q_texts, device)
        c_embed = model.encode_texts(c_texts, device)

        loss = multiple_negatives_ranking_loss(q_embed, c_embed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.8f}")

ValueError: Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.