# Dependencies

In [118]:
from transformers import RobertaModel, RobertaTokenizer, RobertaConfig
import json
import os
import torch

# Data preperation

In [136]:
def create_datasets(default_directory: str):
    """
    Preprocess the data and labels to provide text pair

    Args:
        default_directory: Default directory for both training and validation data
    
    Returns:
        A dictionary contained processed training and validation sets
    """

    # Defining dictionary
    data_dict = {
        "train": [],
        "validation": []
    }

    # Iterate through folders
    for split in ["train", "validation"]:
        for difficulty in ["easy", "medium", "hard"]:
            # Difficulty dict
            difficulty_dict = os.path.join(default_directory, difficulty)
            # Set current directory [train, validation]
            current_directory = os.path.join(difficulty_dict, split)
            
            # Iterate over all filenames
            for filename in os.listdir(current_directory):
                # Only work on .txt files
                if filename.endswith(".txt"):
                    text_path = os.path.join(current_directory, filename)
                    label_path = os.path.join(current_directory, "truth-" + filename.replace(".txt", ".json"))

                    # Open an process the files
                    # Text files
                    with open(text_path) as f:
                        text = f.read()
                    paragraphs = text.strip().split("\n")
                    # Labels
                    with open(label_path) as f:
                        object = json.load(f)
                    labels = object.get("changes")
                    
                    # print(paragraphs)
                    # print(labels)

                    # Error handling by removing badly formatted files
                    if len(labels) != len(paragraphs)-1:
                        os.remove(text_path)
                        os.remove(label_path)
                        print("Removed bad formatted files")

                    # Split each paragraph into tokens
                    processed_paragraphs = [paragraph.split() for paragraph in paragraphs]
                    
                    # Fill up data_dict
                    for i in range(1, len(paragraphs)):
                        #print((paragraphs[i-1], paragraphs[i], labels[i-1]))
                        data_dict[split].append([paragraphs[i-1], paragraphs[i], labels[i-1]])
    return data_dict


In [137]:
default_directory = "../pan24-multi-author-analysis"
data_dict = create_datasets(default_directory=default_directory)

In [140]:
print(len(data_dict.get("train")))

51962
