## Load data and chose categories

In [1]:
import json
import torch
import pickle

In [2]:
with open('data/squad2-train.json', 'r') as file: #place squad2-train.json in data folder. Download from https://rajpurkar.github.io/SQuAD-explorer/
    train_data = json.load(file)

In [3]:
categories = []

In [4]:
for i, category in enumerate(train_data['data']):
    title = category['title']
    question_count = 0
    context_count = len(category['paragraphs'])
    for paragraph in category['paragraphs']:
        question_count += len(paragraph['qas'])
    categories.append((i, title, question_count, context_count))

In [6]:
sorted(categories, key=lambda x: x[2], reverse=True)

[(171, 'Queen_Victoria', 883, 42),
 (7, 'New_York_City', 817, 148),
 (12, 'American_Idol', 790, 127),
 (0, 'Beyoncé', 753, 66),
 (1, 'Frédéric_Chopin', 697, 82),
 (11, 'Buddhism', 610, 149),
 (168, 'Pharmaceutical_industry', 586, 44),
 (250, 'New_Haven,_Connecticut', 582, 66),
 (380, 'Premier_League', 551, 40),
 (438, 'Hunting', 531, 36),
 (152, 'Antarctica', 525, 44),
 (6, '2008_Sichuan_earthquake', 521, 77),
 (107, 'Houston', 521, 48),
 (351, 'Steven_Spielberg', 512, 49),
 (291, 'PlayStation_3', 508, 44),
 (390, 'Alfred_North_Whitehead', 502, 47),
 (14, '2008_Summer_Olympics_torch_relay', 500, 74),
 (356, 'Charleston,_South_Carolina', 490, 48),
 (144, 'Macintosh', 487, 49),
 (349, 'Muammar_Gaddafi', 485, 76),
 (166, 'Multiracial_American', 481, 44),
 (383, 'San_Diego', 475, 49),
 (5, 'Spectre_(2015_film)', 462, 43),
 (379, 'Greeks', 458, 44),
 (91, 'Middle_Ages', 452, 93),
 (217, 'Yale_University', 452, 46),
 (365, 'Modern_history', 448, 88),
 (306, 'Mandolin', 446, 43),
 (172, 'Free

In [7]:
category_indices = [0, 1, 168, 6, 152, 438]

In [8]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class Question:
    question: str
    context_index: int
    embedding: Optional[torch.tensor] = None
    transformed_embedding: Optional[torch.tensor] = None

@dataclass
class Context:
    context: str
    context_index: int
    embedding: Optional[torch.tensor] = None
    transformed_embedding: Optional[torch.tensor] = None

@dataclass
class DataCollection:
    questions: list[Question]
    contexts: list[Context]
    metadata: dict


In [13]:
for index in category_indices:
    question_collection: list[Question] = []
    context_collection: list[Context] = []
    for i, context in enumerate(train_data['data'][index]['paragraphs']):
        context_collection.append(Context(context=context['context'], context_index=i))
        for question in context['qas']:
            if question['is_impossible']:
                continue
            question_collection.append(Question(question=question['question'], context_index=i))
        # print(i, context['qas'][0])
        
    data_collection = DataCollection(
        questions=question_collection,
        contexts=context_collection,
        metadata={
            "description": "Raw question data",
            "category":train_data['data'][index]['title']
        }
    )

    with open(f"data/raw/{train_data['data'][index]['title']}-base.pkl", 'wb') as file:
        pickle.dump(data_collection, file)

In [10]:
question_collection: list[Question] = []
context_collection: list[Context] = []

In [11]:
train_data['data'][6]['paragraphs'][0]

{'qas': [{'question': 'In what year did the earthquake in Sichuan occur?',
   'id': '56cdca7862d2951400fa6826',
   'answers': [{'text': '2008', 'answer_start': 4}],
   'is_impossible': False},
  {'question': 'What was the earthquake named?',
   'id': '56cdca7862d2951400fa6827',
   'answers': [{'text': 'the Great Sichuan earthquake', 'answer_start': 31}],
   'is_impossible': False},
  {'question': 'How many people were killed as a result?',
   'id': '56cdca7862d2951400fa6828',
   'answers': [{'text': '69,197', 'answer_start': 206}],
   'is_impossible': False},
  {'question': 'What year did the Sichuan earthquake take place?',
   'id': '56d4f9902ccc5a1400d833c0',
   'answers': [{'text': '2008', 'answer_start': 4}],
   'is_impossible': False},
  {'question': 'What did the quake measure?',
   'id': '56d4f9902ccc5a1400d833c1',
   'answers': [{'text': '8.0 Ms and 7.9 Mw', 'answer_start': 73}],
   'is_impossible': False},
  {'question': 'What day did the earthquake occur?',
   'id': '56d4f990

In [12]:
for i, context in enumerate(train_data['data'][category_index]['paragraphs']):
    context_collection.append(Context(context=context['context'], context_index=i))
    for question in context['qas']:
        if question['is_impossible']:
            continue
        question_collection.append(Question(question=question['question'], context_index=i))
    print(i, context['qas'][0])


NameError: name 'category_index' is not defined

In [11]:
data_collection = DataCollection(
    questions=question_collection,
    contexts=context_collection,
    metadata={
        "description": "Raw question data",
        "category":train_data['data'][category_index]['title']
    }
)

In [13]:
data_collection.metadata

{'description': 'Raw question data', 'category': 'Beyoncé'}