# Wikipedia Dataset generation

In this notebook, we will generate a dataset from Wikipedia articles. We will use the `wikipedia` library to download the articles and then we will extract the text from them. We will only consider the most popular articles in English to prevent our dataset from being too large.

In [97]:
import wikipediaapi as wiki
from tqdm.notebook import tqdm
import os
import csv
import faiss
import torch
from datasets import load_dataset, Features, Value, Sequence, Dataset
from functools import partial

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Create user agent
wiki_wiki = wiki.Wikipedia("RUG NLP Q&A", 'en')

## Most Popular Articles

We will refer to the Popular pages list on wikipedia to get the most popular articles. We see that there are 1079 links to other pages. We will extract the links and then download the articles.

In [27]:
# Get all pages mentioned in this url https://en.wikipedia.org/wiki/Wikipedia:Popular_pages
popular_pages = wiki_wiki.page("Wikipedia:Popular_pages")
popular_pages_links = popular_pages.links

print(f"Found {len(popular_pages_links)} popular pages")

Found 1079 popular pages


We will conform to the medallion standard for the dataset. Thus, we will have raw data, bronze data, silver data, and gold data. The raw data will be the text of the articles. The bronze data will be the processed text of the articles. The silver data will a json file containing the text and the title of the article. The gold data will be the embeddings of each sentence in the article.

## Raw Data

In [23]:
# Get the text of all the pages as specified in the links above
# Save the text in a file with the name of the page as the file name, save it in data/raw
for page in tqdm(popular_pages_links):
    page = wiki_wiki.page(page)

    # If the file already exists, skip it
    if os.path.exists(f"data/raw/{page.title}.txt"):
        continue
    
    # Account for pages with / in their name by replacing it with _
    file_name = page.title.replace("/", "_")

    # replace ? with _ in the file name
    file_name = file_name.replace("?", "_")

    with open(f"data/raw/{file_name}.txt", "w") as f:
        f.write(page.text)


  0%|          | 0/1079 [00:00<?, ?it/s]

## Bronze Data

We will remove the references, external links and see also sections from the articles.

In [28]:
# Load the data from the raw files, remove the see also, references, external links and notes sections. Then save it in data/bronze
for file in tqdm(os.listdir("data/raw/")):
    with open(f"data/raw/{file}", "r") as f:
        dataset = f.read()

    # Remove the see also, references, external links and notes sections
    dataset = dataset.split("See also")[0]
    dataset = dataset.split("References")[0]
    dataset = dataset.split("External links")[0]
    dataset = dataset.split("Notes ")[0]

    with open(f"data/bronze/{file}", "w") as f:
        f.write(dataset)

  0%|          | 0/1025 [00:00<?, ?it/s]

## Silver Data

We will create a csv file containing the title and the text of the article. The csv format is required for use with the datasets library. We use tab as the delimiter to prevent any issues with commas in the text.

In [32]:
# Load the text of all the pages in data/bronze and store it in a csv file with the title of the page as the first column and the text as the second column
with open("data/silver/data.csv", "w") as f:
    writer = csv.writer(f, delimiter="\t", lineterminator="\n")

    writer.writerow(["title", "text"])

    for file in tqdm(os.listdir("data/bronze/")):
        with open(f"data/bronze/{file}", "r") as f:
            dataset = f.read().replace("\n", " ")
            dataset = dataset.replace("\t", " ")
        
        writer.writerow([file.replace(".txt", ""), dataset])

  0%|          | 0/1025 [00:00<?, ?it/s]

## Gold Data

We will generate the index using faiss and a context encoder.

### Preparing the data

In [3]:
def split_text(text, max_len=100):
    """
    Split the text using " " into chunks of max_len words
    """
    text = text.split(" ")
    chunks = []
    for i in range(0, len(text), max_len):
        chunks.append(" ".join(text[i:i+max_len]))
        
    return chunks
    

In [4]:
def split_dataset(dataset):
    """
    Split the dataset into chunks of 100 words, returns a dictionary with the title and the text
    """
    titles, texts = [], []
    for title, text in zip(dataset["title"], dataset["text"]):

        if text is None:
            continue

        chunks = split_text(text)
        for chunk in chunks:
            titles.append(title)
            texts.append(chunk)

    return {"title": titles, "text": texts}
        

In [5]:
# Load the dataset and split the text into chunks of 100 words
dataset = load_dataset("csv", data_files="data/silver/data.csv", delimiter="\t")
chunked_dataset = dataset.map(split_dataset, batched=True)
print(f"Split the dataset into {len(chunked_dataset['train']['title'])} chunks")
print(chunked_dataset['train']["text"][0])

Split the dataset into 80567 chunks
The hyphen-minus symbol - is the form of hyphen most commonly used in digital documents. On most keyboards, it is the only character that resembles a minus sign or a dash so it is also used for these. The name hyphen-minus derives from the original ASCII standard, where it was called hyphen–(minus). The character is referred to as a hyphen, a minus sign, or a dash according to the context where it is being used.  Description In early typewriters and character encodings, a single key/code was almost always used for hyphen, minus, various dashes, and strikethrough, since they all


### Generating the embeddings
We will use the albert-base-v2 model to generate the embeddings due to its speed and accuracy.

In [162]:
# Load the pretrained model and tokenizer, we will not train the model yet but use it to encode the text
encoder = SentenceTransformer('sentence-transformers/paraphrase-albert-base-v2').to(device)

model.safetensors:   0%|          | 0.00/46.7M [00:00<?, ?B/s]

In [163]:
# Get embedding shape from config
embedding_shape = encoder.encode("test").shape
print(f"Embedding shape: {embedding_shape}")

new_features = Features(
    {
        "text": Value("string"), 
        "title": Value("string"), 
        "embeddings": Sequence(Value("float32"))
    }
)

Embedding shape: (768,)


In [164]:
# Reduce this amount if you run out of memory
BATCH_SIZE = 64

def embed(dataset, encoder):
    """
    Embed the text using the encoder and tokenizer
    """

    input = list(zip(dataset["text"], dataset["title"]))
    embeddings = encoder.encode(input, show_progress_bar=True, device=device, batch_size=BATCH_SIZE)

    # Free up memory
    torch.cuda.empty_cache()

    return {"title": dataset["title"], "text": dataset["text"], "embeddings": embeddings}

In [165]:
# Embed the text without mapping
embedded_dataset = embed(chunked_dataset["train"], encoder)

Batches:   0%|          | 0/1259 [00:00<?, ?it/s]

### Creating the index

We will use the faiss library to create the index with the inner product as the similarity measure. This is the same measure used in the original paper.

In [166]:
# Convert to huggingface dataset
index_dataset = Dataset.from_dict(embedded_dataset, features=new_features)

index = faiss.IndexFlatIP(embedding_shape[0])
index_dataset.add_faiss_index("embeddings", custom_index=index)

# Create gold folder if it does not exist
if not os.path.exists("data/gold"):
    os.makedirs("data/gold")

# Save the dataset with the faiss index
index_dataset.get_index("embeddings").save("data/gold/index.faiss")

  0%|          | 0/81 [00:00<?, ?it/s]

### Testing the index

In [171]:
TOP_K = 5
text = "What is the capital of the Netherlands?"

# Get the embeddings for the question
embeddings = encoder.encode(text, device=device)

# Search the faiss index for the most similar embeddings
D, I = index_dataset.get_index("embeddings").search(embeddings, TOP_K)

# Get the text, titles and distances of the most similar embeddings
for i, (distance, index) in enumerate(zip(D, I)):
    print(f"Distance: {distance}")
    print(f"Title: {index_dataset['title'][index]}")
    print(f"Text: {index_dataset['text'][index]}")
    print()

Distance: 231.4854736328125
Title: Netherlands
Text: The Netherlands, informally Holland, is a country located in northwestern Europe with overseas territories in the Caribbean. It is the largest of the four constituent countries of the Kingdom of the Netherlands. The Netherlands consists of twelve provinces; it borders Germany to the east and Belgium to the south, with a North Sea coastline to the north and west. It shares maritime borders with the United Kingdom, Germany, and Belgium. The official language is Dutch, with West Frisian as a secondary official language in the province of Friesland. Dutch, English, and Papiamento are official in the Caribbean territories.Netherlands literally means

Distance: 217.090576171875
Title: Kingdom of the Netherlands
Text: The Kingdom of the Netherlands (Dutch: Koninkrijk der Nederlanden, pronounced [ˈkoːnɪŋkrɛik dɛr ˈneːdərlɑndə(n)] ), commonly known simply as the Netherlands, is a sovereign state consisting of a collection of constituent terri