In [1]:
from datasets import load_dataset, concatenate_datasets

from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
import random
import numpy as np
from utilities import evaluate_gpt2_classification as evaluate_gpt2_classification, mask_range_gpt,compute_masks, reset_gpt, compute_mask_probe, mask_gpt2
import torch  
from tqdm import tqdm

dataset_name = "fancyzhx/dbpedia_14"

text_tag = "content"

# Load dataset and tokenizer


tables = []
layer = 11

num_classes = 14

# tao = 2.5

lab = "label"
# tao = torch.inf

dataset = load_dataset(dataset_name)

print(dataset)

# print(dataset['train'].features)






#######################Filter dataset####################
from datasets import DatasetDict, Dataset, Features, ClassLabel, Value
import pandas as pd

def sample_balanced_dataset(dataset_dict, max_train_per_class=800, max_test_per_class=200):
    """
    Sample a balanced subset while preserving the original feature structure including ClassLabel.
    """
    # Store original features
    original_features = dataset_dict['train'].features
    
    # Convert to pandas for sampling
    train_df = dataset_dict['train'].to_pandas()
    test_df = dataset_dict['test'].to_pandas()
    
    # Group by label
    train_groups = train_df.groupby('label')
    test_groups = test_df.groupby('label')
    
    sampled_train_dfs = []
    sampled_test_dfs = []
    
    print("\nClass distribution:")
    print("\nLabel | Label Name | Train Samples | Test Samples | Final Train | Final Test")
    print("-" * 85)
    
    label_names = original_features['label'].names
    for idx, label_name in enumerate(label_names):
        train_group = train_groups.get_group(idx)
        test_group = test_groups.get_group(idx) if idx in test_groups.groups else pd.DataFrame()
        
        # Sample with replacement if needed
        train_replace = len(train_group) < max_train_per_class
        test_replace = len(test_group) < max_test_per_class
        
        sampled_train = train_group.sample(
            n=min(len(train_group), max_train_per_class),
            replace=train_replace,
            random_state=42
        )
        
        if not test_group.empty:
            sampled_test = test_group.sample(
                n=min(len(test_group), max_test_per_class),
                replace=test_replace,
                random_state=42
            )
        else:
            sampled_test = pd.DataFrame(columns=test_df.columns)
        
        sampled_train_dfs.append(sampled_train)
        sampled_test_dfs.append(sampled_test)
        
        print(f"{idx:5d} | {label_name:10s} | {len(train_group):12d} | "
              f"{len(test_group):11d} | {len(sampled_train):10d} | {len(sampled_test):9d}")
    
    # Concatenate all sampled dataframes
    final_train_df = pd.concat(sampled_train_dfs, ignore_index=True)
    final_test_df = pd.concat(sampled_test_dfs, ignore_index=True)
    
    # Convert back to datasets while preserving the original features
    final_train_dataset = Dataset.from_pandas(final_train_df, features=original_features)
    final_test_dataset = Dataset.from_pandas(final_test_df, features=original_features)
    
    # Create new DatasetDict
    sampled_dataset = DatasetDict({
        'train': final_train_dataset,
        'test': final_test_dataset
    })
    
    print("\nFinal dataset sizes:")
    print(f"Train: {len(final_train_dataset)} samples")
    print(f"Test: {len(final_test_dataset)} samples")
    
    # Verify feature structure is preserved
    print("\nVerifying feature structure:")
    print(sampled_dataset['train'].features)
    
    return sampled_dataset

# dataset = sample_balanced_dataset(dataset, max_train_per_class=800, max_test_per_class=200)

###########################################



# Set random seed
seed_value = 42  # or any other integer

random.seed(seed_value)
np.random.seed(seed_value)

if torch.cuda.is_available():  # PyTorch-specific
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

import torch

torch.autograd.set_detect_anomaly(True)
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


special_tokens_dict = {}
new_tokens = []
label2text = dataset['train'].features[lab].names

for label in label2text:
    # Create special token format (with and without space)
    special_token = f'{label}'
    
    # Check if the label is already a single token in the tokenizer
    label_tokens = tokenizer.encode(label, add_special_tokens=False)
    is_single_token = len(label_tokens) == 1
    
    if is_single_token:
        print(f"'{label}' is already a single token (ID: {label_tokens[0]})")
    
    # Add both versions to new tokens list
    new_tokens.extend([special_token])

# Add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(new_tokens)
print(f"\nAdded {num_added_tokens} new tokens to the tokenizer")

special_tokens = {
    'pad_token': '<|pad|>',
    'sep_token': '<|sep|>',
    'eos_token': '<|eos|>'
}
tokenizer.add_special_tokens(special_tokens)

def format_data(examples):
    formatted_texts = []
    for text, label in zip(examples[text_tag], examples[lab]):
        # Convert label to string
        
        tok_text = tokenizer.encode(text, max_length=400, truncation=True)
        text = tokenizer.decode(tok_text)
        label_str = dataset['train'].features[lab].int2str(label)
        formatted_text = f"Classify emotion: {text}{tokenizer.sep_token}"#{label_str}{tokenizer.eos_token}"
        formatted_texts.append(formatted_text)
    return {'formatted_text': formatted_texts}

def tokenize_and_prepare(examples):

    # Tokenize with batch processing
    tokenized = tokenizer(
        examples['formatted_text'],
        padding='max_length',
        max_length=408,
        truncation=True,
        return_tensors="pt"
    )
    
    # Clone input_ids to create labels
    labels = tokenized['input_ids'].clone()
    
    # Find the position of sep_token
    sep_token_id = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
    sep_positions = (labels == sep_token_id).nonzero(as_tuple=True)
    
    # Mask all tokens with -100 except for the token right after sep_token
    labels[:] = -100  # Mask all initially
    for batch_idx, sep_pos in zip(*sep_positions):
        if sep_pos + 1 < labels.size(1):
            labels[batch_idx, sep_pos + 1] = tokenized['input_ids'][batch_idx, sep_pos + 1]
    
    # Set padding tokens to -100
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'labels': labels
    }
    
dataset = dataset.filter(lambda x: x[lab] != -1)
# Process the dataset
formatted_dataset = dataset.map(format_data, batched=True)
tokenized_dataset = formatted_dataset.map(
    tokenize_and_prepare, 
    batched=True,
)

from transformers import GPT2LMHeadModel as gt
from models.gpt2 import GPT2LMHeadModel
# Load pre-trained GPT-2 model
model1 = gt.from_pretrained('gpt2')

model1.resize_token_embeddings(len(tokenizer))

model1.config.m_layer = layer
import os

base_path = os.path.join("model_weights", dataset_name)
if not os.path.exists(base_path):
    os.makedirs(base_path)

weights_path = os.path.join(base_path, "weights.pth")

model = GPT2LMHeadModel(model1.config)


model.load_state_dict(torch.load(weights_path))




from prettytable import PrettyTable
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.tensor")

batch_size = 2048/4
# mask_layer = 5
compliment = True
results_table = PrettyTable()
if(compliment):
    results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf", "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf", "Total Masked", "Intersection"]#, "Same as Max"]#"MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []

#merge test and train set and then shuffle and make splits

# First merge and shuffle
# tokenized_dataset = concatenate_datasets([tokenized_dataset['train'], tokenized_dataset['test']]).shuffle(seed=42)#.select(range(100))

# Get the total length
# dataset_length = len(tokenized_dataset)


# Calculate split index
# split_index = int(dataset_length * 0.2)  # 80% for training

# Create the splits using dataset slicing
tokenized_dataset1 = tokenized_dataset['test']#.shuffle().select(range(2000))
recording_dataset = tokenized_dataset['train']#.shuffle().select(range(2000))

    



all_fc_vals = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
print("Recording activations...")
for j in range(0,num_classes):
    dataset_recording = recording_dataset.filter(lambda x: x[lab] in [j])
    dataset = tokenized_dataset1.filter(lambda x: x[lab] in [j])
    dataset_complement = tokenized_dataset1.filter(lambda x: x[lab] not in [j])
    fc_vals = evaluate_gpt2_classification(lab, model, dataset_recording, tokenizer)
    fc_vals = fc_vals[2]
    all_fc_vals.append(np.array(fc_vals))
    
    
    
    acc = evaluate_gpt2_classification(lab, model, dataset, tokenizer)
    
    base_accuracies.append(acc[0])
    base_confidences.append(acc[1])
    
    print("Class ",j, "base accuracy: ", acc[0], acc[1])
    
    acc = evaluate_gpt2_classification(lab, model, dataset_complement, tokenizer)
    
    base_comp_acc.append(acc[0])
    base_comp_conf.append(acc[1])
    
    print("Class ",j, "complement base accuracy: ", acc[0], acc[1])
    



Using the latest cached version of the dataset since fancyzhx/dbpedia_14 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'dbpedia_14' at /u/amo-d1/grad/mha361/.cache/huggingface/datasets/fancyzhx___dbpedia_14/dbpedia_14/0.0.0/9abd46cf7fc8b4c64290f26993c540b92aa145ac (last modified on Mon Mar  3 14:49:57 2025).


DatasetDict({
    train: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 560000
    })
    test: Dataset({
        features: ['label', 'title', 'content'],
        num_rows: 70000
    })
})
'Company' is already a single token (ID: 39154)
'Artist' is already a single token (ID: 43020)
'Building' is already a single token (ID: 25954)
'Animal' is already a single token (ID: 40002)
'Film' is already a single token (ID: 39750)





Added 9 new tokens to the tokenizer


  model.load_state_dict(torch.load(weights_path))


Recording activations...


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

  input_ids = torch.tensor(item['input_ids']).to(device)
  attention_mask = torch.tensor(item['attention_mask']).to(device)


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  0 base accuracy:  0.9756 0.9691


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  0 complement base accuracy:  0.9935 0.992


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  1 base accuracy:  0.9898 0.9868


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  1 complement base accuracy:  0.9924 0.9906


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  2 base accuracy:  0.9894 0.9856


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  2 complement base accuracy:  0.9924 0.9907


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  3 base accuracy:  0.9976 0.9965


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  3 complement base accuracy:  0.9918 0.9899


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  4 base accuracy:  0.9868 0.984


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  4 complement base accuracy:  0.9926 0.9909


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  5 base accuracy:  0.9944 0.994


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  5 complement base accuracy:  0.9921 0.9901


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  6 base accuracy:  0.986 0.9816


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  6 complement base accuracy:  0.9927 0.991


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  7 base accuracy:  0.9964 0.9956


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  7 complement base accuracy:  0.9919 0.99


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  8 base accuracy:  0.9992 0.999


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  8 complement base accuracy:  0.9917 0.9897


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  9 base accuracy:  0.9978 0.9979


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  9 complement base accuracy:  0.9918 0.9898


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  10 base accuracy:  0.9968 0.9963


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  10 complement base accuracy:  0.9919 0.9899


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  11 base accuracy:  0.9944 0.9938


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  11 complement base accuracy:  0.9921 0.9901


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  12 base accuracy:  0.9934 0.9928


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  12 complement base accuracy:  0.9921 0.9902


Filter:   0%|          | 0/560000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/70000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

Class  13 base accuracy:  0.9936 0.9925


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

Class  13 complement base accuracy:  0.9921 0.9902


In [2]:
per = 0.05
tables = []
# for i in range(0,20):
#     per = 0.02+(i*0.005)
print("percentage: ", per)

# model.transformer.mask_layer.fit_kde(all_fc_vals, threshold=per)

model.transformer.mask_layer.fit_histogram(all_fc_vals, threshold=per, num_bins=1000)
    
    
    
results_table = PrettyTable()
if(compliment):
    results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf"]

class_labels = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []
    

for j in range(0,num_classes):
    model.transformer.mask_layer.set_class(j)
    fc_vals = all_fc_vals[j]
    # model = mask_gpt2(model, torch.ones(768).to('cuda'))
    dataset = tokenized_dataset1.filter(lambda x: x[lab] in [j])
    dataset_recording = recording_dataset.filter(lambda x: x[lab] in [j])
    dataset_complement = tokenized_dataset1.filter(lambda x: x[lab] not in [j])
    

    class_labels.append(f"Class {j}")
    # acc = evaluate_gpt2_classification(lab, model, dataset, tokenizer)
    print("Class ",j, "base accuracy: ", base_accuracies[j], base_confidences[j])
    if(compliment):
        print("Class ",j, "complement base accuracy: ", base_comp_acc[j], base_comp_conf[j])
    acc = evaluate_gpt2_classification(lab, model, dataset, tokenizer) 
    print("accuracy after masking STD: ", acc[0], acc[1])
    std_accuracies.append(acc[0])
    std_confidences.append(acc[1])
    if(compliment):
        acc = evaluate_gpt2_classification(lab, model, dataset_complement, tokenizer)
        print("accuracy after masking STD on complement: ", acc[0], acc[1])
        std_comp_acc.append(acc[0])
        std_comp_conf.append(acc[1])
    if(compliment):
        results_table.add_row([
            class_labels[j],
            base_accuracies[j],
            base_confidences[j],
            base_comp_acc[j],
            base_comp_conf[j],
            std_accuracies[j],
            std_confidences[j],
            std_comp_acc[j],
            std_comp_conf[j],
        ])            
print(results_table)
tables.append(results_table)
per = 0
for table in tables:
    per += 0.01
    print("percentage: ", per)
    print(table)

percentage:  0.05
Class  0 base accuracy:  0.9756 0.9691
Class  0 complement base accuracy:  0.9935 0.992


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0246 0.0193


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9893 0.9794
Class  1 base accuracy:  0.9898 0.9868
Class  1 complement base accuracy:  0.9924 0.9906


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.1866 0.0858


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9735 0.9593
Class  2 base accuracy:  0.9894 0.9856
Class  2 complement base accuracy:  0.9924 0.9907


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.011 0.0084


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9846 0.9788
Class  3 base accuracy:  0.9976 0.9965
Class  3 complement base accuracy:  0.9918 0.9899


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0798 0.0487


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9825 0.9623
Class  4 base accuracy:  0.9868 0.984
Class  4 complement base accuracy:  0.9926 0.9909


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0582 0.0407


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9398 0.9085
Class  5 base accuracy:  0.9944 0.994
Class  5 complement base accuracy:  0.9921 0.9901


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0544 0.0378


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9786 0.961
Class  6 base accuracy:  0.986 0.9816
Class  6 complement base accuracy:  0.9927 0.991


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0164 0.0121


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9854 0.9799
Class  7 base accuracy:  0.9964 0.9956
Class  7 complement base accuracy:  0.9919 0.99


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0498 0.0241


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9716 0.9392
Class  8 base accuracy:  0.9992 0.999
Class  8 complement base accuracy:  0.9917 0.9897


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0942 0.0552


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9786 0.9476
Class  9 base accuracy:  0.9978 0.9979
Class  9 complement base accuracy:  0.9918 0.9898


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0526 0.0379


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.98 0.9552
Class  10 base accuracy:  0.9968 0.9963
Class  10 complement base accuracy:  0.9919 0.9899


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.1436 0.0816


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9885 0.9778
Class  11 base accuracy:  0.9944 0.9938
Class  11 complement base accuracy:  0.9921 0.9901


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.1174 0.0616


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9874 0.9723
Class  12 base accuracy:  0.9934 0.9928
Class  12 complement base accuracy:  0.9921 0.9902


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0302 0.0238


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9722 0.9467
Class  13 base accuracy:  0.9936 0.9925
Class  13 complement base accuracy:  0.9921 0.9902


Evaluating:   0%|          | 0/40 [00:00<?, ?it/s]

accuracy after masking STD:  0.0974 0.0528


Evaluating:   0%|          | 0/508 [00:00<?, ?it/s]

accuracy after masking STD on complement:  0.9638 0.9293
+----------+---------------+-----------------+---------------------+----------------------+--------------+----------------+--------------------+---------------------+
|  Class   | Base Accuracy | Base Confidence | Base Complement Acc | Base Compliment Conf | STD Accuracy | STD Confidence | STD compliment ACC | STD compliment Conf |
+----------+---------------+-----------------+---------------------+----------------------+--------------+----------------+--------------------+---------------------+
| Class 0  |     0.9756    |      0.9691     |        0.9935       |        0.992         |    0.0246    |     0.0193     |       0.9893       |        0.9794       |
| Class 1  |     0.9898    |      0.9868     |        0.9924       |        0.9906        |    0.1866    |     0.0858     |       0.9735       |        0.9593       |
| Class 2  |     0.9894    |      0.9856     |        0.9924       |        0.9907        |    0.011     |  