In [None]:
!pip install transformers datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np
import pandas as pd
import ast
from transformers import AutoTokenizer
import math
import os
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
import re

class DistilBertConfig:
    def __init__(self):
        self.vocab_size = 30522
        self.max_position_embeddings = 512
        self.sinusoidal_pos_embds = False
        self.n_layers = 6
        self.n_heads = 12
        self.dim = 768
        self.hidden_dim = 3072
        self.dropout = 0.1
        self.attention_dropout = 0.1
        self.activation = 'gelu'
        self.initializer_range = 0.02
        self.qa_dropout = 0.1

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.dim = config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)
        
        assert self.dim % self.n_heads == 0
        self.head_dim = self.dim // self.n_heads
        
        self.q_lin = nn.Linear(config.dim, config.dim)
        self.k_lin = nn.Linear(config.dim, config.dim)
        self.v_lin = nn.Linear(config.dim, config.dim)
        self.out_lin = nn.Linear(config.dim, config.dim)
        
    def forward(self, query, key, value, mask=None):
        batch_size, seq_len, dim = query.size()

        q = self.q_lin(query).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_lin(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_lin(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores.masked_fill_(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, dim)
 
        output = self.out_lin(context)
        return output

class FFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.lin1 = nn.Linear(config.dim, config.hidden_dim)
        self.lin2 = nn.Linear(config.hidden_dim, config.dim)
        self.dropout = nn.Dropout(p=config.dropout)
        self.activation = nn.GELU()
        
    def forward(self, x):
        x = self.lin1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.lin2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadSelfAttention(config)
        self.sa_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.ffn = FFN(config)
        self.output_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(p=config.dropout)
        
    def forward(self, x, attn_mask=None):
        sa_output = self.attention(x, x, x, mask=attn_mask)
        sa_output = self.dropout(sa_output)
        sa_output = self.sa_layer_norm(sa_output + x)

        ffn_output = self.ffn(sa_output)
        ffn_output = self.dropout(ffn_output)
        output = self.output_layer_norm(ffn_output + sa_output)
        
        return output

class DistilBertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=0)
        if not config.sinusoidal_pos_embds:
            self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)

        self.transformer = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])
        
        self.dropout = nn.Dropout(p=config.dropout)

        self.init_weights()
        
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
                
    def get_position_embeddings(self, seq_len, device):
        if self.config.sinusoidal_pos_embds:
            position = torch.arange(seq_len, device=device).unsqueeze(1).float()
            div_term = torch.exp(torch.arange(0, self.config.dim, 2, device=device).float() * 
                               (-math.log(10000.0) / self.config.dim))
            pos_emb = torch.zeros(seq_len, self.config.dim, device=device)
            pos_emb[:, 0::2] = torch.sin(position * div_term)
            pos_emb[:, 1::2] = torch.cos(position * div_term)
            return pos_emb
        else:
            positions = torch.arange(seq_len, device=device)
            return self.position_embeddings(positions)
    
    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.size()
        device = input_ids.device

        embeddings = self.embeddings(input_ids)

        position_embeddings = self.get_position_embeddings(seq_len, device)
        embeddings += position_embeddings
        
        embeddings = self.dropout(embeddings)
        
        hidden_state = embeddings
        for layer in self.transformer:
            hidden_state = layer(hidden_state, attention_mask)
            
        return hidden_state

class DistilBertForQuestionAnswering(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.distilbert = DistilBertModel(config)
        self.qa_outputs = nn.Linear(config.dim, 2)  
        self.dropout = nn.Dropout(p=config.qa_dropout)
        
    def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None):
        sequence_output = self.distilbert(input_ids, attention_mask)
        sequence_output = self.dropout(sequence_output)
        
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        
        outputs = {
            'start_logits': start_logits,
            'end_logits': end_logits
        }
        
        if start_positions is not None and end_positions is not None:
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)
            
            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs['loss'] = total_loss
            
        return outputs

class SQuADDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []

        print(f"Loading data from {data_path}")
        df = pd.read_csv(data_path)
        print(f"Loaded {len(df)} rows from CSV")

        print("First row sample:")
        print(df.head(1).to_dict('records'))
        
        successful_parses = 0
        failed_parses = 0
        
        for idx, row in df.iterrows():
            try:
                answers_str = str(row['answers'])

                if idx < 3:
                    print(f"Row {idx} answers format: {answers_str}")
                
                answer_text = None
                answer_start = None

                answer_text, answer_start = self._parse_answers(answers_str)
 
                if answer_text is not None and answer_start is not None:
                    answer_end = answer_start + len(answer_text)
                    
                    self.examples.append({
                        'id': str(row['id']),
                        'question': str(row['question']),
                        'context': str(row['context']),
                        'answer_text': answer_text,
                        'answer_start': answer_start,
                        'answer_end': answer_end
                    })
                    successful_parses += 1
                else:
                    failed_parses += 1
                    if failed_parses <= 5:  
                        print(f"Failed to parse row {idx}: {answers_str}")
                    
            except Exception as e:
                failed_parses += 1
                if failed_parses <= 5: 
                    print(f"Error parsing row {idx}: {e}")
                continue
        
        print(f"Successfully parsed {successful_parses} examples")
        print(f"Failed to parse {failed_parses} examples")
        print(f"Total examples in dataset: {len(self.examples)}")
        
        if len(self.examples) == 0:
            raise ValueError("No examples could be parsed from the dataset. Please check the data format.")
    
    def _parse_answers(self, answers_str):
        """Enhanced parsing function for the answers column"""
        try:
            text_pattern = r"'text':\s*array\(\[([^\]]+)\],\s*dtype=object\)"
            start_pattern = r"'answer_start':\s*array\(\[([^\]]+)\],\s*dtype=int32\)"
            
            text_match = re.search(text_pattern, answers_str)
            start_match = re.search(start_pattern, answers_str)
            
            if text_match and start_match:
                text_content = text_match.group(1)
                start_content = start_match.group(1)

                text_items = re.findall(r"'([^']*)'|\"([^\"]*)\"", text_content)
                if text_items:
                    answer_text = text_items[0][0] or text_items[0][1]
                else:
                    return None, None

                start_items = re.findall(r'\d+', start_content)
                if start_items:
                    answer_start = int(start_items[0])
                else:
                    return None, None
                
                return answer_text, answer_start

            try:
                cleaned_str = answers_str.replace('array(', '[').replace(', dtype=object)', ']').replace(', dtype=int32)', ']')

                parsed_dict = ast.literal_eval(cleaned_str)
                
                if 'text' in parsed_dict and 'answer_start' in parsed_dict:
                    text_list = parsed_dict['text']
                    start_list = parsed_dict['answer_start']
                    
                    if isinstance(text_list, list) and len(text_list) > 0:
                        answer_text = text_list[0]
                    else:
                        answer_text = text_list
                    
                    if isinstance(start_list, list) and len(start_list) > 0:
                        answer_start = start_list[0]
                    else:
                        answer_start = start_list
                    
                    return answer_text, int(answer_start)
            except:
                pass

            if "'text':" in answers_str and "'answer_start':" in answers_str:
                text_start_idx = answers_str.find("'text':") + len("'text':")
                text_section = answers_str[text_start_idx:]
                
                if "array([" in text_section:
                    array_start = text_section.find("array([") + len("array([")
                    array_end = text_section.find("]", array_start)
                    if array_end > array_start:
                        text_array_content = text_section[array_start:array_end]

                        text_matches = re.findall(r"'([^']*)'|\"([^\"]*)\"", text_array_content)
                        if text_matches:
                            answer_text = text_matches[0][0] or text_matches[0][1]
                        else:
                            return None, None
                
                start_start_idx = answers_str.find("'answer_start':") + len("'answer_start':")
                start_section = answers_str[start_start_idx:]
                
                if "array([" in start_section:
                    array_start = start_section.find("array([") + len("array([")
                    array_end = start_section.find("]", array_start)
                    if array_end > array_start:
                        start_array_content = start_section[array_start:array_end]

                        start_matches = re.findall(r'\d+', start_array_content)
                        if start_matches:
                            answer_start = int(start_matches[0])
                            return answer_text, answer_start
            
            return None, None
            
        except Exception as e:
            print(f"Error in _parse_answers: {e}")
            return None, None
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        encoding = self.tokenizer(
            example['question'],
            example['context'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_offsets_mapping=True
        )
        
        offset_mapping = encoding['offset_mapping'].squeeze()
        start_positions = 0
        end_positions = 0
        
        for i, (start_offset, end_offset) in enumerate(offset_mapping):
            if start_offset <= example['answer_start'] < end_offset:
                start_positions = i
            if start_offset < example['answer_end'] <= end_offset:
                end_positions = i
                break
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'start_positions': torch.tensor(start_positions, dtype=torch.long),
            'end_positions': torch.tensor(end_positions, dtype=torch.long)
        }

def load_pretrained_weights(model, pretrained_model_name='distilbert-base-uncased'):
    """Load pretrained DistilBERT weights from transformers library"""
    try:
        from transformers import DistilBertModel
        pretrained_model = DistilBertModel.from_pretrained(pretrained_model_name)

        model_dict = model.state_dict()
        pretrained_dict = {}
        
        for name, param in pretrained_model.named_parameters():
            if name.startswith('transformer.layer.'):
                new_name = name.replace('transformer.layer.', 'distilbert.transformer.')
            elif name.startswith('embeddings.'):
                if 'word_embeddings' in name:
                    new_name = name.replace('embeddings.word_embeddings', 'distilbert.embeddings')
                elif 'position_embeddings' in name:
                    new_name = name.replace('embeddings.position_embeddings', 'distilbert.position_embeddings')
                else:
                    continue
            else:
                new_name = f'distilbert.{name}'
            
            if new_name in model_dict and param.shape == model_dict[new_name].shape:
                pretrained_dict[new_name] = param
        
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        print(f"Loaded {len(pretrained_dict)} pretrained parameters")
        
    except ImportError:
        print("transformers library not found. Training from scratch.")
    except Exception as e:
        print(f"Error loading pretrained weights: {e}")


def train_model(model, train_loader, val_loader, epochs=3, lr=2e-5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    print(f"Training on device: {device}")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=epochs)

    os.makedirs('checkpoints', exist_ok=True)
    
    model.train()
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        total_loss = 0
        train_pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}", unit="batch")
        
        for batch_idx, batch in enumerate(train_pbar):

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)
            
            optimizer.zero_grad()

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                start_positions=start_positions,
                end_positions=end_positions
            )
            
            loss = outputs['loss']
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()

            train_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg Loss': f'{total_loss/(batch_idx+1):.4f}'
            })

        model.eval()
        val_loss = 0
        val_pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}", unit="batch")
        
        with torch.no_grad():
            for batch in val_pbar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                start_positions = batch['start_positions'].to(device)
                end_positions = batch['end_positions'].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    start_positions=start_positions,
                    end_positions=end_positions
                )
                
                val_loss += outputs['loss'].item()
                val_pbar.set_postfix({'Val Loss': f'{outputs["loss"].item():.4f}'})
        
        model.train()
        scheduler.step()
        
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0

        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
        }
        
        checkpoint_path = f'checkpoints/distilbert_qa_epoch_{epoch+1}.pth'
        torch.save(checkpoint, checkpoint_path)
        
        print(f'Epoch {epoch+1} completed - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
        print(f'Checkpoint saved: {checkpoint_path}')
    
    print("\nTraining completed!")

def main():
    config = DistilBertConfig()
    model = DistilBertForQuestionAnswering(config)

    load_pretrained_weights(model, 'distilbert-base-uncased')

    tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

    train_dataset = SQuADDataset('/kaggle/input/the-stanford-question-answering-dataset/train.csv', tokenizer)

    try:
        val_dataset = SQuADDataset('/kaggle/input/the-stanford-question-answering-dataset/validation.csv', tokenizer)
    except FileNotFoundError:

        from torch.utils.data import random_split
        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    print(f"Training examples: {len(train_dataset)}")
    print(f"Validation examples: {len(val_dataset)}")

    train_model(model, train_loader, val_loader, epochs=10, lr=2e-5)

    torch.save(model.state_dict(), 'distilbert_qa_model.pth')
    print("Model saved successfully!")

if __name__ == "__main__":
    main()


In [None]:
import torch
from transformers import AutoTokenizer
import pandas as pd
import re
import random

def load_trained_model(model_path='distilbert_qa_model.pth'):
    """Load the trained DistilBERT QA model"""
    config = DistilBertConfig()
    model = DistilBertForQuestionAnswering(config)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    
    print(f"Model loaded successfully from {model_path}")
    print(f"Running on device: {device}")
    
    return model, device

def load_contexts_from_csv(csv_path='/kaggle/input/the-stanford-question-answering-dataset/train.csv'):
    """Load all contexts from the training CSV file"""
    print(f"Loading contexts from {csv_path}")
    df = pd.read_csv(csv_path)

    contexts = df['context'].drop_duplicates().tolist()
    titles = df['title'].drop_duplicates().tolist()

    context_mapping = {}
    for _, row in df.iterrows():
        context_mapping[row['context']] = row['title']
    
    print(f"Loaded {len(contexts)} unique contexts from {len(df)} total examples")
    return contexts, context_mapping

def find_relevant_context(question, contexts, context_mapping, top_k=3):
    """Find the most relevant context for a given question using simple keyword matching"""
    question_words = set(question.lower().split())
    
    context_scores = []
    for context in contexts:
        context_words = set(context.lower().split())

        overlap = len(question_words.intersection(context_words))
        relevance_score = overlap / len(question_words) if question_words else 0
        
        context_scores.append({
            'context': context,
            'title': context_mapping.get(context, 'Unknown'),
            'score': relevance_score
        })

    context_scores.sort(key=lambda x: x['score'], reverse=True)
    return context_scores[:top_k]

def answer_question(model, tokenizer, question, context, device, max_length=512):
    """Answer a question given a context using the trained model"""
    inputs = tokenizer(
        question,
        context,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt',
        return_offsets_mapping=True
    )
    
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    offset_mapping = inputs['offset_mapping']
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        start_logits = outputs['start_logits']
        end_logits = outputs['end_logits']
    
    start_idx = torch.argmax(start_logits, dim=1).item()
    end_idx = torch.argmax(end_logits, dim=1).item()
    
    if end_idx < start_idx:
        end_idx = start_idx
    
    start_confidence = torch.softmax(start_logits, dim=1)[0, start_idx].item()
    end_confidence = torch.softmax(end_logits, dim=1)[0, end_idx].item()
    confidence = (start_confidence + end_confidence) / 2
    
    answer_tokens = input_ids[0, start_idx:end_idx+1]
    answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    
    if len(offset_mapping[0]) > start_idx and len(offset_mapping[0]) > end_idx:
        start_char = offset_mapping[0][start_idx][0].item()
        end_char = offset_mapping[0][end_idx][1].item()
        
        sep_token_idx = (input_ids[0] == tokenizer.sep_token_id).nonzero(as_tuple=True)[0]
        if len(sep_token_idx) > 0:
            context_start_token = sep_token_idx[0].item() + 1
            if start_idx >= context_start_token:
                context_offset = offset_mapping[0][context_start_token][0].item()
                start_char_in_context = start_char - context_offset
                end_char_in_context = end_char - context_offset
                
                if start_char_in_context >= 0 and end_char_in_context <= len(context):
                    answer_from_context = context[start_char_in_context:end_char_in_context]
                    if answer_from_context.strip():
                        answer = answer_from_context.strip()
    
    return {
        'answer': answer,
        'confidence': confidence,
        'context_title': None 
    }

print("Loading trained model and tokenizer...")
model, device = load_trained_model('distilbert_qa_model.pth')
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

print("\nLoading contexts from training data...")
contexts, context_mapping = load_contexts_from_csv()


question = "Which NFL team represented the AFC at Super Bowl 50?"


print("\n" + "="*60)
print("QUESTION ANSWERING WITH AUTOMATIC CONTEXT SELECTION")
print("="*60)
print(f"Question: {question}")


relevant_contexts = find_relevant_context(question, contexts, context_mapping, top_k=3)

print(f"\nTop {len(relevant_contexts)} most relevant contexts found:")
for i, ctx_info in enumerate(relevant_contexts, 1):
    print(f"{i}. {ctx_info['title']} (relevance: {ctx_info['score']:.3f})")

if relevant_contexts:
    best_context = relevant_contexts[0]
    result = answer_question(model, tokenizer, question, best_context['context'], device)
    result['context_title'] = best_context['title']
    
    print(f"\n{'='*60}")
    print("ANSWER")
    print("="*60)
    print(f"🎯 {result['answer']}")
    print(f"📊 Confidence: {result['confidence']:.2%}")
    print(f"📚 Source: {result['context_title']}")
    print(f"🔍 Context relevance: {best_context['score']:.3f}")

    context_preview = best_context['context'][:200] + "..." if len(best_context['context']) > 200 else best_context['context']
    print(f"\n📖 Context used: {context_preview}")
    
else:
    print("❌ No relevant context found for this question.")


print(f"\n{'='*50}")
print("🎲 RANDOM EXAMPLES FROM YOUR TRAINING DATA:")
print("="*50)
sample_contexts = random.sample(list(context_mapping.items()), min(3, len(context_mapping)))
for context, title in sample_contexts:
    print(f"📚 {title}")
    preview = context[:150] + "..." if len(context) > 150 else context
    print(f"   {preview}\n")
