# Difficult Eval
I evaluated the stuff in fine_tuning with different datasets. This is not good. I will re-evaluate according to one test dataset and MATCH the order of the errors (thankfully I used the same seed to shuffle the data.)

In [9]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import ast
import pyarrow as pa
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
from collections import Counter
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

In [10]:
def make_binary_dataset(dataset, error_type='correct'):
    # Map dataset into error / not_error
    def map_to_binary(x):
        # Assume error_type is a list, convert to single string if needed
        x['error_type'] = 'correct' if error_type in x['error_type'] else 'incorrect'
        return x

    # Apply the binary mapping
    binary_dataset = dataset.map(map_to_binary)

    # Count the number of 'correct' samples
    num_correct = binary_dataset.filter(lambda x: x['error_type'] == 'correct').num_rows

    # Undersample to ensure balanced classes
    binary_dataset = undersampling(binary_dataset, error_types=['correct', 'incorrect'], n=num_correct)

    return binary_dataset


def undersampling(dataset, error_types=['correct', 'intrinsic-NP', 'intrinsic-predicate', 'extrinsic-NP', 'extrinsic-predicate'],
                    n=400):
    def sample_class(dataset, error_type, n):
        filtered = dataset.filter(lambda x: x['error_type'] == error_type)
        return filtered.shuffle(seed=42).select(range(min(n, len(filtered))))

    # Sample 400 examples from each class
    sampled_dataset = Dataset.from_dict({
        'doc': [],
        'summ': [],
        'error_type': []
    })

    for error_type in error_types:
        sampled = sample_class(dataset, error_type, n)
        sampled_dataset = concatenate_datasets([sampled_dataset, sampled])

    # Shuffle the final dataset
    sampled_dataset = sampled_dataset.shuffle(seed=42)

    return sampled_dataset


def oversampling(dataset, error_types=['correct', 'intrinsic-NP', 'intrinsic-predicate', 'extrinsic-NP', 'extrinsic-predicate'], n=2330):
    def replicate_class(dataset, error_type, n):
        filtered = dataset.filter(lambda x: x['error_type'] == error_type)
        num_examples = len(filtered)
        
        if num_examples == 0:
            return filtered  # Return empty dataset if no examples
        
        # Calculate how many times to replicate the dataset
        num_repeats = n // num_examples
        num_remaining = n % num_examples
        
        # Repeat the dataset and select the needed number of examples
        replicated = concatenate_datasets([filtered] * num_repeats)
        remaining = filtered.shuffle(seed=42).select(range(num_remaining))
        
        # Concatenate the replicated examples with the additional ones needed
        return concatenate_datasets([replicated, remaining])

    # Initialize an empty dataset for oversampling
    oversampled_dataset = Dataset.from_dict({
        'doc': [],
        'summ': [],
        'error_type': []
    })

    for error_type in error_types:
        oversampled = replicate_class(dataset, error_type, n)
        oversampled_dataset = concatenate_datasets([oversampled_dataset, oversampled])

    # Shuffle the final dataset
    oversampled_dataset = oversampled_dataset.shuffle(seed=42)

    return oversampled_dataset


def reformat_data_split_labels(dataset, dataset_name):
    """Reformats the dataset to have the same format for all datasets for consistency.

    Args:
        dataset: dataset -- dataset to reformat
        dataset_name: str -- name of the dataset

    Returns:
        dataset: dataset -- reformatted dataset
    """
    def duplicate_and_label(example):
        """Duplicates examples with multiple error types, assigning one label per duplicate."""
        docs = []
        summs = []
        labels = []
        
        if example['errors'] is not None:
            try:
                lst = ast.literal_eval(example['errors'])
                for label in lst:
                    docs.append(example['doc'])
                    summs.append(example['summ'])
                    labels.append(label)
            except ValueError:  # If 'errors' is not a list, e.g., it is 'correct'
                docs.append(example['doc'])
                summs.append(example['summ'])
                labels.append(example['errors'])

        return [{'doc': doc, 'summ': summ, 'error_type': label} for doc, summ, label in zip(docs, summs, labels)]

    def process_in_chunks(dataset, chunk_size=10000, map_function=duplicate_and_label):
        chunked_tables = dataset.data.to_batches(max_chunksize=chunk_size)
        processed_chunks = []
        
        for chunk in chunked_tables:
            # Convert chunk to a PyArrow table
            chunk_table = pa.Table.from_batches([chunk])
            
            # Convert the chunk table to a pandas DataFrame
            chunk_df = chunk_table.to_pandas()
            
            if map_function:
                # Rename the column before splitting lists of errors into separate examples
                chunk_df = chunk_df.rename(columns={'error_type': 'errors'})
                
                # Apply the map function and flatten the result
                flattened_rows = chunk_df.apply(lambda row: map_function(row.to_dict()), axis=1).sum()
                
                # Convert the flattened list of dictionaries to a DataFrame
                chunk_df = pd.DataFrame(flattened_rows)
            
            processed_chunks.append(chunk_df)
        
        # Combine all processed chunks back into a single DataFrame
        combined_df = pd.concat(processed_chunks, ignore_index=True)
        
        return Dataset.from_pandas(combined_df)

    if dataset_name == "Lislaam/AggreFact":
        error_types = ['correct', 'intrinsic-NP', 'intrinsic-predicate', 'extrinsic-NP', 'extrinsic-predicate']
        dataset = process_in_chunks(dataset)
        dataset = dataset.filter(lambda x: x['error_type'] in error_types)

    else:
        raise ValueError(f"Dataset {dataset_name} not supported.")
    return dataset

In [11]:
DATA = "Lislaam/AggreFact"

# Load the dataset
dataset = load_dataset(DATA, split=['validation[:]', 'test[:]'])
dataset = concatenate_datasets([dataset[0], dataset[1]]) # Turn into one dataset to make new split
dataset = reformat_data_split_labels(dataset, DATA) # Get rid of non-standard error_type examples and split data

Filter: 100%|██████████| 6540/6540 [00:00<00:00, 237433.98 examples/s]


In [12]:
eval_dataset = oversampling(dataset)

# Split the dataset into train and test sets (80% train, 20% test)
train_test = eval_dataset.train_test_split(test_size=0.2)

# Further split the train set into train and validation sets (75% train, 25% validation of the original 80%)
train_valid = train_test['train'].train_test_split(test_size=0.25)

# Combine the splits into a DatasetDict
eval_dataset = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': train_test['test']
})

eval = eval_dataset['test'].to_pandas()

Filter: 100%|██████████| 5921/5921 [00:00<00:00, 119505.67 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 114659.63 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 120324.40 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 120075.40 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 122049.92 examples/s]


In [13]:
under_data = undersampling(dataset)
# Split the dataset into train and test sets (80% train, 20% test)
train_test = under_data.train_test_split(test_size=0.2)

# Further split the train set into train and validation sets (75% train, 25% validation of the original 80%)
train_valid = train_test['train'].train_test_split(test_size=0.25)

# Combine the splits into a DatasetDict
under_dataset = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': train_test['test']
})

under = under_dataset['test'].to_pandas()


Filter: 100%|██████████| 5921/5921 [00:00<00:00, 121169.19 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 123396.21 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 124467.98 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 122066.72 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 123306.14 examples/s]


In [14]:
whole_data = dataset
# Split the dataset into train and test sets (80% train, 20% test)
train_test = whole_data.train_test_split(test_size=0.2)

# Further split the train set into train and validation sets (75% train, 25% validation of the original 80%)
train_valid = train_test['train'].train_test_split(test_size=0.25)

# Combine the splits into a DatasetDict
whole_dataset = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': train_test['test']
})

whole = whole_dataset['test'].to_pandas()

In [15]:
def get_indices(df_subset, df_large=eval):    
    # Step 1: Create a dictionary to map each entry to its position in the larger dataset
    entry_to_index = {entry: idx for idx, entry in enumerate(df_large['data'])}

    # Step 2: Sort the smaller subset by the index positions found in the larger dataset
    df_subset['index'] = df_subset['data'].map(entry_to_index)

    # Step 3: Drop any rows where the entry wasn't found in the larger dataset
    df_subset = df_subset.dropna(subset=['index'])

    # Step 4: Convert index column to integer for sorting (optional but useful for correct ordering)
    df_subset['index'] = df_subset['index'].astype(int)