# Difficult Eval
I evaluated the stuff in fine_tuning with different datasets. This is not good. I will re-evaluate according to one test dataset and MATCH the order of the errors (thankfully I used the same seed to shuffle the data.)

In [5]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import ast
import pyarrow as pa
from datasets import load_dataset, concatenate_datasets, Dataset, DatasetDict
from collections import Counter
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from utils import *
from sft import *
from constants import SYSTEM_INSTRUCTION

In [57]:
LABEL_CONVERSIONS = {
                    "correct": '0',
                    "intrinsic-NP": '1',
                    "intrinsic-predicate": '2',
                    "extrinsic-NP": '3',
                    "extrinsic-predicate": '4'}

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')

def formatting_prompts_func(example, training=True):
    output_texts = []
    for i in range(len(example["error_type"])):
        text = f"{SYSTEM_INSTRUCTION}\n ### Text1: {example['doc'][i]}\n ### Text2: {example['summ'][i]}\n ### Output: "
        if training:
            text += (
                f"{LABEL_CONVERSIONS[example['error_type'][i]]} ." + tokenizer.eos_token
            )
        output_texts.append(text)
    return output_texts


def reformat_data_split_labels(dataset, dataset_name='Lislaam/AggreFact'):
    """Reformats the dataset to have the same format for all datasets for consistency.
    Args:
        dataset: dataset -- dataset to reformat
        dataset_name: str -- name of the dataset
    Returns:
        dataset: dataset -- reformatted dataset
    """
    def duplicate_and_label(example):
        """Duplicates examples with multiple error types, assigning one label per duplicate."""
        ids = []
        docs = []
        summs = []
        labels = []
        
        if example['errors'] is not None:
            try:
                lst = ast.literal_eval(example['errors'])
                for label in lst:
                    ids.append(example['id'])
                    docs.append(example['doc'])
                    summs.append(example['summ'])
                    labels.append(label)
            except ValueError:  # If 'errors' is not a list, e.g., it is 'correct'
                ids.append(example['id'])
                docs.append(example['doc'])
                summs.append(example['summ'])
                labels.append(example['errors'])
        return [{'id': id, 'doc': doc, 'summ': summ, 'error_type': label} for id, doc, summ, label in zip(ids, docs, summs, labels)]
    def process_in_chunks(dataset, chunk_size=10000, map_function=duplicate_and_label):
        chunked_tables = dataset.data.to_batches(max_chunksize=chunk_size)
        processed_chunks = []
        
        for chunk in chunked_tables:
            # Convert chunk to a PyArrow table
            chunk_table = pa.Table.from_batches([chunk])
            
            # Convert the chunk table to a pandas DataFrame
            chunk_df = chunk_table.to_pandas()
            
            if map_function:
                # Rename the column before splitting lists of errors into separate examples
                chunk_df = chunk_df.rename(columns={'error_type': 'errors'})
                
                # Apply the map function and flatten the result
                flattened_rows = chunk_df.apply(lambda row: map_function(row.to_dict()), axis=1).sum()
                
                # Convert the flattened list of dictionaries to a DataFrame
                chunk_df = pd.DataFrame(flattened_rows)
            
            processed_chunks.append(chunk_df)
        
        # Combine all processed chunks back into a single DataFrame
        combined_df = pd.concat(processed_chunks, ignore_index=True)
        
        return Dataset.from_pandas(combined_df)
    if dataset_name == "Lislaam/AggreFact":
        error_types = ['correct', 'intrinsic-NP', 'intrinsic-predicate', 'extrinsic-NP', 'extrinsic-predicate']
        dataset = process_in_chunks(dataset)
        dataset = dataset.filter(lambda x: x['error_type'] in error_types)
        #dataset = dataset.filter(lambda x: len(x['doc']) < 1800)
        #dataset = dataset.map(error_type_map)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported.")
    return dataset


def oversampling(dataset, error_types=['correct', 'intrinsic-NP', 'intrinsic-predicate', 'extrinsic-NP', 'extrinsic-predicate'], n=2330):
    def replicate_class(dataset, error_type, n):
        filtered = dataset.filter(lambda x: x['error_type'] == error_type)
        num_examples = len(filtered)
        
        if num_examples == 0:
            return filtered  # Return empty dataset if no examples
        
        # Calculate how many times to replicate the dataset
        num_repeats = n // num_examples
        num_remaining = n % num_examples
        
        # Repeat the dataset and select the needed number of examples
        replicated = concatenate_datasets([filtered] * num_repeats)
        remaining = filtered.shuffle(seed=42).select(range(num_remaining))
        
        # Concatenate the replicated examples with the additional ones needed
        return concatenate_datasets([replicated, remaining])
    # Initialize an empty dataset for oversampling
    oversampled_dataset = Dataset.from_dict({
        'doc': [],
        'summ': [],
        'error_type': []
    })
    for error_type in error_types:
        oversampled = replicate_class(dataset, error_type, n)
        oversampled_dataset = concatenate_datasets([oversampled_dataset, oversampled])
    # Shuffle the final dataset
    oversampled_dataset = oversampled_dataset.shuffle(seed=42)
    return oversampled_dataset


def undersampling(dataset, error_types=['correct', 'intrinsic-NP', 'intrinsic-predicate', 'extrinsic-NP', 'extrinsic-predicate'],
                    n=400):
    def sample_class(dataset, error_type, n):
        filtered = dataset.filter(lambda x: x['error_type'] == error_type)
        return filtered.shuffle(seed=42).select(range(min(n, len(filtered))))

    # Sample 400 examples from each class
    sampled_dataset = Dataset.from_dict({
        'doc': [],
        'summ': [],
        'error_type': []
    })

    for error_type in error_types:
        sampled = sample_class(dataset, error_type, n)
        sampled_dataset = concatenate_datasets([sampled_dataset, sampled])

    # Shuffle the final dataset
    sampled_dataset = sampled_dataset.shuffle(seed=42)

    return sampled_dataset


def extract(json_file_path):
    with open(json_file_path, 'r') as f:
        data = json.load(f)

    # Extract predictions and labels into lists
    true_labels = []
    predicted_labels = []

    for entry in data:
        true_labels.append(entry['label'])
        predicted_labels.append(entry['prediction'])

    return true_labels, predicted_labels


def get_score(predictions, references):
    #processed_preds = [preprocess(pred, model) for pred in predictions]
    processed_refs = [preprocess(ref) for ref in references] # Should always be processable
    flatten = lambda lst: [item for sublist in lst for item in (sublist if isinstance(sublist, list) else [sublist])]
    total = 0
    class_errors = {'extrinsic-NP': 0, 'extrinsic-predicate': 0, 'intrinsic-NP': 0,
                    'intrinsic-predicate': 0, 'correct': 0}
    num_extrinsicnp = sum([1 for ref in flatten(processed_refs) if ref == 'extrinsic-NP']) if 'extrinsic-NP' in flatten(processed_refs) else 1
    num_extrinsicpredicate = sum([1 for ref in flatten(processed_refs) if ref == 'extrinsic-predicate']) if 'extrinsic-predicate' in flatten(processed_refs) else 1
    num_intrinsicnp = sum([1 for ref in flatten(processed_refs) if ref == 'intrinsic-NP']) if 'intrinsic-NP' in flatten(processed_refs) else 1
    num_intrinsicpredicate = sum([1 for ref in flatten(processed_refs) if ref == 'intrinsic-predicate']) if 'intrinsic-predicate' in flatten(processed_refs) else 1
    num_correct = sum([1 for ref in flatten(processed_refs) if ref == 'correct']) if 'correct' in flatten(processed_refs) else 1
    # Check if any ref is within pred
    for i in range(len(processed_refs)):
        if type(processed_refs[i])==list:
            for x in processed_refs[i]:
                # print(processed_refs[i], x, predictions[i], soft_match(predictions[i], x), '/n')
                if soft_match(predictions[i], x): # Check if that ref is in the pred
                    total += 1/len(processed_refs[i])
                    class_errors[x] += 1
        else:
            # print(processed_refs[i], predictions[i], soft_match(predictions[i], processed_refs[i]), '/n')
            if soft_match(predictions[i], processed_refs[i]):
                total += 1
                class_errors[processed_refs[i]] += 1
    scores = {'total': total / len(processed_refs),
              'extrinsic-NP': class_errors["extrinsic-NP"] / num_extrinsicnp if 'extrinsic-NP' in flatten(processed_refs) else None,
              'extrinsic-predicate': class_errors["extrinsic-predicate"] / num_extrinsicpredicate if 'extrinsic-predicate' in flatten(processed_refs) else None,
              'intrinsic-NP': class_errors["intrinsic-NP"] / num_intrinsicnp if 'intrinsic-NP' in flatten(processed_refs) else None,
              'intrinsic-predicate': class_errors["intrinsic-predicate"] / num_intrinsicpredicate if 'intrinsic-predicate' in flatten(processed_refs) else None,
              'correct': class_errors["correct"] / num_correct if 'correct' in flatten(processed_refs) else None}
    
    #print(processed_refs)
    return scores

In [20]:
dataset = load_dataset("Lislaam/AggreFact", split=['validation[:]', 'test[:]'])
dataset = concatenate_datasets([dataset[0], dataset[1]]) # Turn into one dataset to make new split
dataset = reformat_data_split_labels(dataset, "Lislaam/AggreFact") # Get rid of non-standard error_type examples and split data

eval_dataset = oversampling(dataset)

# Split the dataset into train and test sets (80% train, 20% test)
train_test = eval_dataset.train_test_split(test_size=0.2)

# Further split the train set into train and validation sets (75% train, 25% validation of the original 80%)
train_valid = train_test['train'].train_test_split(test_size=0.25)

# Combine the splits into a DatasetDict
eval_dataset = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': train_test['test']
})

eval_dataset = eval_dataset.map(
    lambda x: {"formatted_text": formatting_prompts_func(x, False)},
    batched=True,
)

Filter: 100%|██████████| 6540/6540 [00:00<00:00, 199689.50 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 97332.08 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 95852.32 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 95466.54 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 94296.43 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 95085.30 examples/s]
Map: 100%|██████████| 6990/6990 [00:00<00:00, 36255.18 examples/s]
Map: 100%|██████████| 2330/2330 [00:00<00:00, 40135.07 examples/s]
Map: 100%|██████████| 2330/2330 [00:00<00:00, 35849.80 examples/s]


In [21]:
eval_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'doc', 'summ', 'error_type', 'formatted_text'],
        num_rows: 6990
    })
    validation: Dataset({
        features: ['id', 'doc', 'summ', 'error_type', 'formatted_text'],
        num_rows: 2330
    })
    test: Dataset({
        features: ['id', 'doc', 'summ', 'error_type', 'formatted_text'],
        num_rows: 2330
    })
})

In [31]:
def match_index(test_set, eval_set=eval_dataset['test']):
    # Step 1: Create an index for the large dataset
    # Create a dictionary with 'summ' as the key and the index in the large dataset as the value
    eval_set_list = eval_set['id']  # Extracting 'summ' column from eval_set
    large_dataset_index = {summ: idx for idx, summ in enumerate(eval_set_list)}

    # Step 2: Find the indices of the smaller dataset entries in the large dataset
    test_set_list = test_set['id']  # Extracting 'summ' column from test_set
    matching_indices = []
    
    for summ in test_set_list:
        if summ in large_dataset_index:
            matching_indices.append(large_dataset_index[summ])
        else:
            matching_indices.append(None)
            #print(f"Entry not found for: {summ}")  # Handle case when entry is not found

    return matching_indices

In [70]:
data = oversampling(dataset)

# Split the dataset into train and test sets (80% train, 20% test)
train_test = data.train_test_split(test_size=0.2)

# Further split the train set into train and validation sets (75% train, 25% validation of the original 80%)
train_valid = train_test['train'].train_test_split(test_size=0.25)

# Combine the splits into a DatasetDict
data = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': train_test['test']
})

data = data.map(
    lambda x: {"formatted_text": formatting_prompts_func(x, False)},
    batched=True,
)

Filter: 100%|██████████| 5921/5921 [00:00<00:00, 94349.81 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 94203.79 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 94457.83 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 96503.40 examples/s]
Filter: 100%|██████████| 5921/5921 [00:00<00:00, 95986.43 examples/s]
Map: 100%|██████████| 6990/6990 [00:00<00:00, 42093.83 examples/s]
Map: 100%|██████████| 2330/2330 [00:00<00:00, 38766.05 examples/s]
Map: 100%|██████████| 2330/2330 [00:00<00:00, 40670.90 examples/s]


In [71]:
sum([1 for i in match_index(data['test']) if i==None])

183

In [72]:
labels, preds = extract("fine_tuning safe copy/meta-llama/Meta-Llama-3-8B-Instruct/naive_oversampling/summary.json")
len(labels), len(preds)

(2330, 2330)

In [73]:
map_index = match_index(data['test'])
trues = []
labs = []

for l in range(len(labels)):
    if map_index[l] != None:
        trues.append(eval_dataset['test'][l]['error_type'])
        labs.append(labels[l])

len(trues), len(labs)

(2147, 2147)

In [74]:
get_score(labs, trues)

{'total': 0.20353982300884957,
 'extrinsic-NP': 0.19523809523809524,
 'extrinsic-predicate': 0.1864801864801865,
 'intrinsic-NP': 0.21123595505617979,
 'intrinsic-predicate': 0.2116788321167883,
 'correct': 0.21266968325791855}