In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizerFast, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
import random


In [2]:
def prepare_dataset(split="train", sample_size=None):
    print(f"Loading MS MARCO dataset ({split} split)...")
    ds = load_dataset("microsoft/ms_marco", "v2.1")
    
    dataset = ds[split]
    
    if sample_size:
        # Select a random subset if sample_size is specified
        indices = random.sample(range(len(dataset)), min(sample_size, len(dataset)))
        dataset = dataset.select(indices)
    
    processed_examples = []
    
    for example in tqdm(dataset, desc="Processing examples"):
        #multiple passages use the one that contains the answer
        question = example["query"]
        
        # Skip examples without answers
        if len(example["answers"]) == 0 or not example["passages"]["is_selected"]:
            continue
            
        # Find the selected passage that contains the answer
        selected_indices = [i for i, is_selected in enumerate(example["passages"]["is_selected"]) 
                           if is_selected]
        
        if not selected_indices:
            continue
            
        # Use the first selected passage as context
        context_idx = selected_indices[0]
        context = example["passages"]["passage_text"][context_idx]
        
        # Get the answer
        answer_text = example["answers"][0]
        
        # Find answer position in the context
        answer_start = context.find(answer_text)
        
        # Skip if answer not found in context
        if answer_start == -1:
            continue
            
        answer_end = answer_start + len(answer_text)
        
        processed_examples.append({
            "context": context,
            "question": question, 
            "answer_text": answer_text,
            "answer_start": answer_start,
            "answer_end": answer_end
        })
    
    print(f"Processed {len(processed_examples)} examples")
    return processed_examples

In [3]:
class QADataset(Dataset):
    def __init__(self, examples, tokenizer, max_length=384):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Tokenize question and context
        encoding = self.tokenizer(
            example["question"],
            example["context"],
            max_length=self.max_length,
            truncation="only_second",
            stride=128,
            return_offsets_mapping=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        # Get sequence IDs before converting to dictionary
        sequence_ids = encoding.sequence_ids(0)
        
        # Get offsets and remove from encoding
        offsets = encoding.pop("offset_mapping")[0].tolist()
        
        # Remove batch dimension added by tokenizer for other tensors
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        
        # Find token positions for the answer
        start_char = example["answer_start"]
        end_char = example["answer_end"]
        
        # Find the tokens that correspond to the context (not the question)
        context_start = 0
        while sequence_ids[context_start] != 1:
            context_start += 1
            
        # Find token start and end positions
        token_start_position = token_end_position = 0
        
        for i, (offset_start, offset_end) in enumerate(offsets):
            # Skip special tokens and question tokens
            if sequence_ids[i] != 1:
                continue
                
            # Check if this token contains the answer start
            if offset_start <= start_char < offset_end:
                token_start_position = i
                
            # Check if this token contains the answer end
            if offset_start <= end_char <= offset_end:
                token_end_position = i
                break
        
        encoding["start_positions"] = torch.tensor(token_start_position)
        encoding["end_positions"] = torch.tensor(token_end_position)
        
        return encoding

In [4]:
class BertForQA(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(BertForQA, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.qa_outputs = nn.Linear(self.bert.config.hidden_size, 2)  # 2 for start/end position
        
    def forward(self, input_ids, attention_mask, token_type_ids=None, start_positions=None, end_positions=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        sequence_output = outputs.last_hidden_state
        logits = self.qa_outputs(sequence_output)
        
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        loss = None
        if start_positions is not None and end_positions is not None:
            loss_fct = nn.CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            loss = (start_loss + end_loss) / 2
        
        return {
            "loss": loss,
            "start_logits": start_logits,
            "end_logits": end_logits
        }

In [5]:
def train_model(model, train_dataloader, val_dataloader=None, epochs=3, lr=5e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for batch in progress_bar:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                token_type_ids=batch.get("token_type_ids", None),
                start_positions=batch["start_positions"],
                end_positions=batch["end_positions"]
            )
            
            loss = outputs["loss"]
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            train_loss += loss.item()
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        avg_train_loss = train_loss / len(train_dataloader)
        print(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")
        
        # Validation
        if val_dataloader:
            model.eval()
            val_loss = 0
            
            with torch.no_grad():
                for batch in tqdm(val_dataloader, desc="Validating"):
                    batch = {k: v.to(device) for k, v in batch.items()}
                    
                    outputs = model(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        token_type_ids=batch.get("token_type_ids", None),
                        start_positions=batch["start_positions"],
                        end_positions=batch["end_positions"]
                    )
                    
                    val_loss += outputs["loss"].item()
            
            avg_val_loss = val_loss / len(val_dataloader)
            print(f"Validation loss: {avg_val_loss:.4f}")
    
    return model

def answer_question(model, tokenizer, question, context, max_length=384):
    """Predict answer span from context"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    # Tokenize input
    inputs = tokenizer(
        question, 
        context, 
        max_length=max_length,
        truncation="only_second",
        stride=128,
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt"
    )
    
    # Get sequence IDs and offset mapping
    sequence_ids = inputs.sequence_ids(0)
    offset_mapping = inputs.pop("offset_mapping").tolist()[0]
    
    # Move inputs to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass
    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs.get("token_type_ids", None)
        )
    
    # Get predictions
    start_logits = outputs["start_logits"]
    end_logits = outputs["end_logits"]
    
    # Convert to Python lists
    start_logits = start_logits[0].cpu().numpy()
    end_logits = end_logits[0].cpu().numpy()
    
    # Get best answer (consider only context tokens)
    context_tokens = []
    for i, seq_id in enumerate(sequence_ids):
        if seq_id == 1:  # 1 refers to context (not question or special tokens)
            context_tokens.append(i)
    
    # Only consider answers in the context
    start_logits = [float('-inf') if i not in context_tokens else score for i, score in enumerate(start_logits)]
    end_logits = [float('-inf') if i not in context_tokens else score for i, score in enumerate(end_logits)]
    
    # Find best answer
    start_idx = np.argmax(start_logits)
    end_idx = np.argmax(end_logits[start_idx:]) + start_idx
    
    # Convert token indices to character spans
    token_start, token_end = start_idx, end_idx
    
    # Get character span from token indices
    char_start = offset_mapping[token_start][0]
    char_end = offset_mapping[token_end][1]
    
    # Extract answer text
    answer = context[char_start:char_end]
    
    return answer

In [6]:

# Load and prepare dataset
train_examples = prepare_dataset(split="train", sample_size=100000)
val_examples = prepare_dataset(split="validation", sample_size=10000)

# Initialize tokenizer - use the fast version explicitly
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Create datasets
train_dataset = QADataset(train_examples, tokenizer)
val_dataset = QADataset(val_examples, tokenizer)

# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16)

# Initialize model
model = BertForQA()

# Train model
trained_model = train_model(
    model, 
    train_dataloader, 
    val_dataloader, 
    epochs=3
)


Loading MS MARCO dataset (train split)...


README.md: 0.00B [00:00, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

train-00000-of-00007.parquet:   0%|          | 0.00/240M [00:00<?, ?B/s]

train-00001-of-00007.parquet:   0%|          | 0.00/240M [00:00<?, ?B/s]

train-00002-of-00007.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

train-00003-of-00007.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

train-00004-of-00007.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

train-00005-of-00007.parquet:   0%|          | 0.00/242M [00:00<?, ?B/s]

train-00006-of-00007.parquet:   0%|          | 0.00/244M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/101093 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/808731 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/101092 [00:00<?, ? examples/s]

Processing examples: 100%|██████████| 100000/100000 [00:16<00:00, 5960.33it/s]


Processed 22788 examples
Loading MS MARCO dataset (validation split)...


Processing examples: 100%|██████████| 10000/10000 [00:01<00:00, 6465.48it/s]


Processed 1463 examples


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Epoch 1/3: 100%|██████████| 1425/1425 [28:35<00:00,  1.20s/it, loss=1.0407]


Epoch 1 - Average training loss: 1.4859


Validating: 100%|██████████| 92/92 [00:36<00:00,  2.52it/s]


Validation loss: 1.0892


Epoch 2/3: 100%|██████████| 1425/1425 [28:39<00:00,  1.21s/it, loss=0.2553]


Epoch 2 - Average training loss: 0.8994


Validating: 100%|██████████| 92/92 [00:36<00:00,  2.51it/s]


Validation loss: 1.0985


Epoch 3/3: 100%|██████████| 1425/1425 [28:39<00:00,  1.21s/it, loss=0.0646]


Epoch 3 - Average training loss: 0.5442


Validating: 100%|██████████| 92/92 [00:36<00:00,  2.51it/s]

Validation loss: 1.2364





In [13]:
# Save model to current directory instead of a subdirectory
model_path = "./"  # Current directory
torch.save(trained_model.state_dict(), f"{model_path}pytorch_model.bin")
tokenizer.save_pretrained(model_path)
print(f"Model saved to current directory")

# Example usage
context = """
Mohamed Salah is an Egyptian professional footballer who plays as a forward for Liverpool and the Egypt national team. Known for his incredible speed, dribbling, and finishing, he has won multiple Premier League and Champions League titles. Salah has broken numerous records, including becoming Liverpool’s all-time top scorer in the Champions League. He is a national hero in Egypt, inspiring millions with his achievements. His humility and dedication make him one of the greatest footballers of his generation.
"""
question = "What is mohamed salah nationality?"

answer = answer_question(trained_model, tokenizer, question, context)
print(f"Question: {question}")
print(f"Answer: {answer}")

Model saved to current directory
Question: What is mohamed salah nationality?
Answer: Egyptian
