In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
from collections import Counter
import string
import re

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ================================
# Step 1: Dataset definition
# ================================
class SquadDataset(Dataset):
    def __init__(self, filename, max_vocab_size=10000, max_passage_length=300, max_question_length=30):
        self.max_passage_length = max_passage_length
        self.max_question_length = max_question_length
        self.data = []
        self.word2idx = {}
        self.idx2word = {}

        # Load SQuAD JSON
        with open(filename, 'r') as f:
            squad = json.load(f)

        # Build vocabulary
        self.build_vocab(squad, max_vocab_size)

        # Process each paragraph and question-answer pair
        for article in squad['data']:
            for paragraph in article['paragraphs']:
                context = paragraph['context']
                context_tokens = self.tokenize(context)
                context_ids = self.encode(context_tokens, self.max_passage_length)

                for qa in paragraph['qas']:
                    question = qa['question']
                    question_tokens = self.tokenize(question)
                    question_ids = self.encode(question_tokens, self.max_question_length)

                    answer = qa['answers'][0]
                    answer_text = answer['text']
                    answer_start = answer['answer_start']

                    # Approximate token alignment
                    start_idx, end_idx = self.find_answer_span(context, context_tokens, answer_text, answer_start)
                    if start_idx == -1 or end_idx == -1:
                        continue
                    if start_idx >= self.max_passage_length or end_idx >= self.max_passage_length:
                        continue

                    self.data.append({
                        'context': context_ids,
                        'question': question_ids,
                        'start': start_idx,
                        'end': end_idx
                    })

    def build_vocab(self, squad, max_vocab_size):
        word_counts = Counter()
        for article in squad['data']:
            for paragraph in article['paragraphs']:
                word_counts.update(self.tokenize(paragraph['context']))
                for qa in paragraph['qas']:
                    word_counts.update(self.tokenize(qa['question']))
        # Top words
        most_common = word_counts.most_common(max_vocab_size - 2)
        vocab = ['<pad>', '<unk>'] + [w for w, _ in most_common]
        self.word2idx = {w: idx for idx, w in enumerate(vocab)}
        self.idx2word = {idx: w for idx, w in enumerate(vocab)}

    def tokenize(self, text):
        text = text.lower()
        text = re.sub(f'[{string.punctuation}]', '', text)
        return text.split()

    def encode(self, tokens, max_length):
        ids = [self.word2idx.get(t, self.word2idx['<unk>']) for t in tokens]
        if len(ids) > max_length:
            ids = ids[:max_length]
        else:
            ids += [self.word2idx['<pad>']] * (max_length - len(ids))
        return ids

    def detokenize(self, tokens):
        return " ".join(tokens)

    def find_answer_span(self, context, tokens, answer_text, char_start):
        # Naive span finding
        text_lower = context.lower()
        answer_lower = answer_text.lower()
        char_end = char_start + len(answer_text)
        answer_in_context = text_lower[char_start:char_end]

        # Reconstruct text from tokens to approximate token-level alignment
        for i in range(len(tokens)):
            for j in range(i, len(tokens)):
                span_text = " ".join(tokens[i:j+1])
                if span_text == answer_lower:
                    return i, j
        return -1, -1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return {
            'context': torch.tensor(item['context'], dtype=torch.long),
            'question': torch.tensor(item['question'], dtype=torch.long),
            'start': torch.tensor(item['start'], dtype=torch.long),
            'end': torch.tensor(item['end'], dtype=torch.long)
        }

# ================================
# Step 2: Model definition
# ================================
class QAModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=100, hidden_dim=128):
        super(QAModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.passage_lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.question_lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.linear_start = nn.Linear(hidden_dim * 4, 1)
        self.linear_end = nn.Linear(hidden_dim * 4, 1)

    def forward(self, context, question):
        context_embed = self.embedding(context)
        question_embed = self.embedding(question)

        passage_output, _ = self.passage_lstm(context_embed)
        question_output, _ = self.question_lstm(question_embed)

        question_repr = torch.mean(question_output, dim=1)
        question_repr = question_repr.unsqueeze(1).repeat(1, passage_output.size(1), 1)

        combined = torch.cat([passage_output, question_repr], dim=-1)

        start_logits = self.linear_start(combined).squeeze(-1)
        end_logits = self.linear_end(combined).squeeze(-1)

        return start_logits, end_logits

# ================================
# Helper functions
# ================================
def create_mask(tensor, pad_idx=0):
    return (tensor != pad_idx).float()

def masked_logits(logits, mask):
    return logits * mask - 1e9 * (1 - mask)

# ================================
# Step 3: Training loop
# ================================
# Hyperparameters
max_vocab_size = 10000
max_passage_length = 300
max_question_length = 30
batch_size = 8
num_epochs = 3
learning_rate = 0.001

# Load dataset
train_dataset = SquadDataset('train-v1.1.json', max_vocab_size, max_passage_length, max_question_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

vocab_size = len(train_dataset.word2idx)
model = QAModel(vocab_size).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        context = batch['context'].to(device)
        question = batch['question'].to(device)
        start_pos = batch['start'].to(device)
        end_pos = batch['end'].to(device)

        mask = create_mask(context).to(device)

        optimizer.zero_grad()
        start_logits, end_logits = model(context, question)

        start_logits = masked_logits(start_logits, mask)
        end_logits = masked_logits(end_logits, mask)

        loss_start = criterion(start_logits, start_pos)
        loss_end = criterion(end_logits, end_pos)
        loss = loss_start + loss_end

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

# ================================
# Step 4: Inference
# ================================
def predict(context_text, question_text, model, dataset, max_passage_length=300, max_question_length=30):
    model.eval()
    context_tokens = dataset.tokenize(context_text)
    context_ids = dataset.encode(context_tokens, max_passage_length)

    question_tokens = dataset.tokenize(question_text)
    question_ids = dataset.encode(question_tokens, max_question_length)

    context_tensor = torch.tensor([context_ids], dtype=torch.long).to(device)
    question_tensor = torch.tensor([question_ids], dtype=torch.long).to(device)
    mask = create_mask(context_tensor).to(device)

    with torch.no_grad():
        start_logits, end_logits = model(context_tensor, question_tensor)
        start_logits = masked_logits(start_logits, mask)
        end_logits = masked_logits(end_logits, mask)
        start_idx = torch.argmax(start_logits, dim=1).item()
        end_idx = torch.argmax(end_logits, dim=1).item()

    if start_idx > end_idx or end_idx >= len(context_tokens):
        return "I couldn't find a valid answer."

    answer_tokens = context_tokens[start_idx:end_idx + 1]
    predicted_answer = dataset.detokenize(answer_tokens)

    return predicted_answer

# ================================
# Example usage
# ================================
context_example = "The Transformers library was created by Hugging Face in 2018 to provide state-of-the-art NLP models."
question_example = "Who created the Transformers library?"

answer = predict(context_example, question_example, model, train_dataset)
print("\n=== Prediction Example ===")
print(f"Context: {context_example}")
print(f"Question: {question_example}")
print(f"Predicted Answer: {answer}")

Using device: cpu
