In [None]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
from tqdm import tqdm
import torch
from utilities import get_model_distilbert, record_activations

mask_layer = 5
text_tag = "sentence"
compliment = True

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the dataset
dataset_all = load_dataset("stanfordnlp/sst2")
# Select the train split
dataset_all = dataset_all['train']
model = get_model_distilbert("distilbert-base-uncased-finetuned-sst-2-english", mask_layer)
model.to(device)
model.eval()
all_fc_vals = []
batch_size = 32  # You can adjust this based on your GPU memory
for j in range(0,2):
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    fc_vals = record_activations(dataset, model, tokenizer, text_tag='sentence', batch_size=256, mask_layer=mask_layer)
    all_fc_vals.append(fc_vals)

In [None]:
from datasets import load_dataset, concatenate_datasets

from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
import random
import numpy as np
from utilities import evaluate_gpt2_classification as evaluate_gpt2_classification, mask_range_gpt,compute_masks, reset_gpt
import torch  

dataset_name = "fancyzhx/dbpedia_14"

text_tag = "text"

# Load dataset and tokenizer


tables = []
layer = 11
# for layer in range(0,12):
per = 0.2
print("Percentage: ", per)
num_classes = 4

# tao = 2.5

lab = "label"
# tao = torch.inf

dataset = load_dataset(dataset_name)

print(dataset)
# Set random seed
seed_value = 42  # or any other integer

random.seed(seed_value)
np.random.seed(seed_value)

if torch.cuda.is_available():  # PyTorch-specific
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

import torch

torch.autograd.set_detect_anomaly(True)
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


special_tokens_dict = {}
new_tokens = []
label2text = dataset['train'].features[lab].names

for label in label2text:
    # Create special token format (with and without space)
    special_token = f'{label}'
    
    # Check if the label is already a single token in the tokenizer
    label_tokens = tokenizer.encode(label, add_special_tokens=False)
    is_single_token = len(label_tokens) == 1
    
    if is_single_token:
        print(f"'{label}' is already a single token (ID: {label_tokens[0]})")
    
    # Add both versions to new tokens list
    new_tokens.extend([special_token])

# Add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(new_tokens)
print(f"\nAdded {num_added_tokens} new tokens to the tokenizer")

special_tokens = {
    'pad_token': '<|pad|>',
    'sep_token': '<|sep|>',
    'eos_token': '<|eos|>'
}
tokenizer.add_special_tokens(special_tokens)

def format_data(examples):
    formatted_texts = []
    for text, label in zip(examples[text_tag], examples[lab]):
        # Convert label to string
        
        tok_text = tokenizer.encode(text, max_length=400, truncation=True)
        text = tokenizer.decode(tok_text)
        label_str = dataset['train'].features[lab].int2str(label)
        formatted_text = f"Classify emotion: {text}{tokenizer.sep_token}"#{label_str}{tokenizer.eos_token}"
        formatted_texts.append(formatted_text)
    return {'formatted_text': formatted_texts}

def tokenize_and_prepare(examples):

    # Tokenize with batch processing
    tokenized = tokenizer(
        examples['formatted_text'],
        padding='max_length',
        max_length=408,
        truncation=True,
        return_tensors="pt"
    )
    
    # Clone input_ids to create labels
    labels = tokenized['input_ids'].clone()
    
    # Find the position of sep_token
    sep_token_id = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
    sep_positions = (labels == sep_token_id).nonzero(as_tuple=True)
    
    # Mask all tokens with -100 except for the token right after sep_token
    labels[:] = -100  # Mask all initially
    for batch_idx, sep_pos in zip(*sep_positions):
        if sep_pos + 1 < labels.size(1):
            labels[batch_idx, sep_pos + 1] = tokenized['input_ids'][batch_idx, sep_pos + 1]
    
    # Set padding tokens to -100
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'labels': labels
    }
# Process the dataset
formatted_dataset = dataset.map(format_data, batched=True)
tokenized_dataset = formatted_dataset.map(
    tokenize_and_prepare, 
    batched=True,
)

from transformers import GPT2LMHeadModel as gt
from models.gpt2 import GPT2LMHeadModel
# Load pre-trained GPT-2 model
model1 = gt.from_pretrained('gpt2')

model1.resize_token_embeddings(len(tokenizer))

model1.config.m_layer = layer
import os

base_path = os.path.join("model_weights", dataset_name)
if not os.path.exists(base_path):
    os.makedirs(base_path)

weights_path = os.path.join(base_path, "weights.pth")

model = GPT2LMHeadModel(model1.config)


model.load_state_dict(torch.load(weights_path))
dataset_all = tokenized_dataset['train']

all_fc_vals = []
for j in range(0,num_classes):
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    fc_vals = evaluate_gpt2_classification(lab, model, dataset, tokenizer)
    print('Accuracy : ', fc_vals[0], 'Confidence : ', fc_vals[1])
    fc_vals = fc_vals[2]

Percentage:  0.2
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})
'World' is already a single token (ID: 10603)
'Sports' is already a single token (ID: 18153)
'Business' is already a single token (ID: 24749)

Added 1 new tokens to the tokenizer


  model.load_state_dict(torch.load(weights_path))


Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/938 [00:00<?, ?it/s]

  input_ids = torch.tensor(item['input_ids']).to(device)
  attention_mask = torch.tensor(item['attention_mask']).to(device)


Accuracy :  0.9682 Confidence :  0.9625


Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [None]:
import numpy as np
from utilities import compute_masks

output_widgets = []
j = 10
max_all = []
for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    mini = np.min(v, axis=0)
    maxi = np.max(v, axis=0)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(v, 0.30)
    max_all.append(mask_max)


In [None]:
#print number of differences between the two classes
print(torch.sum(max_all[0]!=max_all[1]))
print(max_all[0]==max_all[1])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)    
    
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)

    # Create a new figure for each set of values
    out = Output()
    with out:
        plt.figure(figsize=(12, 6))

        # Plot the mean
        plt.subplot(1, 2, 1)
        plt.plot(m, 'bo', markersize=4)
        plt.title(f'Mean of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Mean Value')
        plt.ylim(0, np.max(m))  # Ensure y-axis starts at 0

        # Plot the standard deviation
        plt.subplot(1, 2, 2)
        plt.plot(s, 'ro', markersize=4)
        plt.title(f'Standard Deviation of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Standard Deviation')
        plt.ylim(0, np.max(s))  # Ensure y-axis starts at 0

        # Show the plots
        plt.tight_layout()
        plt.show()

    output_widgets.append(out)

# Display all figures in a vertical box
VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)
    # Create a new figure for each set of values
    out = Output()
    with out:
        fig = go.Figure()

        # Plot the mean with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=m,
            mode='markers',
            name='Mean',
            marker=dict(size=3, color='blue')
        ))

        # Plot the standard deviation with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=s,
            mode='markers',
            name='Std Dev',
            marker=dict(size=3, color='red')
        ))

        # Add lines connecting corresponding points
        for j in range(768):
            fig.add_trace(go.Scatter(
                x=[j, j],
                y=[m[j], s[j]],
                mode='lines',
                line=dict(color='gray', width=0.5),
                showlegend=False
            ))

        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Set {i+1}',
            xaxis_title='Activation Index',
            yaxis_title='Value',
            yaxis=dict(range=[0, max(np.max(m), np.max(s)) * 1.1]),
            height=600,
            width=1000,
            hovermode='closest'
        )

        fig.show()
    
    output_widgets.append(out)

# Display all figures in a vertical box
# VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
import importlib
import utilities

importlib.reload(utilities)
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_std_high_max = compute_masks(fc1, 0.15)
    mask_std = mask_std_high_max
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=m_norm[indices],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=s_norm[indices],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all plots
        add_traces(indices_max, 1, 1)
        add_traces(indices_std, 1, 2)
        add_traces(indices_intersection, 1, 3)
        add_traces(indices_max_minus_std, 2, 1)
        add_traces(indices_std_minus_max, 2, 2)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                if row == 2 and col == 3:
                    continue  # Skip the empty subplot
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std = compute_masks(fc1, 0.15)
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_all = len(m_norm)
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"All Activations (Count: {count_all})",
                                            f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            indices_list = list(indices)  # Convert range or numpy array to list
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=m_norm[indices_list],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=s_norm[indices_list],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices_list:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all activations
        add_traces(range(len(m_norm)), 1, 1)
        
        # Add traces for other plots
        add_traces(indices_max, 1, 2)
        add_traces(indices_std, 1, 3)
        add_traces(indices_intersection, 2, 1)
        add_traces(indices_max_minus_std, 2, 2)
        add_traces(indices_std_minus_max, 2, 3)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

def create_index_tracking_plot(indices_per_class, title):
    num_classes = len(indices_per_class)
    all_indices = sorted(set.union(*[set(indices) for indices in indices_per_class]))
    
    fig = go.Figure()
    
    # Create a color scale
    color_scale = px.colors.diverging.RdYlGn_r  # Red to Yellow to Green color scale

    # Add edges for indices present in multiple classes
    for idx in all_indices:
        classes_with_idx = [i for i, indices in enumerate(indices_per_class) if idx in indices]
        if len(classes_with_idx) > 1:
            x = [idx] * len(classes_with_idx)
            y = classes_with_idx
            color_index = (len(classes_with_idx) - 1) / (num_classes - 1)  # Normalize to [0, 1]
            edge_color = px.colors.sample_colorscale(color_scale, [color_index])[0]
            
            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='text',
                hovertext=f'Index: {idx}<br>Present in {len(classes_with_idx)} classes',
                showlegend=False
            ))
    
    # Add scatter plots for each class
    for class_idx, indices in enumerate(indices_per_class):
        fig.add_trace(go.Scatter(
            x=indices,
            y=[class_idx] * len(indices),
            mode='markers',
            name=f'Class {class_idx + 1}',
            marker=dict(size=4, symbol='circle', color='black'),
            hoverinfo='text',
            hovertext=[f'Index: {idx}<br>Class: {class_idx + 1}' for idx in indices]
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Activation Index',
        yaxis_title='Class',
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(num_classes)),
            ticktext=[f'Class {i+1}' for i in range(num_classes)]
        ),
        hovermode='closest',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    # Add color bar
    fig.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(
            colorscale=color_scale,
            showscale=True,
            cmin=1,
            cmax=num_classes,
            colorbar=dict(
                title='Number of Classes',
                tickvals=list(range(1, num_classes+1)),
                ticktext=list(range(1, num_classes+1))
            )
        ),
        hoverinfo='none',
        showlegend=False
    ))
    
    return fig

# Collect indices for each class
max_indices_per_class = []
std_indices_per_class = []

for fc1 in all_fc_vals:
    mask_max, mask_std = compute_masks(fc1, 0.15)
    max_indices_per_class.append(np.where(mask_max == 0)[0])
    std_indices_per_class.append(np.where(mask_std == 0)[0])

# Create and display visualizations
output_widgets = []

out = Output()
with out:
    fig_max = create_index_tracking_plot(max_indices_per_class, 'Max Mask Indices Across Classes')
    display(fig_max)
output_widgets.append(out)

out = Output()
with out:
    fig_std = create_index_tracking_plot(std_indices_per_class, 'Std Mask Indices Across Classes')
    display(fig_std)
output_widgets.append(out)

# Display all visualizations
# display(VBox(output_widgets))

In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch
from prettytable import PrettyTable
# from model_distill_bert import getmodel
from utilities import compute_accuracy, compute_masks, mask_distillbert, get_model_distilbert, record_activations

batch_size = 256
mask_layer = 5
text_tag = "text"
compliment = True
results_table = PrettyTable()
if(compliment):
   results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf", "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf", "Total Masked", "Intersedction"]#, "Same as Max"]#"MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"
# results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "STD Accuracy", "STD Confidence", "Same as Max"]#, "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"]

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []

dataset_list = []
tokenizer = AutoTokenizer.from_pretrained("2O24dpower2024/distilbert-base-uncased-finetuned-emotion")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load the dataset
dataset_all = load_dataset("dair-ai/emotion")
dataset_all = dataset_all['train']

for j in range(0,7):
    # model = get_model_distilbert("esuriddick/distilbert-base-uncased-finetuned-emotion", mask_layer)
    
    model = get_model_distilbert("2O24dpower2024/distilbert-base-uncased-finetuned-emotion", mask_layer)
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    dataset_complement = dataset_all.filter(lambda x: x['label'] not in [j])
    
    if(j==6):
        dataset = dataset_all

    class_labels.append(f"Class {j}")
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size)
    dataset_list.append(acc[2])
    print("Class ",j, "base accuracy: ", acc[0], acc[1])
    base_accuracies.append(acc[0])
    base_confidences.append(acc[1])
    aug_dataset = acc[2]
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag , batch_size=batch_size)
        print("Class ",j, "complement base accuracy: ", acc[0], acc[1])
        base_comp_acc.append(acc[0])
        base_comp_conf.append(acc[1])
        aug_dataset.extend(acc[2])
        

    #record the activations of the first fully connected layer, CLS tokken
    print("Recording activations...")
    # progress_bar = tqdm(total=len(dataset))
    # model.to(device)
    # model.eval()
    # fc_vals = []
    # with torch.no_grad():
    #     for i in range(len(dataset)):
    #         text = dataset[i]['sentence']
    #         inputs = tokenizer(text, return_tensors="pt").to(device)
    #         outputs = model(**inputs)
    #         fc_vals.append(outputs[1][mask_layer+1][:, 0].squeeze().cpu().numpy())
    #         progress_bar.update(1)
    #     progress_bar.close()

    fc_vals = record_activations(dataset, model, tokenizer, text_tag=text_tag, mask_layer=mask_layer, batch_size=batch_size)

        
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc_vals,0.50)
    mask_std = mask_std_high_max
    print("Masking STD...")
    model = mask_distillbert(model,mask_std)
    t = int(mask_std.shape[0]-torch.count_nonzero(mask_std))
    print("Total Masked :", t)
    total_masked.append(t)
    diff_from_max.append(int((torch.logical_or(mask_std, mask_max) == 0).sum().item()))
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[:len(dataset)]) 
    dataset_list.append(acc[2])
    print("accuracy after masking STD: ", acc[0], acc[1])
    std_accuracies.append(acc[0])
    std_confidences.append(acc[1])
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[len(dataset):])
        print("accuracy after masking STD on complement: ", acc[0], acc[1])
        std_comp_acc.append(acc[0])
        std_comp_conf.append(acc[1])

    print("Masking MAX...")
    model = mask_distillbert(model,mask_max)
    t = int(mask_max.shape[0]-torch.count_nonzero(mask_max))
    print("Total Masked :", t)
    # total_masked.append(t)
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[:len(dataset)])
    dataset_list.append(acc[2])
    print("accuracy after masking MAX: ", acc[0], acc[1])
    max_accuracies.append(acc[0])
    max_confidences.append(acc[1])
    acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[len(dataset):])
    print("accuracy after masking MAX on complement: ", acc[0], acc[1])
    max_comp_acc.append(acc[0])
    max_comp_conf.append(acc[1])
    if(compliment):
        results_table.add_row([
            class_labels[j],
            base_accuracies[j],
            base_confidences[j],
            base_comp_acc[j],
            base_comp_conf[j],
            std_accuracies[j],
            std_confidences[j],
            std_comp_acc[j],
            std_comp_conf[j],
            max_accuracies[j],
            max_confidences[j],
            max_comp_acc[j],
            max_comp_conf[j],
            total_masked[j],
            diff_from_max[j]
        ])
    # results_table.add_row([
    #     class_labels[j],
    #     base_accuracies[j],
    #     base_confidences[j],
    #     std_accuracies[j],
    #     std_confidences[j],
    #     # max_accuracies[j],
    #     # max_confidences[j],
    #     diff_from_max[j]
    # ])

print(results_table)

In [None]:
import ipywidgets as widgets
from IPython.display import display, HTML

# Assuming you already have your DataFrame
# df = pd.DataFrame(dataset_list[1])

def display_df(dataframe, rows_per_page=10):
    # Convert the dataframe to HTML
    html = dataframe.to_html(classes='table table-striped')
    
    # Add Bootstrap CSS
    html = f"""
    <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css">
    <div class="container">
        {html}
    </div>
    """
    
    # Add pagination
    total_rows = len(dataframe)
    total_pages = (total_rows - 1) // rows_per_page + 1
    
    pagination_html = f"""
    <nav>
        <ul class="pagination justify-content-center">
            <li class="page-item"><a class="page-link" href="#" id="prev-page">Previous</a></li>
            <li class="page-item"><span class="page-link" id="current-page">1 / {total_pages}</span></li>
            <li class="page-item"><a class="page-link" href="#" id="next-page">Next</a></li>
        </ul>
    </nav>
    """
    
    # Add JavaScript for pagination functionality
    js = f"""
    <script>
        var currentPage = 1;
        var rowsPerPage = {rows_per_page};
        var totalPages = {total_pages};
        
        function showPage(page) {{
            var rows = document.querySelectorAll('table.table tbody tr');
            for (var i = 0; i < rows.length; i++) {{
                if (i >= (page - 1) * rowsPerPage && i < page * rowsPerPage) {{
                    rows[i].style.display = '';
                }} else {{
                    rows[i].style.display = 'none';
                }}
            }}
            document.getElementById('current-page').textContent = page + ' / ' + totalPages;
        }}
        
        document.getElementById('prev-page').addEventListener('click', function(e) {{
            e.preventDefault();
            if (currentPage > 1) {{
                currentPage--;
                showPage(currentPage);
            }}
        }});
        
        document.getElementById('next-page').addEventListener('click', function(e) {{
            e.preventDefault();
            if (currentPage < totalPages) {{
                currentPage++;
                showPage(currentPage);
            }}
        }});
        
        showPage(1);
    </script>
    """
    
    # Combine all HTML and JavaScript
    full_html = html + pagination_html + js
    
    # Display the result
    display(HTML(full_html))

In [None]:
import pandas as pd

df = pd.DataFrame(dataset_list[5])

In [None]:
# df = df.iloc[::-1]
display_df(df)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("esuriddick/distilbert-base-uncased-finetuned-emotion")

In [24]:
data = '''Class 0	0.8667	0.6855	0.8684	0.6391	0.7333	0.0161	0.8675	0.4598	0.5667	0.0092	0.8645	0.1126	384	199
Class 1	1	0.808	0.8676	0.6384	0.7	0.0093	0.8631	0.4242	0.6333	0.0088	0.8634	0.1054	384	194
Class 2	0.9	0.589	0.8682	0.6396	0.5333	0.0137	0.8665	0.4151	0.5333	0.0106	0.8596	0.1042	384	210
Class 3	1	0.7492	0.8676	0.6387	0.7	0.0167	0.8691	0.4978	0.7	0.0109	0.8717	0.1076	384	184
Class 4	0.9667	0.754	0.8678	0.6387	0.7	0.0142	0.8691	0.4898	0.5667	0.0104	0.8675	0.1135	384	204
Class 5	1	0.8596	0.8676	0.6381	1	0.0097	0.8709	0.4649	1	0.0097	0.8691	0.1131	384	177
Class 6	0.7	0.5288	0.8693	0.6399	0.5	0.0141	0.8717	0.4407	0.4667	0.0093	0.8728	0.1102	384	186
Class 7	0.9667	0.7714	0.8678	0.6386	0.9333	0.0173	0.8693	0.4771	0.7667	0.0097	0.8697	0.1061	384	187
Class 8	0.8667	0.6623	0.8684	0.6392	0.6333	0.0121	0.864	0.4495	0.6333	0.0097	0.8697	0.1107	384	191
Class 9	0.9	0.7383	0.8682	0.6388	0.4667	0.0115	0.8671	0.4616	0.4333	0.0093	0.8669	0.1064	384	195
Class 10	1	0.7031	0.8676	0.639	0.1667	0.0097	0.8726	0.3615	0.1333	0.0084	0.8678	0.1083	384	203
Class 11	0.9667	0.5542	0.8678	0.6398	0.3	0.0124	0.8616	0.4243	0.2	0.0092	0.8607	0.1081	384	192
Class 12	0.8667	0.6339	0.8684	0.6394	0.6667	0.0117	0.8682	0.4583	0.6333	0.0091	0.8667	0.1123	384	198
Class 13	1	0.694	0.8676	0.639	0.6667	0.0132	0.8654	0.4148	0.6667	0.0098	0.8654	0.1054	384	196
Class 14	0.9667	0.7763	0.8678	0.6386	0.8	0.0094	0.868	0.4339	0.8	0.0091	0.868	0.1107	384	191
Class 15	1	0.7357	0.8676	0.6388	0.8667	0.0122	0.8644	0.4311	0.8	0.0099	0.8684	0.11	384	189
Class 16	0.8667	0.5376	0.8684	0.6399	0.1667	0.011	0.8693	0.416	0.1333	0.0085	0.8636	0.1023	384	194
Class 17	1	0.8239	0.8676	0.6383	1	0.0109	0.8627	0.425	0.9667	0.0102	0.8603	0.108	384	182
Class 18	1	0.7937	0.8676	0.6385	0.9	0.0116	0.8665	0.4662	0.7333	0.0096	0.8671	0.1085	384	183
Class 19	0.9667	0.7218	0.8678	0.6389	0.2	0.0115	0.8651	0.4389	0.1	0.0089	0.8616	0.1081	384	204
Class 20	1	0.6735	0.8676	0.6392	0.7	0.0129	0.8687	0.4422	0.6	0.01	0.8656	0.1034	384	196
Class 21	1	0.7379	0.8676	0.6388	0.6333	0.0128	0.862	0.4434	0.5	0.01	0.8644	0.1059	384	203
Class 22	0.9	0.7642	0.8682	0.6387	0.8333	0.0104	0.8654	0.4449	0.8333	0.0088	0.8653	0.108	384	193
Class 23	0.9667	0.6396	0.8678	0.6393	0.6	0.0139	0.8702	0.4739	0.5667	0.0091	0.8671	0.1112	384	203
Class 24	1	0.8164	0.8676	0.6384	0.8667	0.0108	0.8709	0.4844	0.7667	0.0095	0.8697	0.1068	384	202
Class 25	0.9667	0.753	0.8678	0.6387	0.8667	0.0122	0.8695	0.4475	0.7	0.0098	0.8709	0.1104	384	190
Class 26	0.9333	0.7457	0.868	0.6388	0.9	0.0122	0.8697	0.5148	0.9	0.0099	0.8691	0.1098	384	183
Class 27	0.8333	0.5667	0.8686	0.6397	0.2667	0.0121	0.8733	0.4111	0.2667	0.0087	0.8709	0.1127	384	210
Class 28	0.9333	0.668	0.868	0.6392	0.6	0.0131	0.8676	0.453	0.4667	0.0095	0.8651	0.1098	384	190
Class 29	1	0.7926	0.8676	0.6385	0.7333	0.0127	0.8647	0.4622	0.5667	0.0095	0.8665	0.1098	384	197
Class 30	1	0.6601	0.8676	0.6392	0.5333	0.0105	0.8698	0.4343	0.5667	0.0094	0.866	0.1087	384	206
Class 31	0.9667	0.7784	0.8678	0.6386	0.9	0.0126	0.8667	0.4741	0.8	0.0094	0.866	0.1118	384	209
Class 32	0.9667	0.8114	0.8678	0.6384	0.9	0.0104	0.8662	0.4603	0.9333	0.0096	0.8676	0.1073	384	206
Class 33	1	0.8504	0.8676	0.6382	1	0.0109	0.8698	0.4952	0.9667	0.01	0.8658	0.1079	384	176
Class 34	1	0.8042	0.8676	0.6384	1	0.0106	0.8625	0.4679	0.9667	0.0095	0.8574	0.1077	384	200
Class 35	0.9667	0.7253	0.8678	0.6389	0.6667	0.0118	0.8656	0.4178	0.6	0.0096	0.87	0.1091	384	196
Class 36	1	0.8209	0.8676	0.6383	1	0.0101	0.866	0.475	1	0.0099	0.8704	0.1101	384	196
Class 37	1	0.77	0.8676	0.6386	0.8667	0.0111	0.8633	0.4078	0.7333	0.0098	0.8665	0.1095	384	204
Class 38	1	0.832	0.8676	0.6383	0.7333	0.0137	0.8651	0.4764	0.4667	0.0092	0.8618	0.1098	384	196
Class 39	1	0.8287	0.8676	0.6383	0.9667	0.0107	0.8687	0.5185	0.9667	0.0095	0.8658	0.1073	384	197
Class 40	1	0.7854	0.8676	0.6385	0.9	0.0147	0.8709	0.4575	0.7	0.0099	0.8644	0.1099	384	196
Class 41	0.8667	0.5996	0.8684	0.6396	0.5	0.0119	0.8656	0.4208	0.5	0.0095	0.8622	0.1074	384	168
Class 42	0.514	0.29	0.9471	0.717	0.036	0.0056	0.9467	0.3255	0.037	0.007	0.9476	0.1226	384	195
Class 43	0.8333	0.5676	0.8686	0.6397	0.5	0.0128	0.864	0.4204	0.4667	0.0094	0.8675	0.1051	384	185
Class 44	0.9333	0.7291	0.868	0.6388	0.8	0.0142	0.868	0.4585	0.7333	0.0099	0.8665	0.1041	384	212
Class 45	1	0.8544	0.8676	0.6382	1	0.014	0.8682	0.5627	1	0.0102	0.8676	0.1082	384	186
Class 46	0.8	0.603	0.8687	0.6395	0.3333	0.0115	0.8718	0.4216	0.3	0.0095	0.8651	0.1081	384	190
Class 47	0.9667	0.7855	0.8678	0.6385	0.8333	0.0135	0.8656	0.4746	0.7	0.0098	0.8631	0.1038	384	185
Class 48	0.9333	0.6621	0.868	0.6392	0.6	0.0107	0.866	0.3827	0.6	0.0092	0.868	0.1064	384	188
Class 49	0.8	0.6157	0.8687	0.6395	0.4667	0.014	0.8693	0.4216	0.4333	0.009	0.8695	0.1051	384	201
Class 50	1	0.7456	0.8676	0.6388	0.4667	0.0097	0.8669	0.4467	0.4667	0.0092	0.8698	0.1098	384	210
Class 51	1	0.7898	0.8676	0.6385	0.9	0.0172	0.8673	0.4929	0.6667	0.0103	0.8686	0.11	384	195
Class 52	1	0.8303	0.8676	0.6383	1	0.0155	0.8636	0.4668	0.8333	0.0103	0.8673	0.1061	384	189
Class 53	0.8667	0.6475	0.8684	0.6393	0.5667	0.0114	0.8704	0.4375	0.5667	0.0093	0.8773	0.1114	384	193
Class 54	0.9	0.607	0.8682	0.6395	0.3667	0.0118	0.8686	0.3784	0.3333	0.0096	0.8656	0.1081	384	182
Class 55	0.9667	0.6171	0.8678	0.6395	0.2667	0.0089	0.8664	0.402	0.2667	0.0087	0.8675	0.1115	384	212
Class 56	0.9	0.6483	0.8682	0.6393	0.5667	0.0118	0.8744	0.4138	0.5333	0.0102	0.8676	0.107	384	192
Class 57	0.9	0.5923	0.8682	0.6396	0.5333	0.0142	0.8642	0.3944	0.3667	0.0092	0.8693	0.1064	384	191
Class 58	0.9333	0.7161	0.868	0.6389	0.8333	0.0111	0.8687	0.4555	0.7	0.0093	0.8665	0.1094	384	186
Class 59	1	0.8709	0.8676	0.6381	1	0.0137	0.8644	0.5103	0.8333	0.0095	0.866	0.1089	384	188
Class 60	0.8	0.5944	0.8687	0.6396	0.4	0.0123	0.8715	0.4215	0.3	0.0093	0.8676	0.11	384	200
Class 61	1	0.7286	0.8676	0.6388	0.4	0.0149	0.8616	0.4251	0.2	0.0092	0.8656	0.1088	384	186
Class 62	0.9	0.7418	0.8682	0.6388	0.9	0.012	0.8645	0.5055	0.8667	0.0088	0.8618	0.1084	384	201
Class 63	0.9333	0.7355	0.868	0.6388	0.7667	0.022	0.8689	0.5147	0.6333	0.0102	0.8654	0.1096	384	189
Class 64	0.9	0.7202	0.8682	0.6389	0.8667	0.013	0.8711	0.4932	0.8	0.0091	0.8728	0.1097	384	200
Class 65	1	0.7029	0.8676	0.639	0.5	0.014	0.8698	0.4401	0.4667	0.0101	0.8671	0.1091	384	209
Class 66	0.9667	0.7721	0.8678	0.6386	0.8	0.021	0.8671	0.538	0.7667	0.011	0.8693	0.1111	384	187
Class 67	1	0.7152	0.8676	0.6389	0.7333	0.0096	0.8678	0.4136	0.7	0.0095	0.8687	0.109	384	192
Class 68	0.8	0.5838	0.8687	0.6396	0.4333	0.0088	0.87	0.3841	0.4333	0.0089	0.8689	0.109	384	191
Class 69	0.9667	0.7655	0.8678	0.6386	0.8333	0.014	0.8664	0.4929	0.7333	0.0096	0.864	0.1047	384	206
Class 70	0.7333	0.5502	0.8691	0.6398	0.4667	0.0133	0.8676	0.4171	0.4	0.0086	0.864	0.1032	384	194
Class 71	0.8667	0.5836	0.8684	0.6396	0.3333	0.0104	0.8664	0.4019	0.3333	0.0081	0.868	0.1081	384	209
Class 72	1	0.7706	0.8676	0.6386	0.8333	0.0138	0.8667	0.432	0.6333	0.01	0.8698	0.1108	384	199
Class 73	1	0.7789	0.8676	0.6386	0.8667	0.0115	0.8656	0.4639	0.7333	0.0093	0.8665	0.1098	384	187
Class 74	1	0.7985	0.8676	0.6385	0.7	0.0108	0.862	0.4273	0.6667	0.0095	0.8618	0.1054	384	200
Class 75	0.8667	0.4779	0.8684	0.6402	0.5	0.01	0.8662	0.3932	0.5	0.0096	0.8684	0.11	384	212
Class 76	1	0.7738	0.8676	0.6386	0.8	0.0126	0.8647	0.4367	0.5	0.0093	0.8631	0.1118	384	200
Class 77	0.9667	0.6602	0.8678	0.6392	0.5667	0.0129	0.8645	0.4103	0.4333	0.0103	0.8675	0.1027	384	189
Class 78	0.9	0.6567	0.8682	0.6392	0.5667	0.0146	0.8678	0.4158	0.3333	0.0097	0.868	0.1098	384	209
Class 79	0.8667	0.6551	0.8684	0.6393	0.6	0.0147	0.8687	0.4458	0.5333	0.01	0.8665	0.1102	384	210
Class 80	1	0.7924	0.8676	0.6385	0.8333	0.015	0.8676	0.4631	0.6667	0.0098	0.8667	0.1094	384	198
Class 81	1	0.8355	0.8676	0.6383	0.8333	0.0115	0.8676	0.4713	0.8	0.0099	0.8664	0.1111	384	198
Class 82	0.9667	0.6965	0.8678	0.639	0.7333	0.0105	0.8691	0.3627	0.7333	0.0098	0.8686	0.1085	384	193
Class 83	0.9667	0.7965	0.8678	0.6385	0.8333	0.0099	0.8711	0.4438	0.8333	0.0091	0.8689	0.1053	384	183
Class 84	0.9667	0.7283	0.8678	0.6388	0.7667	0.0139	0.8686	0.4526	0.6	0.0099	0.868	0.1027	384	202
Class 85	1	0.7278	0.8676	0.6389	0.7667	0.0112	0.8707	0.4645	0.7	0.0098	0.8669	0.1119	384	193
Class 86	0.9667	0.7204	0.8678	0.6389	0.7	0.0127	0.8665	0.4365	0.6333	0.0098	0.8667	0.1111	384	193
Class 87	1	0.6578	0.8676	0.6392	0.3333	0.0106	0.8744	0.4161	0.3333	0.0091	0.8695	0.1121	384	196
Class 88	1	0.8461	0.8676	0.6382	1	0.0203	0.8676	0.5081	0.8667	0.0098	0.8748	0.1097	384	185
Class 89	0.8	0.6843	0.8687	0.6391	0.8	0.011	0.8673	0.5183	0.7667	0.009	0.8669	0.107	384	193
Class 90	0.9	0.6098	0.8682	0.6395	0.2333	0.0104	0.8698	0.3781	0.2333	0.0094	0.8693	0.1099	384	220
Class 91	1	0.6915	0.8676	0.6391	0.4333	0.0156	0.8649	0.4487	0.3	0.0096	0.8631	0.1047	384	189
Class 92	1	0.8676	0.8676	0.6381	1	0.0122	0.8675	0.5131	0.9333	0.0103	0.8707	0.111	384	210
Class 93	0.9667	0.6933	0.8678	0.639	0.5333	0.0112	0.8616	0.3852	0.3667	0.0094	0.8631	0.1086	384	192
Class 94	0.9333	0.8079	0.868	0.6384	0.9333	0.0131	0.8667	0.4849	0.9333	0.0093	0.8642	0.1064	384	201
Class 95	0.9667	0.788	0.8678	0.6385	0.9	0.0106	0.866	0.4435	0.7667	0.0093	0.8662	0.1068	384	209
Class 96	1	0.7579	0.8676	0.6387	0.6333	0.0108	0.8585	0.3865	0.6333	0.01	0.8607	0.1042	384	197
Class 97	0.9	0.6711	0.8682	0.6392	0.5	0.0159	0.87	0.4495	0.2667	0.0091	0.8633	0.106	384	209
Class 98	1	0.8303	0.8676	0.6383	0.9	0.0108	0.8682	0.4789	0.9	0.0102	0.8715	0.1083	384	197
Class 99	0.9667	0.7154	0.8678	0.6389	0.7	0.0106	0.8647	0.4179	0.6333	0.009	0.868	0.1093	384	188
Class 100	1	0.8755	0.8676	0.638	1	0.0151	0.8671	0.528	0.9667	0.0098	0.8653	0.1036	384	209
Class 101	0.9333	0.6996	0.868	0.639	0.5667	0.0132	0.8673	0.4339	0.5333	0.0089	0.862	0.1085	384	199
Class 102	0.8667	0.6747	0.8684	0.6391	0.4333	0.0115	0.8697	0.4165	0.3	0.0089	0.87	0.1127	384	212
Class 103	1	0.8107	0.8676	0.6384	0.8667	0.01	0.8665	0.4131	0.8333	0.0098	0.862	0.1105	384	195
Class 104	0.9667	0.6437	0.8678	0.6393	0.5667	0.0131	0.8684	0.4363	0.5667	0.0097	0.8709	0.1077	384	202
Class 105	0.9667	0.6634	0.8678	0.6392	0.7	0.0119	0.8737	0.4279	0.6333	0.0102	0.8702	0.1128	384	193
Class 106	0.9667	0.6461	0.8678	0.6393	0.3667	0.0174	0.8664	0.4191	0.1667	0.0099	0.8622	0.1043	384	199
Class 107	0.8	0.5461	0.8687	0.6398	0.3667	0.0126	0.8735	0.4513	0.3333	0.0094	0.8717	0.1106	384	201
Class 108	0.9667	0.7696	0.8678	0.6386	0.8667	0.009	0.8682	0.4573	0.7667	0.0086	0.8722	0.1093	384	202
Class 109	0.9667	0.7304	0.8678	0.6388	0.7333	0.0106	0.8689	0.456	0.7333	0.0095	0.8684	0.1067	384	188
Class 110	0.9333	0.6648	0.868	0.6392	0.3333	0.0102	0.8706	0.4313	0.2667	0.0087	0.8702	0.1064	384	218
Class 111	0.9333	0.6545	0.868	0.6393	0.7333	0.0144	0.8704	0.4668	0.6333	0.0092	0.8739	0.1108	384	187
Class 112	0.8667	0.5469	0.8684	0.6398	0.4	0.012	0.8658	0.4142	0.4	0.0088	0.8709	0.1105	384	188
Class 113	0.9667	0.8063	0.8678	0.6384	0.9667	0.0097	0.8654	0.5067	0.9667	0.0089	0.866	0.1087	384	199
Class 114	0.9	0.6331	0.8682	0.6394	0.3333	0.0105	0.8728	0.3988	0.3333	0.0095	0.8676	0.1091	384	186
Class 115	0.9333	0.6261	0.868	0.6394	0.2667	0.0117	0.8618	0.3792	0.2333	0.0098	0.8662	0.1065	384	183
Class 116	1	0.8005	0.8676	0.6385	1	0.0095	0.8669	0.4306	0.9	0.0093	0.8658	0.1133	384	179
Class 117	0.9667	0.7434	0.8678	0.6388	0.7	0.0113	0.8644	0.4322	0.4667	0.0095	0.8698	0.1079	384	194
Class 118	1	0.7257	0.8676	0.6389	0.7	0.0143	0.8645	0.4512	0.7	0.0093	0.8622	0.1003	384	200
Class 119	1	0.8004	0.8676	0.6385	0.8	0.0132	0.864	0.4422	0.6333	0.0101	0.8642	0.1088	384	189
Class 120	0.9667	0.7947	0.8678	0.6385	0.9333	0.0138	0.8633	0.4764	0.8	0.0091	0.864	0.108	384	200
Class 121	0.9333	0.6605	0.868	0.6392	0.3333	0.01	0.8662	0.3663	0.3667	0.0092	0.8675	0.111	384	225
Class 122	1	0.7937	0.8676	0.6385	0.9667	0.0121	0.8653	0.445	0.8	0.0094	0.8634	0.1063	384	188
Class 123	0.9333	0.7195	0.868	0.6389	0.4333	0.0117	0.8675	0.436	0.4	0.0093	0.8673	0.1085	384	192
Class 124	0.9667	0.7197	0.8678	0.6389	0.8	0.0107	0.8698	0.4238	0.7667	0.0094	0.8676	0.1104	384	202
Class 125	0.8667	0.5987	0.8684	0.6396	0.5667	0.0123	0.8707	0.4426	0.5667	0.0097	0.8691	0.1097	384	191
Class 126	0.9667	0.5716	0.8678	0.6397	0.5	0.0121	0.8644	0.4291	0.4333	0.0091	0.8603	0.1051	384	192
Class 127	1	0.8266	0.8676	0.6383	1	0.0106	0.8682	0.4938	0.9667	0.0098	0.8686	0.1095	384	193
Class 128	1	0.8069	0.8676	0.6384	0.8333	0.0186	0.8656	0.4761	0.4333	0.0095	0.8689	0.1069	384	200
Class 129	1	0.7665	0.8676	0.6386	0.9	0.0153	0.8684	0.445	0.6667	0.0101	0.8662	0.1116	384	180
Class 130	1	0.8077	0.8676	0.6384	0.9333	0.0179	0.8686	0.4905	0.7667	0.0098	0.866	0.1081	384	196
Class 131	0.9	0.6773	0.8682	0.6391	0.5333	0.0113	0.8729	0.4005	0.4333	0.0095	0.8698	0.1085	384	196
Class 132	1	0.7671	0.8676	0.6386	1	0.015	0.866	0.4853	0.7333	0.0091	0.8638	0.1097	384	197
Class 133	0.9	0.7215	0.8682	0.6389	0.7	0.0118	0.8691	0.4302	0.6333	0.0097	0.8686	0.1061	384	196
Class 134	0.9333	0.6043	0.868	0.6395	0.4	0.0162	0.8675	0.4471	0.3667	0.0091	0.8695	0.1119	384	199
Class 135	1	0.8334	0.8676	0.6383	1	0.0163	0.8669	0.5282	0.9667	0.0102	0.8676	0.1059	384	200
Class 136	0.9667	0.6949	0.8678	0.639	0.4667	0.0112	0.8673	0.4252	0.4667	0.0096	0.8669	0.108	384	202
Class 137	1	0.8039	0.8676	0.6384	0.7667	0.0128	0.8623	0.4372	0.6333	0.0098	0.8622	0.1063	384	192
Class 138	0.9333	0.7579	0.868	0.6387	0.8	0.0109	0.8664	0.4332	0.8333	0.0094	0.8684	0.1079	384	193
Class 139	0.9667	0.7284	0.8678	0.6388	0.5333	0.0126	0.8647	0.4628	0.4667	0.0093	0.8645	0.1119	384	193
Class 140	1	0.7881	0.8676	0.6385	0.6333	0.0104	0.8698	0.4396	0.5667	0.0092	0.8676	0.1095	384	178
Class 141	0.9667	0.7638	0.8678	0.6387	0.6333	0.0177	0.8693	0.501	0.5333	0.0094	0.8669	0.1061	384	191
Class 142	1	0.8219	0.8676	0.6383	0.9	0.0099	0.8653	0.4469	0.8667	0.0095	0.858	0.1049	384	200
Class 143	0.9667	0.7176	0.8678	0.6389	0.5667	0.0109	0.8678	0.4413	0.5667	0.0096	0.8687	0.1106	384	190
Class 144	0.9	0.5452	0.8682	0.6399	0.3333	0.0109	0.8616	0.4068	0.3333	0.0093	0.8631	0.11	384	190
Class 145	0.9	0.6503	0.8682	0.6393	0.3	0.0139	0.8709	0.4449	0.2333	0.0099	0.8645	0.1097	384	203
Class 146	0.9	0.7262	0.8682	0.6389	0.8667	0.0155	0.8682	0.478	0.7	0.0096	0.864	0.1099	384	193
Class 147	1	0.7295	0.8676	0.6388	0.5333	0.012	0.8669	0.4116	0.4333	0.0096	0.8702	0.1067	384	186
Class 148	0.8667	0.5723	0.8684	0.6397	0.4667	0.0099	0.862	0.3923	0.4667	0.0087	0.86	0.1029	384	195
Class 149	0.9667	0.738	0.8678	0.6388	0.6	0.0107	0.8654	0.3782	0.6	0.0097	0.8667	0.1118	384	199'''

def format_latex_row(line):
    # Split the input line into components
    parts = line.split()
    
    # Extract values, skipping the last two columns (384 and intersection)
    values = parts[2:-2]
    
    # Format class name
    class_name = f"{parts[0]} {parts[1]}"
    
    # Convert values to 3 decimal format
    formatted_values = [f"{float(val):.3f}" for val in values]
    
    # Remove leading zeros
    formatted_values = [val.replace('0.', '.') for val in formatted_values]
    
    # Combine into LaTeX format
    return f"{class_name} & " + " & ".join(formatted_values) + " \\\\"

# Example usage:

for line in data.split('\n'):
    print(format_latex_row(line))

Class 0 & .867 & .685 & .868 & .639 & .733 & .016 & .868 & .460 & .567 & .009 & .865 & .113 \\
Class 1 & 1.000 & .808 & .868 & .638 & .700 & .009 & .863 & .424 & .633 & .009 & .863 & .105 \\
Class 2 & .900 & .589 & .868 & .640 & .533 & .014 & .867 & .415 & .533 & .011 & .860 & .104 \\
Class 3 & 1.000 & .749 & .868 & .639 & .700 & .017 & .869 & .498 & .700 & .011 & .872 & .108 \\
Class 4 & .967 & .754 & .868 & .639 & .700 & .014 & .869 & .490 & .567 & .010 & .868 & .114 \\
Class 5 & 1.000 & .860 & .868 & .638 & 1.000 & .010 & .871 & .465 & 1.000 & .010 & .869 & .113 \\
Class 6 & .700 & .529 & .869 & .640 & .500 & .014 & .872 & .441 & .467 & .009 & .873 & .110 \\
Class 7 & .967 & .771 & .868 & .639 & .933 & .017 & .869 & .477 & .767 & .010 & .870 & .106 \\
Class 8 & .867 & .662 & .868 & .639 & .633 & .012 & .864 & .450 & .633 & .010 & .870 & .111 \\
Class 9 & .900 & .738 & .868 & .639 & .467 & .011 & .867 & .462 & .433 & .009 & .867 & .106 \\
Class 10 & 1.000 & .703 & .868 & .639 & .167 