In [None]:
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
from tqdm import tqdm
import torch
from utilities import get_model_distilbert, record_activations

mask_layer = 5
text_tag = "sentence"
compliment = True

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the dataset
dataset_all = load_dataset("stanfordnlp/sst2")
# Select the train split
dataset_all = dataset_all['train']
model = get_model_distilbert("distilbert-base-uncased-finetuned-sst-2-english", mask_layer)
model.to(device)
model.eval()
all_fc_vals = []
batch_size = 32  # You can adjust this based on your GPU memory
for j in range(0,2):
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    fc_vals = record_activations(dataset, model, tokenizer, text_tag='sentence', batch_size=256, mask_layer=mask_layer)
    all_fc_vals.append(fc_vals)

In [None]:
import numpy as np
from utilities import compute_masks

output_widgets = []
j = 10
max_all = []
for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    mini = np.min(v, axis=0)
    maxi = np.max(v, axis=0)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(v, 0.30)
    max_all.append(mask_max)


In [None]:
#print number of differences between the two classes
print(torch.sum(max_all[0]!=max_all[1]))
print(max_all[0]==max_all[1])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)    
    
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)

    # Create a new figure for each set of values
    out = Output()
    with out:
        plt.figure(figsize=(12, 6))

        # Plot the mean
        plt.subplot(1, 2, 1)
        plt.plot(m, 'bo', markersize=4)
        plt.title(f'Mean of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Mean Value')
        plt.ylim(0, np.max(m))  # Ensure y-axis starts at 0

        # Plot the standard deviation
        plt.subplot(1, 2, 2)
        plt.plot(s, 'ro', markersize=4)
        plt.title(f'Standard Deviation of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Standard Deviation')
        plt.ylim(0, np.max(s))  # Ensure y-axis starts at 0

        # Show the plots
        plt.tight_layout()
        plt.show()

    output_widgets.append(out)

# Display all figures in a vertical box
VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)
    # Create a new figure for each set of values
    out = Output()
    with out:
        fig = go.Figure()

        # Plot the mean with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=m,
            mode='markers',
            name='Mean',
            marker=dict(size=3, color='blue')
        ))

        # Plot the standard deviation with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=s,
            mode='markers',
            name='Std Dev',
            marker=dict(size=3, color='red')
        ))

        # Add lines connecting corresponding points
        for j in range(768):
            fig.add_trace(go.Scatter(
                x=[j, j],
                y=[m[j], s[j]],
                mode='lines',
                line=dict(color='gray', width=0.5),
                showlegend=False
            ))

        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Set {i+1}',
            xaxis_title='Activation Index',
            yaxis_title='Value',
            yaxis=dict(range=[0, max(np.max(m), np.max(s)) * 1.1]),
            height=600,
            width=1000,
            hovermode='closest'
        )

        fig.show()
    
    output_widgets.append(out)

# Display all figures in a vertical box
# VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
import importlib
import utilities

importlib.reload(utilities)
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_std_high_max = compute_masks(fc1, 0.15)
    mask_std = mask_std_high_max
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=m_norm[indices],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=s_norm[indices],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all plots
        add_traces(indices_max, 1, 1)
        add_traces(indices_std, 1, 2)
        add_traces(indices_intersection, 1, 3)
        add_traces(indices_max_minus_std, 2, 1)
        add_traces(indices_std_minus_max, 2, 2)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                if row == 2 and col == 3:
                    continue  # Skip the empty subplot
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std = compute_masks(fc1, 0.15)
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_all = len(m_norm)
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"All Activations (Count: {count_all})",
                                            f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            indices_list = list(indices)  # Convert range or numpy array to list
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=m_norm[indices_list],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=s_norm[indices_list],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices_list:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all activations
        add_traces(range(len(m_norm)), 1, 1)
        
        # Add traces for other plots
        add_traces(indices_max, 1, 2)
        add_traces(indices_std, 1, 3)
        add_traces(indices_intersection, 2, 1)
        add_traces(indices_max_minus_std, 2, 2)
        add_traces(indices_std_minus_max, 2, 3)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

def create_index_tracking_plot(indices_per_class, title):
    num_classes = len(indices_per_class)
    all_indices = sorted(set.union(*[set(indices) for indices in indices_per_class]))
    
    fig = go.Figure()
    
    # Create a color scale
    color_scale = px.colors.diverging.RdYlGn_r  # Red to Yellow to Green color scale

    # Add edges for indices present in multiple classes
    for idx in all_indices:
        classes_with_idx = [i for i, indices in enumerate(indices_per_class) if idx in indices]
        if len(classes_with_idx) > 1:
            x = [idx] * len(classes_with_idx)
            y = classes_with_idx
            color_index = (len(classes_with_idx) - 1) / (num_classes - 1)  # Normalize to [0, 1]
            edge_color = px.colors.sample_colorscale(color_scale, [color_index])[0]
            
            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='text',
                hovertext=f'Index: {idx}<br>Present in {len(classes_with_idx)} classes',
                showlegend=False
            ))
    
    # Add scatter plots for each class
    for class_idx, indices in enumerate(indices_per_class):
        fig.add_trace(go.Scatter(
            x=indices,
            y=[class_idx] * len(indices),
            mode='markers',
            name=f'Class {class_idx + 1}',
            marker=dict(size=4, symbol='circle', color='black'),
            hoverinfo='text',
            hovertext=[f'Index: {idx}<br>Class: {class_idx + 1}' for idx in indices]
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Activation Index',
        yaxis_title='Class',
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(num_classes)),
            ticktext=[f'Class {i+1}' for i in range(num_classes)]
        ),
        hovermode='closest',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    # Add color bar
    fig.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(
            colorscale=color_scale,
            showscale=True,
            cmin=1,
            cmax=num_classes,
            colorbar=dict(
                title='Number of Classes',
                tickvals=list(range(1, num_classes+1)),
                ticktext=list(range(1, num_classes+1))
            )
        ),
        hoverinfo='none',
        showlegend=False
    ))
    
    return fig

# Collect indices for each class
max_indices_per_class = []
std_indices_per_class = []

for fc1 in all_fc_vals:
    mask_max, mask_std = compute_masks(fc1, 0.15)
    max_indices_per_class.append(np.where(mask_max == 0)[0])
    std_indices_per_class.append(np.where(mask_std == 0)[0])

# Create and display visualizations
output_widgets = []

out = Output()
with out:
    fig_max = create_index_tracking_plot(max_indices_per_class, 'Max Mask Indices Across Classes')
    display(fig_max)
output_widgets.append(out)

out = Output()
with out:
    fig_std = create_index_tracking_plot(std_indices_per_class, 'Std Mask Indices Across Classes')
    display(fig_std)
output_widgets.append(out)

# Display all visualizations
# display(VBox(output_widgets))

In [3]:
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch
from prettytable import PrettyTable
# from model_distill_bert import getmodel
from utilities import compute_accuracy, compute_masks, mask_distillbert, get_model_distilbert, record_activations

batch_size = 256
mask_layer = 5
text_tag = "text"
compliment = True
results_table = PrettyTable()
if(compliment):
   results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf", "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf", "Total Masked", "Intersedction"]#, "Same as Max"]#"MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"
# results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "STD Accuracy", "STD Confidence", "Same as Max"]#, "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"]

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []

dataset_list = []
tokenizer = AutoTokenizer.from_pretrained("2O24dpower2024/distilbert-base-uncased-finetuned-emotion")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load the dataset
dataset_all = load_dataset("dair-ai/emotion")
dataset_all = dataset_all['train']

for j in range(0,7):
    # model = get_model_distilbert("esuriddick/distilbert-base-uncased-finetuned-emotion", mask_layer)
    
    model = get_model_distilbert("2O24dpower2024/distilbert-base-uncased-finetuned-emotion", mask_layer)
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    dataset_complement = dataset_all.filter(lambda x: x['label'] not in [j])
    
    if(j==6):
        dataset = dataset_all

    class_labels.append(f"Class {j}")
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size)
    dataset_list.append(acc[2])
    print("Class ",j, "base accuracy: ", acc[0], acc[1])
    base_accuracies.append(acc[0])
    base_confidences.append(acc[1])
    aug_dataset = acc[2]
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag , batch_size=batch_size)
        print("Class ",j, "complement base accuracy: ", acc[0], acc[1])
        base_comp_acc.append(acc[0])
        base_comp_conf.append(acc[1])
        aug_dataset.extend(acc[2])
        

    #record the activations of the first fully connected layer, CLS tokken
    print("Recording activations...")
    # progress_bar = tqdm(total=len(dataset))
    # model.to(device)
    # model.eval()
    # fc_vals = []
    # with torch.no_grad():
    #     for i in range(len(dataset)):
    #         text = dataset[i]['sentence']
    #         inputs = tokenizer(text, return_tensors="pt").to(device)
    #         outputs = model(**inputs)
    #         fc_vals.append(outputs[1][mask_layer+1][:, 0].squeeze().cpu().numpy())
    #         progress_bar.update(1)
    #     progress_bar.close()

    fc_vals = record_activations(dataset, model, tokenizer, text_tag=text_tag, mask_layer=mask_layer, batch_size=batch_size)

        
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc_vals,0.50)
    mask_std = mask_std_high_max
    print("Masking STD...")
    model = mask_distillbert(model,mask_std)
    t = int(mask_std.shape[0]-torch.count_nonzero(mask_std))
    print("Total Masked :", t)
    total_masked.append(t)
    diff_from_max.append(int((torch.logical_or(mask_std, mask_max) == 0).sum().item()))
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[:len(dataset)]) 
    dataset_list.append(acc[2])
    print("accuracy after masking STD: ", acc[0], acc[1])
    std_accuracies.append(acc[0])
    std_confidences.append(acc[1])
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[len(dataset):])
        print("accuracy after masking STD on complement: ", acc[0], acc[1])
        std_comp_acc.append(acc[0])
        std_comp_conf.append(acc[1])

    print("Masking MAX...")
    model = mask_distillbert(model,mask_max)
    t = int(mask_max.shape[0]-torch.count_nonzero(mask_max))
    print("Total Masked :", t)
    # total_masked.append(t)
    acc = compute_accuracy(dataset, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[:len(dataset)])
    dataset_list.append(acc[2])
    print("accuracy after masking MAX: ", acc[0], acc[1])
    max_accuracies.append(acc[0])
    max_confidences.append(acc[1])
    acc = compute_accuracy(dataset_complement, model, tokenizer, text_tag, batch_size=batch_size, in_aug_dataset=aug_dataset[len(dataset):])
    print("accuracy after masking MAX on complement: ", acc[0], acc[1])
    max_comp_acc.append(acc[0])
    max_comp_conf.append(acc[1])
    if(compliment):
        results_table.add_row([
            class_labels[j],
            base_accuracies[j],
            base_confidences[j],
            base_comp_acc[j],
            base_comp_conf[j],
            std_accuracies[j],
            std_confidences[j],
            std_comp_acc[j],
            std_comp_conf[j],
            max_accuracies[j],
            max_confidences[j],
            max_comp_acc[j],
            max_comp_conf[j],
            total_masked[j],
            diff_from_max[j]
        ])
    # results_table.add_row([
    #     class_labels[j],
    #     base_accuracies[j],
    #     base_confidences[j],
    #     std_accuracies[j],
    #     std_confidences[j],
    #     # max_accuracies[j],
    #     # max_confidences[j],
    #     diff_from_max[j]
    # ])

print(results_table)

  model.load_state_dict(torch.load(weights_path))
100%|██████████| 4666/4666 [00:02<00:00, 2151.11it/s]


4608 4666
Class  0 base accuracy:  0.9781 0.9671


100%|██████████| 11334/11334 [00:04<00:00, 2340.96it/s]


11264 11334
Class  0 complement base accuracy:  0.9288 0.9165
Recording activations...


100%|██████████| 4666/4666 [00:01<00:00, 2363.31it/s]


Masking STD...
Total Masked : 384


100%|██████████| 4666/4666 [00:01<00:00, 2337.77it/s]


4608 4666
accuracy after masking STD:  0.9769 0.6375


100%|██████████| 11334/11334 [00:04<00:00, 2335.39it/s]


11264 11334
accuracy after masking STD on complement:  0.9272 0.6026
Masking MAX...
Total Masked : 384


100%|██████████| 4666/4666 [00:02<00:00, 2320.31it/s]


4608 4666
accuracy after masking MAX:  0.7752 0.2297


100%|██████████| 11334/11334 [00:04<00:00, 2315.30it/s]


11264 11334
accuracy after masking MAX on complement:  0.9378 0.6149


100%|██████████| 5362/5362 [00:02<00:00, 2258.95it/s]


5120 5362
Class  1 base accuracy:  0.9489 0.9491


100%|██████████| 10638/10638 [00:04<00:00, 2331.89it/s]


10496 10638
Class  1 complement base accuracy:  0.9403 0.923
Recording activations...


100%|██████████| 5362/5362 [00:02<00:00, 2323.53it/s]


Masking STD...
Total Masked : 384


100%|██████████| 5362/5362 [00:02<00:00, 2169.65it/s]


5120 5362
accuracy after masking STD:  0.9472 0.5879


100%|██████████| 10638/10638 [00:04<00:00, 2305.08it/s]


10496 10638
accuracy after masking STD on complement:  0.9412 0.6175
Masking MAX...
Total Masked : 384


100%|██████████| 5362/5362 [00:02<00:00, 2284.15it/s]


5120 5362
accuracy after masking MAX:  0.735 0.2254


100%|██████████| 10638/10638 [00:04<00:00, 2310.82it/s]


10496 10638
accuracy after masking MAX on complement:  0.9563 0.6338


100%|██████████| 1304/1304 [00:00<00:00, 2140.93it/s]


1280 1304
Class  2 base accuracy:  0.8919 0.8247


100%|██████████| 14696/14696 [00:06<00:00, 2339.26it/s]


14592 14696
Class  2 complement base accuracy:  0.9477 0.9407
Recording activations...


100%|██████████| 1304/1304 [00:00<00:00, 2241.33it/s]


Masking STD...
Total Masked : 384


100%|██████████| 1304/1304 [00:00<00:00, 2247.98it/s]


1280 1304
accuracy after masking STD:  0.7899 0.4178


100%|██████████| 14696/14696 [00:06<00:00, 2301.87it/s]


14592 14696
accuracy after masking STD on complement:  0.9558 0.6204
Masking MAX...
Total Masked : 384


100%|██████████| 1304/1304 [00:00<00:00, 1828.23it/s]


1280 1304
accuracy after masking MAX:  0.4525 0.211


100%|██████████| 14696/14696 [00:06<00:00, 2301.31it/s]


14592 14696
accuracy after masking MAX on complement:  0.962 0.6308


100%|██████████| 2159/2159 [00:00<00:00, 2221.54it/s]


2048 2159
Class  3 base accuracy:  0.9472 0.9357


100%|██████████| 13841/13841 [00:05<00:00, 2317.43it/s]


13824 13841
Class  3 complement base accuracy:  0.9426 0.9312
Recording activations...


100%|██████████| 2159/2159 [00:00<00:00, 2370.34it/s]


Masking STD...
Total Masked : 384


100%|██████████| 2159/2159 [00:00<00:00, 2332.85it/s]


2048 2159
accuracy after masking STD:  0.9375 0.562


100%|██████████| 13841/13841 [00:06<00:00, 2261.36it/s]


13824 13841
accuracy after masking STD on complement:  0.9447 0.612
Masking MAX...
Total Masked : 384


100%|██████████| 2159/2159 [00:00<00:00, 2337.36it/s]


2048 2159
accuracy after masking MAX:  0.7133 0.2264


100%|██████████| 13841/13841 [00:06<00:00, 2285.35it/s]


13824 13841
accuracy after masking MAX on complement:  0.9483 0.6471


100%|██████████| 1937/1937 [00:00<00:00, 2219.91it/s]


1792 1937
Class  4 base accuracy:  0.9251 0.8981


100%|██████████| 14063/14063 [00:06<00:00, 2257.18it/s]


13824 14063
Class  4 complement base accuracy:  0.9457 0.9363
Recording activations...


100%|██████████| 1937/1937 [00:00<00:00, 2290.44it/s]


Masking STD...
Total Masked : 384


100%|██████████| 1937/1937 [00:00<00:00, 2257.23it/s]


1792 1937
accuracy after masking STD:  0.9303 0.5526


100%|██████████| 14063/14063 [00:06<00:00, 2269.81it/s]


13824 14063
accuracy after masking STD on complement:  0.9417 0.6143
Masking MAX...
Total Masked : 384


100%|██████████| 1937/1937 [00:00<00:00, 2260.18it/s]


1792 1937
accuracy after masking MAX:  0.5225 0.2095


100%|██████████| 14063/14063 [00:06<00:00, 2295.57it/s]


13824 14063
accuracy after masking MAX on complement:  0.9565 0.657


100%|██████████| 572/572 [00:00<00:00, 1915.43it/s]


512 572
Class  5 base accuracy:  0.7675 0.7679


100%|██████████| 15428/15428 [00:06<00:00, 2321.88it/s]


15360 15428
Class  5 complement base accuracy:  0.9497 0.9367
Recording activations...


100%|██████████| 572/572 [00:00<00:00, 2252.47it/s]


Masking STD...
Total Masked : 384


100%|██████████| 572/572 [00:00<00:00, 2180.27it/s]


512 572
accuracy after masking STD:  0.7238 0.4084


100%|██████████| 15428/15428 [00:06<00:00, 2300.66it/s]


15360 15428
accuracy after masking STD on complement:  0.9494 0.6207
Masking MAX...
Total Masked : 384


100%|██████████| 572/572 [00:00<00:00, 2115.29it/s]


512 572
accuracy after masking MAX:  0.4003 0.2066


100%|██████████| 15428/15428 [00:06<00:00, 2275.63it/s]


15360 15428
accuracy after masking MAX on complement:  0.9508 0.6128


100%|██████████| 16000/16000 [00:07<00:00, 2269.37it/s]


15872 16000
Class  6 base accuracy:  0.9432 0.9318


100%|██████████| 16000/16000 [00:06<00:00, 2325.02it/s]


15872 16000
Class  6 complement base accuracy:  0.9432 0.9318
Recording activations...


100%|██████████| 16000/16000 [00:06<00:00, 2347.91it/s]


Masking STD...
Total Masked : 384


100%|██████████| 16000/16000 [00:06<00:00, 2301.49it/s]


15872 16000
accuracy after masking STD:  0.9429 0.6007


100%|██████████| 16000/16000 [00:07<00:00, 2270.12it/s]


15872 16000
accuracy after masking STD on complement:  0.9429 0.6007
Masking MAX...
Total Masked : 384


100%|██████████| 16000/16000 [00:06<00:00, 2307.13it/s]


15872 16000
accuracy after masking MAX:  0.9307 0.3962


100%|██████████| 16000/16000 [00:07<00:00, 2222.58it/s]

15872 16000
accuracy after masking MAX on complement:  0.9307 0.3962
+---------+---------------+-----------------+---------------------+----------------------+--------------+----------------+--------------------+---------------------+--------------+----------------+--------------------+---------------------+--------------+---------------+
|  Class  | Base Accuracy | Base Confidence | Base Complement Acc | Base Compliment Conf | STD Accuracy | STD Confidence | STD compliment ACC | STD compliment Conf | MAX Accuracy | MAX Confidence | Max compliment acc | Max compliment conf | Total Masked | Intersedction |
+---------+---------------+-----------------+---------------------+----------------------+--------------+----------------+--------------------+---------------------+--------------+----------------+--------------------+---------------------+--------------+---------------+
| Class 0 |     0.9781    |      0.9671     |        0.9288       |        0.9165        |    0.9769    |     0.637




In [4]:
import ipywidgets as widgets
from IPython.display import display, HTML

# Assuming you already have your DataFrame
# df = pd.DataFrame(dataset_list[1])

def display_df(dataframe, rows_per_page=10):
    # Convert the dataframe to HTML
    html = dataframe.to_html(classes='table table-striped')
    
    # Add Bootstrap CSS
    html = f"""
    <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css">
    <div class="container">
        {html}
    </div>
    """
    
    # Add pagination
    total_rows = len(dataframe)
    total_pages = (total_rows - 1) // rows_per_page + 1
    
    pagination_html = f"""
    <nav>
        <ul class="pagination justify-content-center">
            <li class="page-item"><a class="page-link" href="#" id="prev-page">Previous</a></li>
            <li class="page-item"><span class="page-link" id="current-page">1 / {total_pages}</span></li>
            <li class="page-item"><a class="page-link" href="#" id="next-page">Next</a></li>
        </ul>
    </nav>
    """
    
    # Add JavaScript for pagination functionality
    js = f"""
    <script>
        var currentPage = 1;
        var rowsPerPage = {rows_per_page};
        var totalPages = {total_pages};
        
        function showPage(page) {{
            var rows = document.querySelectorAll('table.table tbody tr');
            for (var i = 0; i < rows.length; i++) {{
                if (i >= (page - 1) * rowsPerPage && i < page * rowsPerPage) {{
                    rows[i].style.display = '';
                }} else {{
                    rows[i].style.display = 'none';
                }}
            }}
            document.getElementById('current-page').textContent = page + ' / ' + totalPages;
        }}
        
        document.getElementById('prev-page').addEventListener('click', function(e) {{
            e.preventDefault();
            if (currentPage > 1) {{
                currentPage--;
                showPage(currentPage);
            }}
        }});
        
        document.getElementById('next-page').addEventListener('click', function(e) {{
            e.preventDefault();
            if (currentPage < totalPages) {{
                currentPage++;
                showPage(currentPage);
            }}
        }});
        
        showPage(1);
    </script>
    """
    
    # Combine all HTML and JavaScript
    full_html = html + pagination_html + js
    
    # Display the result
    display(HTML(full_html))

In [16]:
import pandas as pd

df = pd.DataFrame(dataset_list[5])

In [17]:
# df = df.iloc[::-1]
display_df(df)

Unnamed: 0,0,1,2,3
0,i waited to hold my precious boy in my arms no i did not get to feel his sweet skin against mine after his birth no i could not rub his soft hair or look into his beautiful eyes but god had a plan,1,"[0.11188905, 0.13397843, 0.44713435, 0.09947785, 0.09712214, 0.11039818]",0.001006
1,i feel your soul in mine calling for our beloved,1,"[0.14084665, 0.141238, 0.419752, 0.10933736, 0.091798075, 0.09702783]",0.002334
2,i don t discuss even my feelings for beloved with anyone,1,"[0.12717606, 0.13951164, 0.43672562, 0.11356184, 0.09333057, 0.08969426]",0.004192
3,i had a feeling going into this book that its a little too well loved to be orthodox,1,"[0.11467561, 0.13965856, 0.44958448, 0.110723354, 0.08804612, 0.097311914]",0.004544
4,i cry when i think of the utter devastation my mum will feel to lose her beloved companion of years,1,"[0.15565604, 0.14066613, 0.4036414, 0.10491234, 0.0958999, 0.09922424]",0.004846
5,i feel after reading allthingsbucks blog which brought tears to my eyes and a lump in my throat and a feeling of not having a worthwhile thing to be upset about that i shouldnt write such a lame blog,1,"[0.5673408, 0.10830105, 0.090447344, 0.117138565, 0.06744149, 0.049330745]",0.004907
6,i cant help but wonder if the other mom i walked with felt the same way i was feeling as she watched her sweet girls with my isaac,1,"[0.1094298, 0.13671184, 0.4531789, 0.09666956, 0.09108621, 0.11292372]",0.006197
7,i know a lot of people are whining that a first boot cant possibly be a favourite but you guys know how i feel about my beloved a href http winterpaysforsummer,1,"[0.14002214, 0.14437118, 0.42046064, 0.10590176, 0.09148694, 0.09775737]",0.006458
8,i feel cared for and accepted,1,"[0.1370945, 0.12992272, 0.4488269, 0.1007134, 0.095239334, 0.08820315]",0.006701
9,i cant escape the tears of sadness and just true grief i feel at the loss of my sweet friend and sister,1,"[0.116946526, 0.13760547, 0.44626892, 0.10051321, 0.09236305, 0.10630279]",0.006963


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("esuriddick/distilbert-base-uncased-finetuned-emotion")