In [None]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
from tqdm import tqdm
import torch
from utilities import get_model_distilbert, record_activations

mask_layer = 5
text_tag = "sentence"
compliment = True

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the dataset
dataset_all = load_dataset("stanfordnlp/sst2")
# Select the train split
dataset_all = dataset_all['train']
model = get_model_distilbert("distilbert-base-uncased-finetuned-sst-2-english", mask_layer)
model.to(device)
model.eval()
all_fc_vals = []
batch_size = 32  # You can adjust this based on your GPU memory
for j in range(0,2):
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    fc_vals = record_activations(dataset, model, tokenizer, text_tag='sentence', batch_size=256, mask_layer=mask_layer)
    all_fc_vals.append(fc_vals)

In [None]:
from datasets import load_dataset, concatenate_datasets

from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
import random
import numpy as np
from utilities import evaluate_gpt2_classification as evaluate_gpt2_classification, mask_range_gpt,compute_masks, reset_gpt
import torch  

dataset_name = "fancyzhx/dbpedia_14"

text_tag = "text"

# Load dataset and tokenizer


tables = []
layer = 11
# for layer in range(0,12):
per = 0.2
print("Percentage: ", per)
num_classes = 4

# tao = 2.5

lab = "label"
# tao = torch.inf

dataset = load_dataset(dataset_name)

print(dataset)
# Set random seed
seed_value = 42  # or any other integer

random.seed(seed_value)
np.random.seed(seed_value)

if torch.cuda.is_available():  # PyTorch-specific
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

import torch

torch.autograd.set_detect_anomaly(True)
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


special_tokens_dict = {}
new_tokens = []
label2text = dataset['train'].features[lab].names

for label in label2text:
    # Create special token format (with and without space)
    special_token = f'{label}'
    
    # Check if the label is already a single token in the tokenizer
    label_tokens = tokenizer.encode(label, add_special_tokens=False)
    is_single_token = len(label_tokens) == 1
    
    if is_single_token:
        print(f"'{label}' is already a single token (ID: {label_tokens[0]})")
    
    # Add both versions to new tokens list
    new_tokens.extend([special_token])

# Add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(new_tokens)
print(f"\nAdded {num_added_tokens} new tokens to the tokenizer")

special_tokens = {
    'pad_token': '<|pad|>',
    'sep_token': '<|sep|>',
    'eos_token': '<|eos|>'
}
tokenizer.add_special_tokens(special_tokens)

def format_data(examples):
    formatted_texts = []
    for text, label in zip(examples[text_tag], examples[lab]):
        # Convert label to string
        
        tok_text = tokenizer.encode(text, max_length=400, truncation=True)
        text = tokenizer.decode(tok_text)
        label_str = dataset['train'].features[lab].int2str(label)
        formatted_text = f"Classify emotion: {text}{tokenizer.sep_token}"#{label_str}{tokenizer.eos_token}"
        formatted_texts.append(formatted_text)
    return {'formatted_text': formatted_texts}

def tokenize_and_prepare(examples):

    # Tokenize with batch processing
    tokenized = tokenizer(
        examples['formatted_text'],
        padding='max_length',
        max_length=408,
        truncation=True,
        return_tensors="pt"
    )
    
    # Clone input_ids to create labels
    labels = tokenized['input_ids'].clone()
    
    # Find the position of sep_token
    sep_token_id = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
    sep_positions = (labels == sep_token_id).nonzero(as_tuple=True)
    
    # Mask all tokens with -100 except for the token right after sep_token
    labels[:] = -100  # Mask all initially
    for batch_idx, sep_pos in zip(*sep_positions):
        if sep_pos + 1 < labels.size(1):
            labels[batch_idx, sep_pos + 1] = tokenized['input_ids'][batch_idx, sep_pos + 1]
    
    # Set padding tokens to -100
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'labels': labels
    }
# Process the dataset
formatted_dataset = dataset.map(format_data, batched=True)
tokenized_dataset = formatted_dataset.map(
    tokenize_and_prepare, 
    batched=True,
)

from transformers import GPT2LMHeadModel as gt
from models.gpt2 import GPT2LMHeadModel
# Load pre-trained GPT-2 model
model1 = gt.from_pretrained('gpt2')

model1.resize_token_embeddings(len(tokenizer))

model1.config.m_layer = layer
import os

base_path = os.path.join("model_weights", dataset_name)
if not os.path.exists(base_path):
    os.makedirs(base_path)

weights_path = os.path.join(base_path, "weights.pth")

model = GPT2LMHeadModel(model1.config)


model.load_state_dict(torch.load(weights_path))
dataset_all = tokenized_dataset['train']

all_fc_vals = []
for j in range(0,num_classes):
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    fc_vals = evaluate_gpt2_classification(lab, model, dataset, tokenizer)
    print('Accuracy : ', fc_vals[0], 'Confidence : ', fc_vals[1])
    fc_vals = fc_vals[2]

Percentage:  0.2
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})
'World' is already a single token (ID: 10603)
'Sports' is already a single token (ID: 18153)
'Business' is already a single token (ID: 24749)

Added 1 new tokens to the tokenizer


  model.load_state_dict(torch.load(weights_path))


Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

Evaluating:   0%|          | 0/938 [00:00<?, ?it/s]

  input_ids = torch.tensor(item['input_ids']).to(device)
  attention_mask = torch.tensor(item['attention_mask']).to(device)


Accuracy :  0.9682 Confidence :  0.9625


Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

KeyboardInterrupt: 

: 

In [None]:
import numpy as np
from utilities import compute_masks

output_widgets = []
j = 10
max_all = []
for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    mini = np.min(v, axis=0)
    maxi = np.max(v, axis=0)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(v, 0.30)
    max_all.append(mask_max)


In [None]:
#print number of differences between the two classes
print(torch.sum(max_all[0]!=max_all[1]))
print(max_all[0]==max_all[1])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)    
    
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)

    # Create a new figure for each set of values
    out = Output()
    with out:
        plt.figure(figsize=(12, 6))

        # Plot the mean
        plt.subplot(1, 2, 1)
        plt.plot(m, 'bo', markersize=4)
        plt.title(f'Mean of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Mean Value')
        plt.ylim(0, np.max(m))  # Ensure y-axis starts at 0

        # Plot the standard deviation
        plt.subplot(1, 2, 2)
        plt.plot(s, 'ro', markersize=4)
        plt.title(f'Standard Deviation of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Standard Deviation')
        plt.ylim(0, np.max(s))  # Ensure y-axis starts at 0

        # Show the plots
        plt.tight_layout()
        plt.show()

    output_widgets.append(out)

# Display all figures in a vertical box
VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)
    # Create a new figure for each set of values
    out = Output()
    with out:
        fig = go.Figure()

        # Plot the mean with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=m,
            mode='markers',
            name='Mean',
            marker=dict(size=3, color='blue')
        ))

        # Plot the standard deviation with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=s,
            mode='markers',
            name='Std Dev',
            marker=dict(size=3, color='red')
        ))

        # Add lines connecting corresponding points
        for j in range(768):
            fig.add_trace(go.Scatter(
                x=[j, j],
                y=[m[j], s[j]],
                mode='lines',
                line=dict(color='gray', width=0.5),
                showlegend=False
            ))

        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Set {i+1}',
            xaxis_title='Activation Index',
            yaxis_title='Value',
            yaxis=dict(range=[0, max(np.max(m), np.max(s)) * 1.1]),
            height=600,
            width=1000,
            hovermode='closest'
        )

        fig.show()
    
    output_widgets.append(out)

# Display all figures in a vertical box
# VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
import importlib
import utilities

importlib.reload(utilities)
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_std_high_max = compute_masks(fc1, 0.15)
    mask_std = mask_std_high_max
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=m_norm[indices],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=s_norm[indices],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all plots
        add_traces(indices_max, 1, 1)
        add_traces(indices_std, 1, 2)
        add_traces(indices_intersection, 1, 3)
        add_traces(indices_max_minus_std, 2, 1)
        add_traces(indices_std_minus_max, 2, 2)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                if row == 2 and col == 3:
                    continue  # Skip the empty subplot
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std = compute_masks(fc1, 0.15)
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_all = len(m_norm)
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"All Activations (Count: {count_all})",
                                            f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            indices_list = list(indices)  # Convert range or numpy array to list
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=m_norm[indices_list],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=s_norm[indices_list],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices_list:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all activations
        add_traces(range(len(m_norm)), 1, 1)
        
        # Add traces for other plots
        add_traces(indices_max, 1, 2)
        add_traces(indices_std, 1, 3)
        add_traces(indices_intersection, 2, 1)
        add_traces(indices_max_minus_std, 2, 2)
        add_traces(indices_std_minus_max, 2, 3)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

def create_index_tracking_plot(indices_per_class, title):
    num_classes = len(indices_per_class)
    all_indices = sorted(set.union(*[set(indices) for indices in indices_per_class]))
    
    fig = go.Figure()
    
    # Create a color scale
    color_scale = px.colors.diverging.RdYlGn_r  # Red to Yellow to Green color scale

    # Add edges for indices present in multiple classes
    for idx in all_indices:
        classes_with_idx = [i for i, indices in enumerate(indices_per_class) if idx in indices]
        if len(classes_with_idx) > 1:
            x = [idx] * len(classes_with_idx)
            y = classes_with_idx
            color_index = (len(classes_with_idx) - 1) / (num_classes - 1)  # Normalize to [0, 1]
            edge_color = px.colors.sample_colorscale(color_scale, [color_index])[0]
            
            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='text',
                hovertext=f'Index: {idx}<br>Present in {len(classes_with_idx)} classes',
                showlegend=False
            ))
    
    # Add scatter plots for each class
    for class_idx, indices in enumerate(indices_per_class):
        fig.add_trace(go.Scatter(
            x=indices,
            y=[class_idx] * len(indices),
            mode='markers',
            name=f'Class {class_idx + 1}',
            marker=dict(size=4, symbol='circle', color='black'),
            hoverinfo='text',
            hovertext=[f'Index: {idx}<br>Class: {class_idx + 1}' for idx in indices]
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Activation Index',
        yaxis_title='Class',
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(num_classes)),
            ticktext=[f'Class {i+1}' for i in range(num_classes)]
        ),
        hovermode='closest',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    # Add color bar
    fig.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(
            colorscale=color_scale,
            showscale=True,
            cmin=1,
            cmax=num_classes,
            colorbar=dict(
                title='Number of Classes',
                tickvals=list(range(1, num_classes+1)),
                ticktext=list(range(1, num_classes+1))
            )
        ),
        hoverinfo='none',
        showlegend=False
    ))
    
    return fig

# Collect indices for each class
max_indices_per_class = []
std_indices_per_class = []

for fc1 in all_fc_vals:
    mask_max, mask_std = compute_masks(fc1, 0.15)
    max_indices_per_class.append(np.where(mask_max == 0)[0])
    std_indices_per_class.append(np.where(mask_std == 0)[0])

# Create and display visualizations
output_widgets = []

out = Output()
with out:
    fig_max = create_index_tracking_plot(max_indices_per_class, 'Max Mask Indices Across Classes')
    display(fig_max)
output_widgets.append(out)

out = Output()
with out:
    fig_std = create_index_tracking_plot(std_indices_per_class, 'Std Mask Indices Across Classes')
    display(fig_std)
output_widgets.append(out)

# Display all visualizations
# display(VBox(output_widgets))

In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch
from prettytable import PrettyTable
# from model_distill_bert import getmodel
from utilities import compute_accuracy, compute_masks, mask_distillbert, get_model_distilbert, record_activations

batch_size = 256
mask_layer = 5
text_tag = "text"
compliment = True
results_table = PrettyTable()
if(compliment):
   results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf", "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf", "Total Masked", "Intersedction"]#, "Same as Max"]#"MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"
# results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "STD Accuracy", "STD Confidence", "Same as Max"]#, "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"]

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []

dataset_list = []
tokenizer = AutoTokenizer.from_pretrained("2O24dpower2024/distilbert-base-uncased-finetuned-emotion")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load the dataset
dataset_all = load_dataset("dair-ai/emotion")
dataset_all = dataset_all['train']

for j in range(0,7):
    # model = get_model_distilbert("esuriddick/distilbert-base-uncased-finetuned-emotion", mask_layer)
    
    model = get_model_distilbert("2O24dpower2024/distilbert-base-uncased-finetuned-emotion", mask_layer)
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    dataset_complement = dataset_all.filter(lambda x: x['label'] not in [j])
    
    if(j==6):
        dataset = dataset_all

    class_labels.append(f"Class {j}")
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size)
    dataset_list.append(acc[2])
    print("Class ",j, "base accuracy: ", acc[0], acc[1])
    base_accuracies.append(acc[0])
    base_confidences.append(acc[1])
    aug_dataset = acc[2]
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag , batch_size=batch_size)
        print("Class ",j, "complement base accuracy: ", acc[0], acc[1])
        base_comp_acc.append(acc[0])
        base_comp_conf.append(acc[1])
        aug_dataset.extend(acc[2])
        

    #record the activations of the first fully connected layer, CLS tokken
    print("Recording activations...")
    # progress_bar = tqdm(total=len(dataset))
    # model.to(device)
    # model.eval()
    # fc_vals = []
    # with torch.no_grad():
    #     for i in range(len(dataset)):
    #         text = dataset[i]['sentence']
    #         inputs = tokenizer(text, return_tensors="pt").to(device)
    #         outputs = model(**inputs)
    #         fc_vals.append(outputs[1][mask_layer+1][:, 0].squeeze().cpu().numpy())
    #         progress_bar.update(1)
    #     progress_bar.close()

    fc_vals = record_activations(dataset, model, tokenizer, text_tag=text_tag, mask_layer=mask_layer, batch_size=batch_size)

        
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc_vals,0.50)
    mask_std = mask_std_high_max
    print("Masking STD...")
    model = mask_distillbert(model,mask_std)
    t = int(mask_std.shape[0]-torch.count_nonzero(mask_std))
    print("Total Masked :", t)
    total_masked.append(t)
    diff_from_max.append(int((torch.logical_or(mask_std, mask_max) == 0).sum().item()))
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[:len(dataset)]) 
    dataset_list.append(acc[2])
    print("accuracy after masking STD: ", acc[0], acc[1])
    std_accuracies.append(acc[0])
    std_confidences.append(acc[1])
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[len(dataset):])
        print("accuracy after masking STD on complement: ", acc[0], acc[1])
        std_comp_acc.append(acc[0])
        std_comp_conf.append(acc[1])

    print("Masking MAX...")
    model = mask_distillbert(model,mask_max)
    t = int(mask_max.shape[0]-torch.count_nonzero(mask_max))
    print("Total Masked :", t)
    # total_masked.append(t)
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[:len(dataset)])
    dataset_list.append(acc[2])
    print("accuracy after masking MAX: ", acc[0], acc[1])
    max_accuracies.append(acc[0])
    max_confidences.append(acc[1])
    acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[len(dataset):])
    print("accuracy after masking MAX on complement: ", acc[0], acc[1])
    max_comp_acc.append(acc[0])
    max_comp_conf.append(acc[1])
    if(compliment):
        results_table.add_row([
            class_labels[j],
            base_accuracies[j],
            base_confidences[j],
            base_comp_acc[j],
            base_comp_conf[j],
            std_accuracies[j],
            std_confidences[j],
            std_comp_acc[j],
            std_comp_conf[j],
            max_accuracies[j],
            max_confidences[j],
            max_comp_acc[j],
            max_comp_conf[j],
            total_masked[j],
            diff_from_max[j]
        ])
    # results_table.add_row([
    #     class_labels[j],
    #     base_accuracies[j],
    #     base_confidences[j],
    #     std_accuracies[j],
    #     std_confidences[j],
    #     # max_accuracies[j],
    #     # max_confidences[j],
    #     diff_from_max[j]
    # ])

print(results_table)

In [None]:
import ipywidgets as widgets
from IPython.display import display, HTML

# Assuming you already have your DataFrame
# df = pd.DataFrame(dataset_list[1])

def display_df(dataframe, rows_per_page=10):
    # Convert the dataframe to HTML
    html = dataframe.to_html(classes='table table-striped')
    
    # Add Bootstrap CSS
    html = f"""
    <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css">
    <div class="container">
        {html}
    </div>
    """
    
    # Add pagination
    total_rows = len(dataframe)
    total_pages = (total_rows - 1) // rows_per_page + 1
    
    pagination_html = f"""
    <nav>
        <ul class="pagination justify-content-center">
            <li class="page-item"><a class="page-link" href="#" id="prev-page">Previous</a></li>
            <li class="page-item"><span class="page-link" id="current-page">1 / {total_pages}</span></li>
            <li class="page-item"><a class="page-link" href="#" id="next-page">Next</a></li>
        </ul>
    </nav>
    """
    
    # Add JavaScript for pagination functionality
    js = f"""
    <script>
        var currentPage = 1;
        var rowsPerPage = {rows_per_page};
        var totalPages = {total_pages};
        
        function showPage(page) {{
            var rows = document.querySelectorAll('table.table tbody tr');
            for (var i = 0; i < rows.length; i++) {{
                if (i >= (page - 1) * rowsPerPage && i < page * rowsPerPage) {{
                    rows[i].style.display = '';
                }} else {{
                    rows[i].style.display = 'none';
                }}
            }}
            document.getElementById('current-page').textContent = page + ' / ' + totalPages;
        }}
        
        document.getElementById('prev-page').addEventListener('click', function(e) {{
            e.preventDefault();
            if (currentPage > 1) {{
                currentPage--;
                showPage(currentPage);
            }}
        }});
        
        document.getElementById('next-page').addEventListener('click', function(e) {{
            e.preventDefault();
            if (currentPage < totalPages) {{
                currentPage++;
                showPage(currentPage);
            }}
        }});
        
        showPage(1);
    </script>
    """
    
    # Combine all HTML and JavaScript
    full_html = html + pagination_html + js
    
    # Display the result
    display(HTML(full_html))

In [None]:
import pandas as pd

df = pd.DataFrame(dataset_list[5])

In [None]:
# df = df.iloc[::-1]
display_df(df)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("esuriddick/distilbert-base-uncased-finetuned-emotion")

In [2]:
data = '''Class 0	0.9694	0.9559	0.913	0.9032	0.6781	0.6527	0.8445	0.8386	0.9828	0.9139	0.8465	0.7929	225	225
Class 1	0.9385	0.9382	0.9248	0.9085	0.8778	0.873	0.9282	0.9234	0.9409	0.9255	0.9101	0.8821	222	222
Class 2	0.9021	0.8718	0.9318	0.9226	0.7902	0.786	0.7619	0.7582	0.7028	0.6849	0.9025	0.8274	204	204
Class 3	0.9097	0.9055	0.9325	0.9206	0.4702	0.4671	0.9358	0.9282	0.8912	0.8946	0.93	0.9007	227	227
Class 4	0.869	0.8537	0.9378	0.9274	0.7471	0.7041	0.9017	0.8963	0.8506	0.8511	0.9248	0.8871	234	234
Class 5	0.8571	0.7979	0.9321	0.9229	0.254	0.2088	0.9036	0.8932	0.9603	0.8471	0.9174	0.8779	239	239'''

def format_latex_row(line):
    # Split the input line into components
    parts = line.split()
    
    # Extract values, skipping the last two columns (384 and intersection)
    values = parts[2:-2]
    
    # Format class name
    class_name = f"{parts[0]} {parts[1]}"
    
    # Convert values to 3 decimal format
    formatted_values = [f"{float(val):.3f}" for val in values]
    
    # Remove leading zeros
    formatted_values = [val.replace('0.', '.') for val in formatted_values]
    
    # Combine into LaTeX format
    return f"{class_name} & " + " & ".join(formatted_values) + " \\\\"

# Example usage:

for line in data.split('\n'):
    print(format_latex_row(line))

Class 0 & .969 & .956 & .913 & .903 & .678 & .653 & .845 & .839 & .983 & .914 & .847 & .793 \\
Class 1 & .939 & .938 & .925 & .908 & .878 & .873 & .928 & .923 & .941 & .925 & .910 & .882 \\
Class 2 & .902 & .872 & .932 & .923 & .790 & .786 & .762 & .758 & .703 & .685 & .902 & .827 \\
Class 3 & .910 & .905 & .932 & .921 & .470 & .467 & .936 & .928 & .891 & .895 & .930 & .901 \\
Class 4 & .869 & .854 & .938 & .927 & .747 & .704 & .902 & .896 & .851 & .851 & .925 & .887 \\
Class 5 & .857 & .798 & .932 & .923 & .254 & .209 & .904 & .893 & .960 & .847 & .917 & .878 \\


In [1]:
from datasets import load_dataset, concatenate_datasets

from sklearn.model_selection import train_test_split
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
import random
import numpy as np
from utilities import evaluate_gpt2_classification as evaluate_gpt2_classification, mask_range_gpt,compute_masks, reset_gpt, compute_mask_probe
import torch  
from tqdm import tqdm

dataset_name = "dair-ai/emotion"

text_tag = "text"

# Load dataset and tokenizer


tables = []
layer = 11
# for i in tqdm(range(1, 21)):
per = 0.3
print("Percentage: ", per)
num_classes = 6

# tao = 2.5

lab = "label"
# tao = torch.inf

dataset = load_dataset(dataset_name)

print(dataset)
# Set random seed
seed_value = 42  # or any other integer

random.seed(seed_value)
np.random.seed(seed_value)

if torch.cuda.is_available():  # PyTorch-specific
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

import torch

torch.autograd.set_detect_anomaly(True)
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')


special_tokens_dict = {}
new_tokens = []
label2text = dataset['train'].features[lab].names

for label in label2text:
    # Create special token format (with and without space)
    special_token = f'{label}'
    
    # Check if the label is already a single token in the tokenizer
    label_tokens = tokenizer.encode(label, add_special_tokens=False)
    is_single_token = len(label_tokens) == 1
    
    if is_single_token:
        print(f"'{label}' is already a single token (ID: {label_tokens[0]})")
    
    # Add both versions to new tokens list
    new_tokens.extend([special_token])

# Add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(new_tokens)
print(f"\nAdded {num_added_tokens} new tokens to the tokenizer")

special_tokens = {
    'pad_token': '<|pad|>',
    'sep_token': '<|sep|>',
    'eos_token': '<|eos|>'
}
tokenizer.add_special_tokens(special_tokens)

def format_data(examples):
    formatted_texts = []
    for text, label in zip(examples[text_tag], examples[lab]):
        # Convert label to string
        
        tok_text = tokenizer.encode(text, max_length=400, truncation=True)
        text = tokenizer.decode(tok_text)
        label_str = dataset['train'].features[lab].int2str(label)
        formatted_text = f"Classify emotion: {text}{tokenizer.sep_token}"#{label_str}{tokenizer.eos_token}"
        formatted_texts.append(formatted_text)
    return {'formatted_text': formatted_texts}

def tokenize_and_prepare(examples):

    # Tokenize with batch processing
    tokenized = tokenizer(
        examples['formatted_text'],
        padding='max_length',
        max_length=408,
        truncation=True,
        return_tensors="pt"
    )
    
    # Clone input_ids to create labels
    labels = tokenized['input_ids'].clone()
    
    # Find the position of sep_token
    sep_token_id = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
    sep_positions = (labels == sep_token_id).nonzero(as_tuple=True)
    
    # Mask all tokens with -100 except for the token right after sep_token
    labels[:] = -100  # Mask all initially
    for batch_idx, sep_pos in zip(*sep_positions):
        if sep_pos + 1 < labels.size(1):
            labels[batch_idx, sep_pos + 1] = tokenized['input_ids'][batch_idx, sep_pos + 1]
    
    # Set padding tokens to -100
    labels[labels == tokenizer.pad_token_id] = -100
    
    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'labels': labels
    }
# Process the dataset
formatted_dataset = dataset.map(format_data, batched=True)
tokenized_dataset = formatted_dataset.map(
    tokenize_and_prepare, 
    batched=True,
)

from transformers import GPT2LMHeadModel as gt
from models.gpt2 import GPT2LMHeadModel
# Load pre-trained GPT-2 model
model1 = gt.from_pretrained('gpt2')

model1.resize_token_embeddings(len(tokenizer))

model1.config.m_layer = layer
import os

base_path = os.path.join("model_weights", dataset_name)
if not os.path.exists(base_path):
    os.makedirs(base_path)

weights_path = os.path.join(base_path, "weights.pth")

model = GPT2LMHeadModel(model1.config)


model.load_state_dict(torch.load(weights_path))




from prettytable import PrettyTable
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch.tensor")

batch_size = 2048/2
# mask_layer = 5
compliment = True
results_table = PrettyTable()
if(compliment):
    results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf", "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf", "Total Masked", "Intersection"]#, "Same as Max"]#"MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []

#merge test and train set and then shuffle and make splits

# First merge and shuffle
tokenized_dataset = concatenate_datasets([tokenized_dataset['train'], tokenized_dataset['test']]).shuffle(seed=42)#.select(range(100))

# Get the total length
dataset_length = len(tokenized_dataset)


# Calculate split index
split_index = int(dataset_length * 0.2)  # 80% for training

# Create the splits using dataset slicing
tokenized_dataset1 = tokenized_dataset.select(range(split_index))  # training set
recording_dataset = tokenized_dataset.select(range(split_index, dataset_length))




all_fc_vals = []
print("Recording activations...")
for j in range(0,num_classes):
    dataset_recording = recording_dataset.filter(lambda x: x[lab] in [j])
    fc_vals = evaluate_gpt2_classification(lab, model, dataset_recording, tokenizer)
    fc_vals = fc_vals[2]
    all_fc_vals.append(fc_vals)
    
# save all_fc_vals to file
import pickle

with open('all_fc_vals.pkl', 'wb') as f:
    pickle.dump(all_fc_vals, f)
    
# Load all_fc_vals from file
with open('all_fc_vals.pkl', 'rb') as f:
    all_fc_vals = pickle.load(f)

# from plot_correlation import create_correlation_plots

# create_correlation_plots(all_fc_vals, num_classes)
for i in range(1, 21):
    per = 0.05 * i
    print(f"\nAnalysis for Percentage: {per:.2f}")
    
    # Create table for this percentage
    table = PrettyTable()
    table.field_names = ["Class", "Total Masked", "Common Zeros", "Common Zeros %"]
    
    for j in range(0, num_classes):
        # Compute masks
        mask_max, mask_std, mask_intersection, mask_max_low_std, \
        mask_max_high_std, mask_std_high_max, mask_max_random_off, \
        random_mask = compute_masks(all_fc_vals[j], per)
        
        # Calculate intersections and zeros
        intersection_zeros = (~mask_std.bool()) & (~mask_max.bool())
        num_common_zeros = intersection_zeros.sum().item()
        total_masked = int(mask_std.shape[0] - torch.count_nonzero(mask_std))
        percentage_common_zeros = num_common_zeros / total_masked
        
        # Add row to table
        table.add_row([
            j,
            total_masked,
            num_common_zeros,
            f"{percentage_common_zeros:.4f}"
        ])
    
    print(table)
    print("-" * 50)  # Separator between tables
    
    
total_masked = []
# print("Training Probe...")

# from linear_probe import train_probe

# probe = train_probe(all_fc_vals,lambda_l1 = 1e-4, lambda_l2 = 1e-4)

# print("Probe Trained")

# probe.to('cpu')

# probe_weights = list(probe.parameters())[0]
# tokenized_dataset1 = tokenized_dataset['test']#.shuffle().select(range(200))
# recording_dataset = tokenized_dataset['train']#.shuffle().select(range(200))

        

Percentage:  0.3
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})
'joy' is already a single token (ID: 2633)
'love' is already a single token (ID: 23205)
'anger' is already a single token (ID: 2564)

Added 3 new tokens to the tokenizer


  model.load_state_dict(torch.load(weights_path))


Recording activations...


Evaluating:   0%|          | 0/132 [00:00<?, ?it/s]

  input_ids = torch.tensor(item['input_ids']).to(device)
  attention_mask = torch.tensor(item['attention_mask']).to(device)


Evaluating:   0%|          | 0/152 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/37 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/61 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]


Analysis for Percentage: 0.05
+-------+--------------+--------------+----------------+
| Class | Total Masked | Common Zeros | Common Zeros % |
+-------+--------------+--------------+----------------+
|   0   |      38      |      12      |     0.3158     |
|   1   |      38      |      9       |     0.2368     |
|   2   |      38      |      8       |     0.2105     |
|   3   |      38      |      9       |     0.2368     |
|   4   |      38      |      10      |     0.2632     |
|   5   |      38      |      14      |     0.3684     |
+-------+--------------+--------------+----------------+
--------------------------------------------------

Analysis for Percentage: 0.10
+-------+--------------+--------------+----------------+
| Class | Total Masked | Common Zeros | Common Zeros % |
+-------+--------------+--------------+----------------+
|   0   |      76      |      18      |     0.2368     |
|   1   |      76      |      19      |     0.2500     |
|   2   |      76      |      23

In [29]:
class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []
for j in range(0,2):
    
    fc_vals = all_fc_vals[j]
    model = reset_gpt(model)
    dataset = tokenized_dataset1.filter(lambda x: x[lab] in [j])
    dataset_recording = recording_dataset.filter(lambda x: x[lab] in [j])
    dataset_complement = tokenized_dataset1.filter(lambda x: x[lab] not in [j])
    

    class_labels.append(f"Class {j}")
    acc = evaluate_gpt2_classification(lab, model, dataset, tokenizer)
    print("Class ",j, "base accuracy: ", acc[0], acc[1])
    base_accuracies.append(acc[0])
    base_confidences.append(acc[1])
    if(compliment):
        acc = evaluate_gpt2_classification(lab, model, dataset_complement, tokenizer)
        print("Class ",j, "complement base accuracy: ", acc[0], acc[1])
        base_comp_acc.append(acc[0])
        base_comp_conf.append(acc[1])
    # print("Recording activations...")
    # fc_vals = evaluate_gpt2_classi
        
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max, mask_max_random_off, random_mask = compute_masks(fc_vals,0.3)
    # mask_probe = compute_mask_probe(probe_weights[j], 0.3)
    # mask_std = mask_max_low_std
    print("Masking Probe...")
    # model = mask_distillbert(model,mask_std)
 
    
    # diff_from_max.append(int((torch.logical_or(mask_std, mask_max) == 0).sum().item()))
    
    intersection_zeros = (~mask_std.bool()) & (~mask_max.bool())
    
    mask_std[~intersection_zeros] = 1
    
    intersection_indices = torch.where(intersection_zeros)
    
    fc_vals2 = torch.tensor(fc_vals)
    
    num_common_zeros = intersection_zeros.sum().item()
    
    mean = torch.mean(fc_vals2, axis=0)
    
    mean[intersection_indices] = -torch.inf
    
    mask_max_indices = torch.argsort(mean, descending=True)[:num_common_zeros]
    
    # mask_max_indices = mask_max_indices-intersection_indices
        
    mask_max = torch.ones_like(mask_max)
    
    mask_max[mask_max_indices] = 0.0
    
    
    
    print("Number of non zeros: ", torch.count_nonzero(mask_max))
    print("Number of common zeros: ", num_common_zeros)
    
    diff_from_max.append(num_common_zeros)
    
    
    
    
    
    tao = torch.inf#2.5
    model = mask_range_gpt(model, mask_std, fc_vals, tao)        
    t = int(mask_std.shape[0]-torch.count_nonzero(mask_std))
    print("Total Masked :", t)
    total_masked.append(t)
    
    
    acc = evaluate_gpt2_classification(lab, model, dataset, tokenizer) 
    print("accuracy after masking STD: ", acc[0], acc[1])
    std_accuracies.append(acc[0])
    std_confidences.append(acc[1])
    if(compliment):
        acc = evaluate_gpt2_classification(lab, model, dataset_complement, tokenizer)
        print("accuracy after masking STD on complement: ", acc[0], acc[1])
        std_comp_acc.append(acc[0])
        std_comp_conf.append(acc[1])
    model = reset_gpt(model)
    tao = torch.inf
    print("Masking MAX...")
    # model = mask_distillbert(model,mask_max) 
    model = mask_range_gpt(model, mask_max, fc_vals, tao)
    t = int(mask_max.shape[0]-torch.count_nonzero(mask_max))
    print("Total Masked :", t)
    acc = evaluate_gpt2_classification(lab, model, dataset, tokenizer)
    print("accuracy after masking MAX: ", acc[0], acc[1])
    max_accuracies.append(acc[0])
    max_confidences.append(acc[1])
    acc = evaluate_gpt2_classification(lab, model, dataset_complement, tokenizer)
    print("accuracy after masking MAX on complement: ", acc[0], acc[1])
    max_comp_acc.append(acc[0])
    max_comp_conf.append(acc[1])
    if(compliment):
        results_table.add_row([
            class_labels[j],
            base_accuracies[j],
            base_confidences[j],
            base_comp_acc[j],
            base_comp_conf[j],
            std_accuracies[j],
            std_confidences[j],
            std_comp_acc[j],
            std_comp_conf[j],
            max_accuracies[j],
            max_confidences[j],
            max_comp_acc[j],
            max_comp_conf[j],
            total_masked[j],
            diff_from_max[j]
        ])            
# print("Layer ", mask_layer)
print(results_table)
tables.append(results_table)
# print("Layer ", mask_layer)
print("Average Base Accuracy: ",round(sum(base_accuracies)/len(base_accuracies), 4))
print("Average Base Confidence: ", round(sum(base_confidences)/len(base_confidences), 4))
print("Average MAX Accuracy: ", round(sum(max_accuracies)/len(max_accuracies), 4))
print("Average MAX Confidence: ", round(sum(max_confidences)/len(max_confidences), 4))
print("Average MAX Complement Accuracy: ", round(sum(max_comp_acc)/len(max_comp_acc), 4))
print("Average MAX Complement Confidence: ", round(sum(max_comp_conf)/len(max_comp_conf), 4))

# for table in tables:
# print(table)
# print("\n")

Evaluating:   0%|          | 0/33 [00:00<?, ?it/s]

Class  0 base accuracy:  0.9694 0.9559


Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Class  0 complement base accuracy:  0.913 0.9032
Masking Probe...


KeyboardInterrupt: 

In [25]:
total_masked

[82, 87, 84]

In [21]:
std_comp_conf

[0.8284, 0.9164, 0.9289]

In [22]:
max_confidences

[0.8324, 0.7941, 0.7012]

In [23]:
max_comp_conf

[0.9016, 0.8985, 0.9089]