# Key-Value Memory Network for Wikipedia Biography QA

This notebook demonstrates the usage of the Key-Value Memory Network implementation for question answering tasks.

## Env Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import os
import json
import random
import numpy as np
import pickle
from tqdm import tqdm
import nltk
import unidecode

from model.memory_network import KVMemoryNetwork
from utils.data_utils import Vocab, multihot, tokenize

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Load Pre-processed Data

In [None]:
# Download the dataset
!git clone https://github.com/rlebret/wikipedia-biography-dataset.git
!cat wikipedia-biography-dataset/wikipedia-biography-dataset.z?? > tmp.zip
!unzip -o tmp.zip
!rm tmp.zip

# Get titles
train_titles = []
with open("wikipedia-biography-dataset/train/train.title", "r") as file:
    for line in file:
        train_titles.append(line.rstrip())

# Get boxes
train_boxes = []
with open("wikipedia-biography-dataset/train/train.box", "r") as file:
    for line in file:
        train_boxes.append(line.rstrip())

In [None]:
def make_db(titles, boxes):
    """Create database from titles and boxes."""
    db = {}  
    for i in tqdm(range(len(titles))):
        box = boxes[i]
        d = {}
        for pair in re.findall(r'([a-zA-Z_]+)[0-9]*\:([\\w\\d]+)', box):
            key, value = pair
            key = key.strip()
            value = value.strip()
            if 'image' not in key:
                if key[-1] == '_':
                    key = key[:-1]
                if key not in d:
                    d[key] = value
                else:
                    d[key] += ' ' + value
        if 'office' in d:
            db[titles[i]] = d
    return db

DB = make_db(train_titles, train_boxes)
print(f"Created database with {len(DB)} entries")

def make_vocab(DB):
    """Create vocabulary from database."""
    vocab = Vocab()
    tokens = tokenize(str(DB))
    for t in tqdm(tokens):
        vocab.add_word(t)
    return vocab

VOCAB = make_vocab(DB)
print(f"Created vocabulary with {VOCAB.num_words()} words")

## Prepare Training Data

In [None]:
def create_formatted_datasets(DB, vocab, train_size=500, test_size=100):
    """Create training and testing datasets from the biographical database."""
    raw_data = []
    count = 0
    
    # Process each person in DB
    for name, content in DB.items():
        if count >= train_size + test_size:
            break
        person_data = []
        for key, value in content.items():
            # Store raw strings
            question_key = f"{name} {key}"
            question_value = f"{name} {value}"
            person_data.append((name, question_key, question_value))
        if person_data:
            raw_data.append(person_data)
            count += 1
            if count % 100 == 0:
                print(f"Processed {count} entries")

    # Split into train and test
    train_raw = raw_data[:train_size]
    test_raw = raw_data[train_size:train_size + test_size]
    
    # Convert to tensors
    train_data = []
    for person_data in raw_data[:train_size]:
        person_tensors = []
        for name, q_key, q_val in person_data:
            question = torch.tensor(multihot(name, vocab), dtype=torch.float32)
            key_vec = torch.tensor(multihot(q_key, vocab), dtype=torch.float32)
            value_vec = torch.tensor(multihot(q_val, vocab), dtype=torch.float32)
            person_tensors.append((question, key_vec, value_vec))
        if person_tensors:
            questions = torch.stack([p[0] for p in person_tensors])
            keys = torch.stack([p[1] for p in person_tensors])
            values = torch.stack([p[2] for p in person_tensors])
            train_data.append((questions, keys, values))
            
    return train_data, test_raw

# Create datasets
train_data, test_data = create_formatted_datasets(DB, VOCAB, train_size=500, test_size=100)
print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")

## Initialize and Train Model

In [None]:
# Model initialization
vocab_size = len(VOCAB._word2index)
embed_size = 256

model = KVMemoryNetwork(vocab_size, embed_size)
model = model.to(device)
print("Model initialized")

# Train the model
losses, accuracies = train_model(
    model,
    train_data,
    num_epochs=10,
    learning_rate=0.001,
    device=device
)

# Plot training curves
plot_training_curves(losses, accuracies)

## Question Answering

Test the model with some example questions about historical figures.

In [None]:
def answer_question(question, person, model, DB, vocab, device):
    """Answer a question about a person using the trained model."""
    if person not in DB:
        return "Person not found in database."
        
    # Process question
    q_tensor = torch.tensor(multihot(question, vocab), dtype=torch.float32).unsqueeze(0).to(device)
    
    # Get person's data
    person_data = []
    for key, value in DB[person].items():
        question_key = f"{person} {key}"
        person_data.append((question_key, value))
        
    # Add random people's data for comparison
    other_persons = list(set(DB.keys()) - {person})
    rand_persons = np.random.choice(other_persons, 2, replace=False)
    
    # Prepare keys and values
    keys = []
    values = []
    
    # Add main person's data
    for key, value in person_data:
        keys.append(torch.tensor(multihot(key, vocab), dtype=torch.float32))
        values.append(torch.tensor(multihot(value, vocab), dtype=torch.float32))
    
    # Add random persons' data
    for p in rand_persons:
        for key, value in DB[p].items():
            question_key = f"{p} {key}"
            keys.append(torch.tensor(multihot(question_key, vocab), dtype=torch.float32))
            values.append(torch.tensor(multihot(value, vocab), dtype=torch.float32))
    
    k_tensor = torch.stack(keys).unsqueeze(0).to(device)
    v_tensor = torch.stack(values).unsqueeze(0).to(device)
    
    # Get model prediction
    with torch.no_grad():
        output = model(q_tensor, k_tensor, v_tensor)
        all_values = model.get_value_embeddings(v_tensor.squeeze(0))
        similarity = torch.matmul(output, all_values.t())
        pred_idx = torch.argmax(similarity).item()
    
    # Return predicted value
    return list(values)[pred_idx]

# Test some questions
test_questions = [
    ("When was Alexander Hamilton born?", "alexander hamilton"),
    ("What was Alexander Hamilton's party?", "alexander hamilton"),
    ("Where was George Washington born?", "george washington"),
    ("What was Abraham Lincoln's occupation?", "abraham lincoln")
]

print("Testing question answering:")
for question, person in test_questions:
    print(f"\nQ: {question}")
    answer = answer_question(question, person, model, DB, VOCAB, device)
    print(f"A: {answer}")