In [1]:
import torch
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# Define a simple dataset class
class QADataset(Dataset):
    def __init__(self, questions, passages, tokenizer, max_length=128):
        self.questions = questions
        self.passages = passages
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        question = self.questions[idx]
        passage = self.passages[idx]
        question_enc = self.tokenizer(question, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        passage_enc = self.tokenizer(passage, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return question_enc['input_ids'].squeeze(), question_enc['attention_mask'].squeeze(), passage_enc['input_ids'].squeeze(), passage_enc['attention_mask'].squeeze()

In [2]:
class DualEncoderModel(torch.nn.Module):
    def __init__(self, question_encoder, passage_encoder):
        super(DualEncoderModel, self).__init__()
        self.question_encoder = question_encoder
        self.passage_encoder = passage_encoder

    def forward(self, question_ids, question_mask, passage_ids, passage_mask):
        question_outputs = self.question_encoder(input_ids=question_ids, attention_mask=question_mask)
        passage_outputs = self.passage_encoder(input_ids=passage_ids, attention_mask=passage_mask)
        
        # Pooling the output embeddings of the [CLS] token
        question_pooler_output = question_outputs.pooler_output
        passage_pooler_output = passage_outputs.pooler_output
        return question_pooler_output, passage_pooler_output

In [3]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    distance_positive = F.cosine_similarity(anchor, positive)
    distance_negative = F.cosine_similarity(anchor, negative)
    loss_values = torch.relu(distance_positive - distance_negative + margin)
    loss = torch.mean(loss_values)
    return loss

In [4]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

# Initialize two BERT models
question_encoder = BertModel.from_pretrained(model_name)
passage_encoder = BertModel.from_pretrained(model_name)
model = DualEncoderModel(question_encoder, passage_encoder)

In [5]:
import pandas as pd
df=pd.read_csv("archive2/nq_small.csv")

In [6]:
questions=df["question"].values
passages=df["context"].values
questions[:5]

array(['1 kilohertz is equal to how many hertz?',
       'How big is a 1 18 scale model?',
       'When did christianity become official religion of rome?',
       'When did salt and pepper became a pair?',
       'Who was president when world war 2 ended?'], dtype=object)

In [None]:
from tqdm import tqdm
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
# questions = ["What is the capital of France?", "Who wrote Romeo and Juliet?"]
# passages = ["Paris is the capital of France.", "William Shakespeare wrote Romeo and Juliet."]

# Create dataset and dataloader
dataset = QADataset(questions, passages, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# Training loop
num_epochs = 5
margin = 0.5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False)

    for question_ids, question_mask, passage_ids, passage_mask in progress_bar:
        question_ids, question_mask = question_ids, question_mask
        passage_ids, passage_mask = passage_ids, passage_mask

        question_embeddings, passage_embeddings = model(question_ids, question_mask, passage_ids, passage_mask)
        negative_embeddings = torch.roll(passage_embeddings, 1, dims=0)

        loss = triplet_loss(question_embeddings, passage_embeddings, negative_embeddings, margin)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        progress_bar.set_postfix({'loss': total_loss / len(progress_bar)})

    progress_bar.close()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")


                                                                                                                       

Epoch 1, Loss: 0.2694773841841535


Epoch 2/5:  88%|██████████████████████████████████████████████▋      | 4695/5329 [4:01:30<37:12,  3.52s/it, loss=0.122]

In [None]:
# Define file paths for saving the model
model_path = 'dual_encoder_model.pth'

# Save the model's state dictionary along with any necessary metadata
torch.save({
    'question_encoder_state_dict': model.question_encoder.state_dict(),
    'passage_encoder_state_dict': model.passage_encoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    # Add any other relevant information such as tokenizer, hyperparameters, etc.
}, model_path)

print(f"Model saved to '{model_path}'")