In [1]:
import random, re, json
from pathlib import Path
from typing import List, Dict, Optional, Tuple

import wikipedia
import pandas as pd

In [3]:
# Basic configuration
TOPICS = ["Air pollution", "Inflation"]  # topics 
SEED = 42
wikipedia.set_lang("en") # set Wikipedia language English
random.seed(SEED)  # ensure consistent random choices

# Fetch a Wikipedia page by title
# If the title is ambiguous, try the first few suggested options
# Returns a dictionary with title, URL, and content.


In [5]:
def fetch_page(title: str) -> Dict[str, str]:
    try:
        p = wikipedia.page(title, auto_suggest=False)
        return {"title": p.title, "url": p.url, "content": p.content}
    except wikipedia.DisambiguationError as e:
        for opt in e.options[:3]:
            try:
                p = wikipedia.page(opt, auto_suggest=False)
                return {"title": p.title, "url": p.url, "content": p.content}
            except Exception:
                continue
        return {}
    except Exception:
        return {}

# Test the fetch_page function
page_data = fetch_page("Air pollution")

# Check if content was successfully fetched
if page_data:
    print(f"Page title: {page_data['title']}")
    print(f"URL: {page_data['url']}")
    print("\nFirst 500 characters of content:\n")
    print(page_data["content"][:500])  # print only first 500 chars for readability
else:
    print("Could not fetch the page.")


Page title: Air pollution
URL: https://en.wikipedia.org/wiki/Air_pollution

First 500 characters of content:

Air pollution is the presence of substances in the air that are harmful to humans, other living beings or the environment. Pollutants can be gases, like ozone or nitrogen oxides, or small particles like soot and dust. Both outdoor and indoor air can be polluted.
Outdoor air pollution comes from burning fossil fuels for electricity and transport, wildfires, some industrial processes, waste management, demolition and agriculture. Indoor air pollution is often from burning firewood or agricultural 


In [7]:

def clean_text(text: str) -> str:
    # If the input isn't a string, return an empty string
    if not isinstance(text, str):
        return ""
    
    # Split the text into words (this automatically removes extra spaces and newlines)
    words = text.split()

    # Join the words back together with a single space between them
    cleaned_text = " ".join(words)

    # Remove any leftover leading/trailing spaces just in case
    return cleaned_text.strip()

#Test the clean_text function
if page_data:
    raw_text = page_data["content"]
    cleaned = clean_text(raw_text)
    print("\nBefore cleaning:", len(raw_text), "characters")
    print("After cleaning:", len(cleaned), "characters")
    print("\nFirst 300 chars after cleaning:\n", cleaned[:1000])

#if page_data:
    # Clean the content
#    cleaned = clean_text(page_data["content"])

    # Choose a filename (based on title)
#    filename = f"{page_data['title'].replace(' ', '_')}.txt"

    # Save to file
#    with open(filename, "w", encoding="utf-8") as f:
#        f.write(f"Title: {page_data['title']}\n")
#        f.write(f"URL: {page_data['url']}\n\n")
#       f.write(cleaned)

#    print(f"Saved cleaned article to {filename}")
#else:
#    print("Could not fetch the page.")



Before cleaning: 45145 characters
After cleaning: 45004 characters

First 300 chars after cleaning:
 Air pollution is the presence of substances in the air that are harmful to humans, other living beings or the environment. Pollutants can be gases, like ozone or nitrogen oxides, or small particles like soot and dust. Both outdoor and indoor air can be polluted. Outdoor air pollution comes from burning fossil fuels for electricity and transport, wildfires, some industrial processes, waste management, demolition and agriculture. Indoor air pollution is often from burning firewood or agricultural waste for cooking and heating. Other sources of air pollution include dust storms and volcanic eruptions. Many sources of local air pollution, especially burning fossil fuels, also release greenhouse gases that cause global warming. However air pollution may limit warming locally. Air pollution kills 7 or 8 million people each year. It is a significant risk factor for a number of diseases, inclu

In [9]:
def paragraphs_from_content(text: str) -> List[str]:

    # Step 1: Clean up extra spaces
    text = clean_text(text)
    if not text:
        return []

    # Step 2: Remove unwanted sections (and everything after them)
    unwanted_sections = ["Notes", "References", "Further reading", "External links"]
    pattern_sections = r"==\s*(?:" + "|".join(unwanted_sections) + r")\s*==.*"
    text = re.sub(pattern_sections, "", text, flags=re.IGNORECASE | re.DOTALL)

    # Step 3: Remove LaTeX-style math or formulas
    formula_patterns = [
        r"\{\\displaystyle.*?\}",     # Remove {\displaystyle ...}
        r"\\frac\s*\{.*?\}\s*\{.*?\}",# Remove \frac{a}{b}
        r"\\[a-zA-Z]+",               # Remove LaTeX macros (\times, \left, etc.)
        r"[\d]+\s*[/]\s*[\d]+",       # Fractions like 211.080/202.416
        r"[=×]\s*[\d\w\.\-\+/%\s]+",  # Equations like = 4.28% or × 100%
        r"\([^)]*\d+[^)]*\)"          # Parentheses containing numbers (math-style)
    ]

    for fp in formula_patterns:
        text = re.sub(fp, "", text)

    # Step 4: Split by section headers (== Section ==)
    parts = re.split(r"\s*==+.*?==+\s*", text)

    # Step 5: Remove empty or too-short chunks
    paragraphs = [p.strip() for p in parts if p.strip() and len(p.split()) >= 8]

    return paragraphs


In [11]:
def split_into_small_paragraphs(paragraphs: List[str]) -> List[str]:
    """
    Break long paragraphs into smaller chunks (~3 sentences each).
    Remove:
      - Incomplete or cut-off sentences
      - Paragraphs that start with punctuation (e.g., ',', '.', ';', ')')
      - Paragraphs that contain only one sentence
    """
    new_paragraphs = []
    for para in paragraphs:
        # Step 1: Split paragraph into sentences by punctuation (. ? !)
        sentences = re.split(r'(?<=[.!?])\s+', para)

        # Step 2: Keep only valid, complete sentences
        sentences = [s.strip() for s in sentences if s and s[-1] in ".!?"]

        # Step 3: Group sentences into chunks of 3
        for i in range(0, len(sentences), 3):
            chunk_sents = sentences[i:i+3]
            chunk = " ".join(chunk_sents).strip()

            # Step 4: Skip invalid chunks
            if not chunk:
                continue

            # Skip if starts with punctuation (fragment)
            if re.match(r"^[,.';:)\]]", chunk):
                continue

            # Skip if contains only 1 sentence
            if len(chunk_sents) < 2:
                continue

            new_paragraphs.append(chunk)

    return new_paragraphs



In [13]:
def basic_normalize(s: str) -> str:
    if not isinstance(s, str):
        return ""
    s = re.sub(r"\[[^\]]*\]", "", s)          # drop bracketed refs like [12]
    s = re.sub(r"\s+", " ", s)                # collapse whitespace
    return s.strip()

In [15]:
def make_variants(
    topic_paras: Dict[str, List[str]],
    *,
    seed: int = 42,
    min_words: int = 8
) -> Dict[str, pd.DataFrame]:
    rng = random.Random(seed)

    rows = []
    for topic, paras in topic_paras.items():
        for i, p in enumerate(paras, start=1):
            if not isinstance(p, str):
                continue
            p = basic_normalize(p)
            if len(p.split()) < min_words:
                continue
            rows.append({"topic": topic, "text": p, "order": i})

    # de-duplicate by text
    df = pd.DataFrame(rows).drop_duplicates(subset=["text"]).reset_index(drop=True)

    shuffled = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)

    # pair within-topic when possible (fallback to global)
    pairs = []
    groups = {k: v.copy() for k, v in df.groupby("topic")}
    for _, r in df.iterrows():
        pool = groups[r["topic"]]
        pool = pool[pool["text"] != r["text"]]
        if pool.empty:
            pool = df[df["text"] != r["text"]]
            if pool.empty:
                continue
        ref = pool.sample(n=1, random_state=r.name + seed).iloc[0]
        pairs.append({
            "input_topic": r["topic"], "input_text": r["text"],
            "ref_topic": ref["topic"], "ref_text": ref["text"]
        })
    pairs_df = pd.DataFrame(pairs)

    return {"original_df": df, "shuffled_df": shuffled, "pairs_df": pairs_df}



In [17]:
from collections import defaultdict

def build_topic_paragraphs(topics, max_per_topic=200):
    topic_paras = defaultdict(list)
    for t in topics:
        page = fetch_page(t)
        if not page:
            continue
        paras = paragraphs_from_content(page["content"])
        smalls = split_into_small_paragraphs(paras)
        # keep some per topic (you can adjust)
        topic_paras[t] = smalls[:max_per_topic]
    return dict(topic_paras)

topic_paras = build_topic_paragraphs(TOPICS, max_per_topic=400)  # tweak if needed
variants = make_variants(topic_paras, seed=SEED, min_words=8)
variants["original_df"].head()


Unnamed: 0,topic,text,order
0,Air pollution,Air pollution is the presence of substances in...,1
1,Air pollution,Outdoor air pollution comes from burning fossi...,2
2,Air pollution,"Many sources of local air pollution, especiall...",3
3,Air pollution,It is a significant risk factor for a number o...,4
4,Air pollution,"Overall, the World Bank has estimated that wel...",5


In [19]:
from sklearn.model_selection import train_test_split

# (A) Original -> label,text
df_all = variants["original_df"][["topic", "text"]].rename(columns={"topic":"label"})

# (B) Cap to 2,000 total with stratification by label
N_MAX = 2000
if len(df_all) > N_MAX and df_all["label"].nunique() > 1:
    _, df_all_small = train_test_split(
        df_all,
        test_size=N_MAX,
        stratify=df_all["label"],
        random_state=SEED
    )
    df_all = df_all_small.reset_index(drop=True)
else:
    df_all = df_all.sample(frac=1.0, random_state=SEED).reset_index(drop=True)

len(df_all), df_all["label"].value_counts()


(95,
 label
 Inflation        49
 Air pollution    46
 Name: count, dtype: int64)

In [21]:
# 80/20 first
train_df, temp_df = train_test_split(
    df_all,
    test_size=0.2,
    stratify=df_all["label"],
    random_state=SEED
)

# then 10/10 from the 20%
valid_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    stratify=temp_df["label"],
    random_state=SEED
)

print(len(train_df), len(valid_df), len(test_df))


76 9 10


In [25]:
train_df.to_csv(r"wiki_train.csv", index=False)  # columns: label,text
valid_df.to_csv(r"wiki_test.csv", index=False)
test_df.to_csv(r"wiki_valid.csv",  index=False)