### Objective: 
1. Create Folders based on unique rating values and create empty files inside folders with filename=article_pageid
2. Download articles from wikipedia and save as file. Use article_pageid as page id to search in wikipedia api.

#### RUNNING IMPORTS

In [11]:
import pandas as pd
import os
import threading
import wikipedia
import queue
import pickle
import re
from tqdm import tqdm
from glob import glob


### EDIT THIS TAB TO CONFIGURE NOTEBOOK SETTINGS

In [12]:
# CONFIG
BASE_FOLDER = "TEST"
TSV_FILE = "test_set.tsv"
# Group into folder by this column
FOLDER_BY = "rating"
# Create files with names from this column
FILE_BY = "article_pageid"
FILE_FORMAT = ".txt"
# THREADS
THREAD_COUNT = 40
# ERROR FILE
ERROR_FILE = "ERROR_LIST.pickle"
# SAFE CHARACTERS IN FILENAMES
SAFE_CHARACTERS = (".", "_", "-")
# CLEANING CONFIG
DELETE_SECTIONS = [
    "References",
    "External links",
    "See also",
    "Further reading",
    "Notes",
    "Bibliography",
    "Sources",
]
CLEANING_LEVEL = 0
"""
    WARNING:
        Don't clean data while downloading if you plan on using different levels of cleaning later on.
        Cleaning headings will make it impossible use clean functions later on.
    CLEANING LEVELS:
        0: No cleaning
        1: Clean All Headings
        2: Delete All Headings
        3: Delete Sections only
        4: Delete Selected Sections and Clean Headings
        5: Delete Selected Sections and Delete All Headings
    RECOMMENDED: 
        0: No Cleaning
        3: Delete Sections only
"""
# END CONFIG

## UTILITY FUNCTIONS TO MAKE CONFIG USABLE


def _headingClean(x):
    """
    internal function to clean headings and make them lowercase so that comparisons can be performed
    """
    try:
        return x.replace(" ", "").lower()
    except Exception as e:
        print("ERROR at clean_heading function:", e)
        return x


DELETE_SECTIONS = list(map(_headingClean, DELETE_SECTIONS))


##### UTILITY FUNCTIONS FOR CLEANING DATA

In [13]:
RE_HEADINGS = re.compile(r"==.*?==+", re.MULTILINE)


def cleanHeadings(x, DEL_HEADINGS=False):
    """
    Function to remove unwanted characters from headings
    Configurable to remove headings or not via DELETE_HEADINGS.
    Configurable to clean headings or not via CLEAN_HEADINGS.
    Warning: This function will remove all headings from the text. Please run only after deleting unwanted sections.
    """
    if DEL_HEADINGS:
        return RE_HEADINGS.sub("", x)
    else:
        return (
            x.replace("==== ", "")
            .replace("=== ", "")
            .replace("== ", "")
            .replace(" ====", "")
            .replace(" ===", "")
            .replace(" ==", "")
        )


def removeSections(x):
    """
    Function to remove unwanted sections from the text
    Configurable via DELETE_SECTIONS
    """
    r = RE_HEADINGS.finditer(x)
    sections = [(m.start(0), m.end(0)) for m in r]
    s = []
    for i, sec in enumerate(sections):
        secname = x[sec[0] : sec[1]].replace("=", "").replace(" ", "").lower()
        if secname in DELETE_SECTIONS:
            sb = sec[0]
            try:
                se = sections[i + 1][0]
            except IndexError:
                se = len(x)
            s.append(x[sb:se])
    for sec in s:
        x = x.replace(sec, "")
    return x


def clean(x):
    """
    Function to clean the text
    CLEANING LEVELS:
        0: No cleaning
        1: Clean All Headings
        2: Delete All Headings
        3: Delete Sections only
        4: Delete Selected Sections and Clean Headings
        5: Delete Selected Sections and Delete All Headings
    RECOMMENDED:
        0: No Cleaning
        1: Light Cleaning
        4: Heavy Cleaning
    """
    if CLEANING_LEVEL == 0:
        return x
    elif CLEANING_LEVEL == 1:
        return cleanHeadings(x)
    elif CLEANING_LEVEL == 2:
        return cleanHeadings(x, DEL_HEADINGS=True)
    elif CLEANING_LEVEL == 3:
        return removeSections(x)
    elif CLEANING_LEVEL == 4:
        return cleanHeadings(removeSections(x))
    elif CLEANING_LEVEL == 5:
        return cleanHeadings(removeSections(x), DEL_HEADINGS=True)
    else:
        raise Exception("Invalid CLEANING_LEVEL configured. Please check the config.")


##### READING THE TSV FILE AND CREATING INFERENCES

In [14]:
df = pd.read_csv(TSV_FILE, delimiter="\t")
df.head()
file_dict = dict(zip(df[FILE_BY], df[FOLDER_BY]))
folder_dict = dict(zip(df[FOLDER_BY], df[FILE_BY]))
urllist = df[FILE_BY].unique().tolist()


#### CREATING FOLDERS

In [15]:
if not os.path.exists(BASE_FOLDER):
    os.mkdir(BASE_FOLDER)
for folder in folder_dict.keys():
    os.makedirs(os.path.join(BASE_FOLDER, folder), exist_ok=True)
    print(f"folder created: {os.path.join(BASE_FOLDER, folder)}")


folder created: TEST\FA
folder created: TEST\GA
folder created: TEST\B
folder created: TEST\C
folder created: TEST\Start
folder created: TEST\Stub


### MAIN ASYNC WORKER CLASS AND RUNNER FUNCTION

In [None]:
ERROR_LIST = []


class Worker(threading.Thread):
    def __init__(self, q, *args, **kwargs):
        self.q = q
        super().__init__(*args, **kwargs)

    def run(self):
        while True:
            try:
                work = self.q.get(timeout=3)
                print(
                    f"{self.name} working on {work} with {self.q.qsize()} items remaining"
                )
                page = wikipedia.page(pageid=work)
                title = "".join(
                    c for c in page.title if c.isalnum() or c in SAFE_CHARACTERS
                ).rstrip()
                content = clean(page.content.encode("ascii", "ignore").decode("ascii"))
                folder = file_dict[work]
                with open(
                    os.path.join(BASE_FOLDER, folder, f"{title}{FILE_FORMAT}"), "w"
                ) as f:
                    f.write(content)
                    f.close()
                    print(
                        f"{os.path.join(BASE_FOLDER, folder, f'{title}{FILE_FORMAT}')} written"
                    )
            except queue.Empty:
                return
            except Exception as e:
                print(f"{self.name} error", e)
                ERROR_LIST.append(work)
            self.q.task_done()


def RunWorkers(urllist):
    q = queue.Queue()
    global ERROR_LIST
    ERROR_LIST = []
    for work in urllist:
        q.put_nowait(work)
    for _ in range(THREAD_COUNT):
        Worker(q).start()
    q.join()
    # INDEX DATASET INTO TSV_FILE
    indexlist = []
    for folder in folder_dict.keys():
        for file in glob(os.path.join(BASE_FOLDER,folder, "*.txt")):
            filename = os.path.basename(file)
            indexlist.append([folder,filename])
    table = pd.DataFrame(indexlist, columns=["category", "filename"])
    table.to_csv(os.path.join(BASE_FOLDER,"index.tsv"), sep="\t", index=True)
    print(f"{len(ERROR_LIST)} errors")
    pickle.dump(ERROR_LIST, open(ERROR_FILE, "wb"))


## RUN THIS FOR FIRST RUN / DOWNLOAD FROM CSV

In [None]:
RunWorkers(urllist)


## RUN THIS FOR ERROR CORRECTION FROM PREVIOUS RUN
Run this 2-3 times to remove all errors

In [None]:
ERRORS = pickle.load(open(ERROR_FILE, "rb"))
print(f"Errors: {len(ERRORS)}")

RunWorkers(ERRORS)


#### INDEXING DATASET
Run this once to index the dataset if already downloaded. This automatically happens if dataset is being downloaded.

In [None]:
# INDEX DATASET INTO TSV_FILE
indexlist = []
for folder in folder_dict.keys():
    for file in glob(os.path.join(BASE_FOLDER,folder, "*.txt")):
        filename = os.path.basename(file)
        indexlist.append([folder,filename])
table = pd.DataFrame(indexlist, columns=["category", "filename"])
table.to_csv(os.path.join(BASE_FOLDER,"index.tsv"), sep="\t", index=True)

In [49]:
# PYTORCH DATASET CREATION
from torch.utils.data import Dataset,random_split
import torch
import re


class WikipediaDataset(Dataset):
    def __init__(
        self,
        tsv_file,
        root_dir,
        clean_level=0,
        delete_sections=[
            "References",
            "External links",
            "See also",
            "Further reading",
            "Notes",
            "Bibliography",
            "Sources",
        ],
    ):
        """
        Args:
            tsv_file (string): Path to the tsv file with annotations.
            root_dir (string): Database Base Directory.
            clean_level (callable, optional): Optional cleaning to be applied on a sample.
            delete_sections (list, optional): Optional list of sections to be deleted if clean level is configured to delete sections.
        """
        self.wiki_frame = pd.read_csv(tsv_file, sep="\t")
        self.root_dir = root_dir
        self.clean_level = clean_level
        self.RE_HEADINGS = re.compile(r"==.*?==+", re.MULTILINE)
        self.delete_sections = list(map(self._headingClean, delete_sections))

    def _headingClean(self,x,args=None):
        """
            internal function to clean headings and make them lowercase so that comparisons can be performed
        """
        try:
            return x.replace(" ", "").lower()
        except Exception as e:
            print("ERROR at clean_heading function:", e)
            return x

    def cleanHeadings(self, x, DEL_HEADINGS=False):
        """
            Function to remove unwanted characters from headings
        """
        if DEL_HEADINGS:
            return RE_HEADINGS.sub("", x)
        else:
            return (
                x.replace("==== ", "")
                .replace("=== ", "")
                .replace("== ", "")
                .replace(" ====", "")
                .replace(" ===", "")
                .replace(" ==", "")
            )

    def removeSections(self, x):
        """
            Function to remove unwanted sections from the text
        """
        r = self.RE_HEADINGS.finditer(x)
        sections = [(m.start(0), m.end(0)) for m in r]
        s = []
        for i, sec in enumerate(sections):
            secname = x[sec[0] : sec[1]].replace("=", "").replace(" ", "").lower()
            if secname in self.delete_sections:
                sb = sec[0]
                try:
                    se = sections[i + 1][0]
                except IndexError:
                    se = len(x)
                s.append(x[sb:se])
        for sec in s:
            x = x.replace(sec, "")
        return x

    def clean(self, x):
        """
            Function to clean the text
            CLEANING LEVELS:
                0: No cleaning
                1: Clean All Headings
                2: Delete All Headings
                3: Delete Sections only
                4: Delete Selected Sections and Clean Headings
                5: Delete Selected Sections and Delete All Headings
            RECOMMENDED:
                0: No Cleaning
                1: Light Cleaning
                4: Heavy Cleaning
        """
        if self.clean_level == 0:
            return x
        elif self.clean_level == 1:
            return cleanHeadings(x)
        elif self.clean_level == 2:
            return cleanHeadings(x, DEL_HEADINGS=True)
        elif self.clean_level == 3:
            return removeSections(x)
        elif self.clean_level == 4:
            return cleanHeadings(removeSections(x))
        elif self.clean_level == 5:
            return cleanHeadings(removeSections(x), DEL_HEADINGS=True)
        else:
            raise Exception(
                "Invalid clean_level configured. Please reinitialize dataloader."
            )

    def __len__(self):
        return len(self.wiki_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        x = self.wiki_frame.iloc[idx]
        cat, file_name = x["category"], x["filename"]
        txt = open(os.path.join(self.root_dir, cat, file_name), "r").read()
        txt = self.clean(txt)
        sample = {"text": txt, "label": cat}
        return sample


In [None]:
BASE_FOLDER = 'TEST/'
INDEX_PATH = os.path.join(BASE_FOLDER,"index.tsv")
CLEAN_LEVEL = 4
CACHE_DIR = os.path.join(os.getcwd(), 'transformers-cache')

In [None]:
dataset = WikipediaDataset(INDEX_PATH, BASE_FOLDER, clean_level=CLEAN_LEVEL)
length = len(dataset)
print(f"Dataset length: {length}")
test_l = int(length * 0.15)
train_l = length - test_l
valid_l = int(train_l * 0.1)
train_l -= valid_l
train,test,validation = random_split(range(len(dataset)), [train_l,test_l,valid_l], generator=torch.Generator().manual_seed(42))
print(f"Train: {len(train)}, Test: {len(test)}, Validation: {len(validation)}")
# Dictionary of labels and their id - this will be used to convert.
# String labels to number ids.
labels_ids = {'B': 0, 'C': 1, 'FA': 2, 'GA': 3, 'Start': 4, 'Stub': 5}

# How many labels are we using in training.
# This is used to decide size of classification head.
n_labels = len(labels_ids)