In [2]:
import os
import pandas as pd
from datasets import DatasetDict, Dataset, Features, Value, ClassLabel
from collections import defaultdict
import numpy as np

def classification_dataset(
    data_dir: str, target_author: str = "Hugo", binary_classification: bool = True
) -> DatasetDict:
    
   
    dataset_dict = defaultdict(list)
    
    
    authors = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
    author_to_label = {author: 1 if binary_classification and author == target_author else 0 for author in authors}
    
    print(f"Author-to-Label Mapping: {author_to_label}")
    
    for author in authors:
        author_dir = os.path.join(data_dir, author)
        
        # Read all text files for this author
        for file_name in os.listdir(author_dir):
            file_path = os.path.join(author_dir, file_name)
            
            # Read paragraph text from file
            with open(file_path, "r", encoding="utf-8") as file:
                paragraphs = file.readlines()
            
            # Assign labels based on classification mode
            if binary_classification:
                label = 1 if author == target_author else 0
                dataset_dict["text"].extend(paragraphs)
                dataset_dict["label"].extend([label] * len(paragraphs))
            else:  # Multi-label classification
                dataset_dict["text"].extend(paragraphs)
                dataset_dict["label"].extend([authors.index(author)] * len(paragraphs))
    
    # Shuffle the dataset and split into train, validation, and test
    df = pd.DataFrame(dataset_dict).sample(frac=1, random_state=42).reset_index(drop=True)
    train_df, valid_df, test_df = np.split(
        df, 
        [int(0.8 * len(df)), int(0.9 * len(df))]  # 80% train, 10% validation, 10% test
    )
    
    # Define dataset features
    if binary_classification:
        features = Features(
            {
                "text": Value("string"),
                "label": ClassLabel(num_classes=2, names=["Other", target_author]),
            }
        )
    else:
        features = Features(
            {
                "text": Value("string"),
                "label": ClassLabel(num_classes=len(authors), names=authors),
            }
        )
    
    
    dataset_splits = {
        "train": Dataset.from_pandas(train_df, features=features),
        "validation": Dataset.from_pandas(valid_df, features=features),
        "test": Dataset.from_pandas(test_df, features=features),
    }
    
    return DatasetDict(dataset_splits)