In [None]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm, trange

from src.config import DCNConfig
from src.squad import SquadDataset
from src.model import CoattentionModel
from src.glove import GloVeEmbeddings


config = DCNConfig()

glove = GloVeEmbeddings(embedding_dim=config.glove_dim)
glove.load_glove_embeddings(config.glove_path)

train_dataset = SquadDataset(glove.word_to_idx, split="train")
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)

eval_dataset = SquadDataset(glove.word_to_idx, split="validation")
eval_dataloader = DataLoader(eval_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm
INFO:glove:Loading GloVe embeddings from ../glove_embeddings/glove.840B.300d.txt
INFO:glove:Processed 0 lines
INFO:glove:Processed 100000 lines
INFO:glove:Processed 200000 lines
INFO:glove:Processed 300000 lines
INFO:glove:Processed 400000 lines
INFO:glove:Processed 500000 lines
INFO:glove:Processed 600000 lines
INFO:glove:Processed 700000 lines
INFO:glove:Processed 800000 lines
INFO:glove:Processed 900000 lines
INFO:glove:Processed 1000000 lines
INFO:glove:Processed 1100000 lines
INFO:glove:Processed 1200000 lines
INFO:glove:Processed 1300000 lines
INFO:glove:Processed 1400000 lines
INFO:glove:Processed 1500000 lines
INFO:glove:Processed 1600000 lines
INFO:glove:Processed 1700000 lines
INFO:glove:Processed 1800000 lines
INFO:glove:Processed 1900000 lines
INFO:glove:Processed 2000000 lines
INFO:glove:Processed 2100000 lines
INFO:glove:Loaded 2196021 words with 300d embeddings


In [16]:
print(f"Embedding matrix shape: {glove.get_embedding_matrix().shape}") # Should be (vocab_size, embedding_dim)
print(f"Vocabulary size: {len(glove.word_to_idx)}")
print(f"Index of 'the': {glove.word_to_idx.get('the', 'Not found')}")
print(f"Index of 'McDonald': {glove.word_to_idx.get('McDonald', 'Not found')}")
print(f"Index of '<PAD>': {glove.word_to_idx.get('<PAD>', 'Not found')}")
print(f"Index of '<UNK>': {glove.word_to_idx.get('<UNK>', 'Not found')}")
print(f"Total training samples: {len(train_dataset)}")
print(f"Total validation samples: {len(eval_dataset)}")
print("")

sample_idx = 33
print(f"\nSample #{sample_idx}")
print(f"Context: {eval_dataset.context_data[sample_idx][:190]}...")
print(f"Question: {eval_dataset.question_data[sample_idx]}")
print(f"Answer span: {eval_dataset.answer_span_data[sample_idx]}")
# answer_text = eval_dataset.context_data[sample_idx]
# print(f"Answer: '{answer_text}'")
print("")

print("Embedded view:")
context_ids, context_len, question_ids, question_len, answer_span, qid = eval_dataset[sample_idx]

print(f"Context IDs: {context_ids}")
print(f"Context length: {context_len}")
print(f"Question IDs: {question_ids}")
print(f"Question length: {question_len}")
print(f"Answer span: {answer_span}")


Embedding matrix shape: (2196021, 300)
Vocabulary size: 2196021
Index of 'the': 6
Index of 'McDonald': 9172
Index of '<PAD>': 0
Index of '<UNK>': 0
Total training samples: 87580
Total validation samples: 10570


Sample #33
Context: The Panthers finished the regular season with a 15–1 record, and quarterback Cam Newton was named the NFL Most Valuable Player (MVP). They defeated the Arizona Cardinals 49–15 in the NFC Cha...
Question: What team did the Panthers defeat?
Answer span: (25, 26)

Embedded view:
Context IDs: tensor([     6,  83217,   1697,      6,   1446,    547,     23,     10,    299,
          7956,     70,    922,      4,      7,   9484,   7529,  74034,     34,
          1654,      6,  22375,    126,   3505,    986,     18,  96888,     17,
             5,     53,   8123,      6,  20240,  50141,   3847,   7956,    299,
            11,      6, 147525,   6486,    243,      7,   2141,      8,     62,
           345,   2261,   3975,   3071,    270,      6,   6426,     34,   4742

In [3]:
def lengths_to_mask(lengths, max_len):
    """
    Convert sequence lengths to binary masks.
    
    Args:
        lengths: Tensor of shape (batch_size,) containing sequence lengths
        max_len: Maximum sequence length. If None, uses the maximum from lengths
        
    Returns:
        mask: Binary tensor of shape (batch_size, max_len) where 1 indicates valid positions
    """
    # Create a range tensor [0, 1, 2, ..., max_len-1]
    indices = torch.arange(0, max_len, dtype=lengths.dtype, device=lengths.device)
    
    # Expand dimensions: lengths -> (batch_size, 1), indices -> (1, max_len)
    # Then broadcast and compare
    mask = indices.unsqueeze(0) < lengths.unsqueeze(1)
    
    return mask.long()  # Convert boolean to 0/1 integers

## Train

In [None]:
model = CoattentionModel(config.hidden_dim, config.maxout_pool_size, glove.get_embedding_matrix(), config.max_dec_steps, config.dropout_ratio)
use_cuda = False
if use_cuda:
    model = model.cuda()

opt = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.reg_lambda)

# Training tracking
best_eval_loss = float('inf')
train_losses = []
eval_losses = []

os.makedirs(config.model_save_path, exist_ok=True)
device = torch.device('cuda' if use_cuda else 'cpu')
print("Training started!")
for epoch in trange(10, desc="Epoch"):
    epoch_train_loss = 0
    num_batches = 0

    for iteration, batch in enumerate(tqdm(train_dataloader)):
        
        # Skip batches if enabled
        if config.skip_frequency > 1 and iteration % config.skip_frequency != 0:
            continue
        
        # Unpack the batch
        context, context_lens, question, question_lens, answer_spans, _ = batch

        context = context.to(device)
        context_lens = context_lens.view(-1).to(device)
        question = question.to(device)
        question_lens = question_lens.view(-1).to(device)
        answer_spans = answer_spans.to(device)

        context_mask = lengths_to_mask(context_lens, config.context_len)
        question_mask = lengths_to_mask(question_lens, config.question_len)
        
        # Reset gradients
        opt.zero_grad()
        
        # === Forward pass ===
        loss, _, _ = model(context, context_mask, question, question_mask, context_lens, answer_spans)
        
        # === Backpropagation ===
        loss.backward()
        clip_grad_norm_(model.parameters(), config.max_grad_norm)
        opt.step()

        total_loss = loss.item()
        epoch_train_loss += total_loss
        num_batches += 1

        if num_batches % config.print_frequency == 0:
            print(f"Epoch: {epoch+1} Iteration: {iteration+1} loss: {total_loss}")
            
    # Calculate average training loss for this epoch
    avg_train_loss = epoch_train_loss / num_batches
    train_losses.append(avg_train_loss)
    
    # Validation at the end of each epoch
    if (epoch + 1) % config.eval_frequency == 0:  # Note: moved to epoch level
        print("Running validation...")
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for val_batch in eval_dataloader:
                context, context_lens, question, question_lens, answer_spans, _ = val_batch
                
                # Move validation data to device
                context = context.to(device)
                context_lens = context_lens.view(-1).to(device)
                question = question.to(device)
                question_lens = question_lens.view(-1).to(device)
                answer_spans = answer_spans.to(device)
                
                context_mask = lengths_to_mask(context_lens, config.context_len)
                question_mask = lengths_to_mask(question_lens, config.question_len)

                loss, _, _ = model(context, context_mask, question, question_mask, context_lens, answer_spans)

                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(eval_dataloader)
        eval_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1} - Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_eval_loss:
            best_eval_loss = avg_val_loss
            torch.save(model.state_dict(), os.path.join(config.model_save_path, 'best_model.pt'))
            print("New best model saved!")
        
        model.train()  # Switch back to training mode

    # === Save model checkpoint ===
    print("Saving model checkpoint...")
    torch.save(model.state_dict(), os.path.join(config.model_save_path, f'model_epoch_{epoch+1}.pt'))

print("Training completed!!!")

Training started!


100%|██████████| 1/1 [00:01<00:00,  1.77s/it]

Running validation...





Epoch 1 - Training Loss: 6.2035, Validation Loss: 4.9628
New best model saved!
Saving model checkpoint...


100%|██████████| 1/1 [00:01<00:00,  1.86s/it].94s/it]

Running validation...





Epoch 2 - Training Loss: 2.4814, Validation Loss: 4.9628
New best model saved!
Saving model checkpoint...


100%|██████████| 1/1 [00:02<00:00,  2.28s/it].23s/it]

Running validation...





Epoch 3 - Training Loss: 2.4814, Validation Loss: 4.9628
New best model saved!
Saving model checkpoint...


100%|██████████| 1/1 [00:02<00:00,  2.16s/it].54s/it]

Running validation...





Epoch 4 - Training Loss: 2.4814, Validation Loss: 2.4814
New best model saved!
Saving model checkpoint...


100%|██████████| 1/1 [00:01<00:00,  1.80s/it].03s/it]

Running validation...





Epoch 5 - Training Loss: 3.7221, Validation Loss: 6.2035
Saving model checkpoint...


100%|██████████| 1/1 [00:01<00:00,  1.99s/it].58s/it]

Running validation...





Epoch 6 - Training Loss: 6.2034, Validation Loss: 4.9628
Saving model checkpoint...


100%|██████████| 1/1 [00:02<00:00,  2.06s/it].58s/it]

Running validation...





Epoch 7 - Training Loss: 2.4814, Validation Loss: 6.2035
Saving model checkpoint...


100%|██████████| 1/1 [00:01<00:00,  1.96s/it].96s/it]

Running validation...





Epoch 8 - Training Loss: 2.4814, Validation Loss: 6.2035
Saving model checkpoint...


100%|██████████| 1/1 [00:02<00:00,  2.27s/it].63s/it]

Running validation...





Epoch 9 - Training Loss: 2.4814, Validation Loss: 6.2035
Saving model checkpoint...


100%|██████████| 1/1 [00:01<00:00,  1.88s/it].48s/it]

Running validation...





Epoch 10 - Training Loss: 2.4814, Validation Loss: 2.4814
New best model saved!
Saving model checkpoint...


Epoch: 100%|██████████| 10/10 [01:16<00:00,  7.66s/it]

Training completed!!!





## Plots

In [19]:
# Optional: Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(eval_losses, label='Evaluation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Evaluation Loss')
plt.legend()
plt.grid(True)
# plt.savefig('training_curves.png')
plt.show()

ModuleNotFoundError: No module named 'matplotlib'

In [4]:
import json
from collections import OrderedDict
from tqdm import tqdm
model = CoattentionModel(config.hidden_dim, config.maxout_pool_size, glove.get_embedding_matrix(), config.max_dec_steps, config.dropout_ratio)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load best model
model.load_state_dict(torch.load('models/best_model.pt', map_location=device))
model.eval()

predictions = OrderedDict()

with torch.no_grad():
    for val_batch in tqdm(eval_dataloader, desc="Evaluating"):
        context, context_lens, question, question_lens, answer_spans, qids = val_batch

        context = context.to(device)
        context_lens = context_lens.view(-1).to(device)
        question = question.to(device)
        question_lens = question_lens.view(-1).to(device)

        context_mask = lengths_to_mask(context_lens, config.context_len)
        question_mask = lengths_to_mask(question_lens, config.question_len)

        _, start_indices, end_indices = model(context, context_mask, question, question_mask, context_lens, answer_spans)


        batch_size = context.size(0)
        for i in range(batch_size):
            start = start_indices[i].item()
            end = end_indices[i].item()
            
            context_tokens = [glove.index_to_word(idx.item()) for idx in context[i]]

            # Clamp indices to valid range
            start = max(0, min(start, len(context_tokens) - 1))
            end = max(0, min(end, len(context_tokens) - 1))
            if end < start:
                end = start

            predicted_tokens = context_tokens[start:end+1]
            predicted_text = " ".join(predicted_tokens).strip()

            qid = qids[i]
            if isinstance(qid, torch.Tensor):  
                qid = qid.item() if qid.dim() == 0 else qid[0]

            predictions[qid] = predicted_text

# Save predictions to JSON
with open("predictions.json", "w", encoding="utf-8") as f:
    json.dump(predictions, f, indent=2)

print("Predictions saved to predictions.json.")


Evaluating: 100%|██████████| 331/331 [13:12<00:00,  2.39s/it]

Predictions saved to predictions.json.





In [5]:
""" Official evaluation script for v1.1 of the SQuAD dataset. """
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(dataset, predictions):
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    message = 'Unanswered question ' + qa['id'] + \
                              ' will receive score 0.'
                    print(message, file=sys.stderr)
                    continue
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                prediction = predictions[qa['id']]
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}



In [6]:
import json

with open("../dev-v1.1.json", "r", encoding="utf-8") as dev_file:
    dev_data = json.load(dev_file)["data"]  

with open("predictions.json", "r", encoding="utf-8") as pred_file:
    predictions = json.load(pred_file)  



results = evaluate(dev_data, predictions)


print(f"Exact Match: {results['exact_match']:.2f}")
print(f"F1 Score: {results['f1']:.2f}")


Exact Match: 1.22
F1 Score: 4.13


In [None]:
N = 50
count = 0

for article in dev_data:
    for paragraph in article["paragraphs"]:
        for qa in paragraph["qas"]:
            qid = qa["id"]
            question = qa["question"]
            dev_answers = [a["text"] for a in qa["answers"]]
            predicted = predictions.get(qid, "[NO PREDICTION]")
            print(f"Q: {question}")
            print(f"Dev: {dev_answers}")
            print(f"Pred: {predicted}")
            print("-" * 80)

            count += 1
            if count >= N:
                break
        if count >= N:
            break
    if count >= N:
        break


Q: Which NFL team represented the AFC at Super Bowl 50?
Dev: ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']
Pred: bowl
--------------------------------------------------------------------------------
Q: Which NFL team represented the NFC at Super Bowl 50?
Dev: ['Carolina Panthers', 'Carolina Panthers', 'Carolina Panthers']
Pred: bowl
--------------------------------------------------------------------------------
Q: Where did Super Bowl 50 take place?
Dev: ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."]
Pred: bowl
--------------------------------------------------------------------------------
Q: Which NFL team won Super Bowl 50?
Dev: ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']
Pred: bowl
--------------------------------------------------------------------------------
Q: What color was used to emphasize the 50th anniversary of the Super Bowl?
Dev: ['gold', 'gold', 'gold']
Pred: bowl
-----------