In [2]:
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch
from model_distill_bert import getmodel



tokenizer = AutoTokenizer.from_pretrained("esuriddick/distilbert-base-uncased-finetuned-emotion")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load the dataset
dataset_all = load_dataset("dair-ai/emotion")
# Select the train split
dataset_all = dataset_all['train']
all_fc_vals = []
for j in range(0,7):
    model = getmodel("esuriddick/distilbert-base-uncased-finetuned-emotion")
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    dataset_complement = dataset_all.filter(lambda x: x['label'] not in [j])
    
    if(j==6):
        dataset = dataset_all
    #record the activations of the first fully connected layer, CLS tokken
    print("Recording activations...")
    progress_bar = tqdm(total=len(dataset))
    model.to(device)
    model.eval()
    fc_vals = []
    with torch.no_grad():
        for i in range(len(dataset)):
            text = dataset[i]['text']
            inputs = tokenizer(text, return_tensors="pt").to(device)
            outputs = model(**inputs)
            fc_vals.append(outputs[1].squeeze().cpu().numpy())
            progress_bar.update(1)
        progress_bar.close()
    all_fc_vals.append(fc_vals)

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 4666/4666 [00:31<00:00, 148.65it/s]


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 5362/5362 [00:26<00:00, 199.87it/s]


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 1304/1304 [00:06<00:00, 197.99it/s]


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 2159/2159 [00:10<00:00, 198.56it/s]


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 1937/1937 [00:09<00:00, 199.10it/s]


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 572/572 [00:02<00:00, 196.69it/s]


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Recording activations...


100%|██████████| 16000/16000 [01:19<00:00, 200.13it/s]


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)    
    
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)

    # Create a new figure for each set of values
    out = Output()
    with out:
        plt.figure(figsize=(12, 6))

        # Plot the mean
        plt.subplot(1, 2, 1)
        plt.plot(m, 'bo', markersize=4)
        plt.title(f'Mean of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Mean Value')
        plt.ylim(0, np.max(m))  # Ensure y-axis starts at 0

        # Plot the standard deviation
        plt.subplot(1, 2, 2)
        plt.plot(s, 'ro', markersize=4)
        plt.title(f'Standard Deviation of Activations - Set {i+1}')
        plt.xlabel('Activation Index')
        plt.ylabel('Standard Deviation')
        plt.ylim(0, np.max(s))  # Ensure y-axis starts at 0

        # Show the plots
        plt.tight_layout()
        plt.show()

    output_widgets.append(out)

# Display all figures in a vertical box
VBox(output_widgets)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output

output_widgets = []

for i, v in enumerate(all_fc_vals):
    v = np.array(v)
    m = np.mean(np.abs(v), axis=0)
    s = np.std(v, axis=0)
    
    min_val = np.min(v, axis=0)
    max_val = np.max(v, axis=0)
    s = (s-min_val) / (max_val - min_val)
    # m = (m-min_val) / (max_val - min_val)
    # Create a new figure for each set of values
    out = Output()
    with out:
        fig = go.Figure()

        # Plot the mean with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=m,
            mode='markers',
            name='Mean',
            marker=dict(size=3, color='blue')
        ))

        # Plot the standard deviation with markers
        fig.add_trace(go.Scatter(
            x=list(range(768)),
            y=s,
            mode='markers',
            name='Std Dev',
            marker=dict(size=3, color='red')
        ))

        # Add lines connecting corresponding points
        for j in range(768):
            fig.add_trace(go.Scatter(
                x=[j, j],
                y=[m[j], s[j]],
                mode='lines',
                line=dict(color='gray', width=0.5),
                showlegend=False
            ))

        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Set {i+1}',
            xaxis_title='Activation Index',
            yaxis_title='Value',
            yaxis=dict(range=[0, max(np.max(m), np.max(s)) * 1.1]),
            height=600,
            width=1000,
            hovermode='closest'
        )

        fig.show()
    
    output_widgets.append(out)

# Display all figures in a vertical box
# VBox(output_widgets)

In [17]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
import importlib
import utilities

importlib.reload(utilities)
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_std_high_max = compute_masks(fc1, 0.15)
    mask_std = mask_max_low_std
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=m_norm[indices],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices,
                    y=s_norm[indices],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all plots
        add_traces(indices_max, 1, 1)
        add_traces(indices_std, 1, 2)
        add_traces(indices_intersection, 1, 3)
        add_traces(indices_max_minus_std, 2, 1)
        add_traces(indices_std_minus_max, 2, 2)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                if row == 2 and col == 3:
                    continue  # Skip the empty subplot
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std = compute_masks(fc1, 0.15)
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_all = len(m_norm)
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"All Activations (Count: {count_all})",
                                            f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            indices_list = list(indices)  # Convert range or numpy array to list
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=m_norm[indices_list],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=s_norm[indices_list],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices_list:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all activations
        add_traces(range(len(m_norm)), 1, 1)
        
        # Add traces for other plots
        add_traces(indices_max, 1, 2)
        add_traces(indices_std, 1, 3)
        add_traces(indices_intersection, 2, 1)
        add_traces(indices_max_minus_std, 2, 2)
        add_traces(indices_std_minus_max, 2, 3)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

def create_index_tracking_plot(indices_per_class, title):
    num_classes = len(indices_per_class)
    all_indices = sorted(set.union(*[set(indices) for indices in indices_per_class]))
    
    fig = go.Figure()
    
    # Create a color scale
    color_scale = px.colors.diverging.RdYlGn_r  # Red to Yellow to Green color scale

    # Add edges for indices present in multiple classes
    for idx in all_indices:
        classes_with_idx = [i for i, indices in enumerate(indices_per_class) if idx in indices]
        if len(classes_with_idx) > 1:
            x = [idx] * len(classes_with_idx)
            y = classes_with_idx
            color_index = (len(classes_with_idx) - 1) / (num_classes - 1)  # Normalize to [0, 1]
            edge_color = px.colors.sample_colorscale(color_scale, [color_index])[0]
            
            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='text',
                hovertext=f'Index: {idx}<br>Present in {len(classes_with_idx)} classes',
                showlegend=False
            ))
    
    # Add scatter plots for each class
    for class_idx, indices in enumerate(indices_per_class):
        fig.add_trace(go.Scatter(
            x=indices,
            y=[class_idx] * len(indices),
            mode='markers',
            name=f'Class {class_idx + 1}',
            marker=dict(size=4, symbol='circle', color='black'),
            hoverinfo='text',
            hovertext=[f'Index: {idx}<br>Class: {class_idx + 1}' for idx in indices]
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Activation Index',
        yaxis_title='Class',
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(num_classes)),
            ticktext=[f'Class {i+1}' for i in range(num_classes)]
        ),
        hovermode='closest',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    # Add color bar
    fig.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(
            colorscale=color_scale,
            showscale=True,
            cmin=1,
            cmax=num_classes,
            colorbar=dict(
                title='Number of Classes',
                tickvals=list(range(1, num_classes+1)),
                ticktext=list(range(1, num_classes+1))
            )
        ),
        hoverinfo='none',
        showlegend=False
    ))
    
    return fig

# Collect indices for each class
max_indices_per_class = []
std_indices_per_class = []

for fc1 in all_fc_vals:
    mask_max, mask_std = compute_masks(fc1, 0.15)
    max_indices_per_class.append(np.where(mask_max == 0)[0])
    std_indices_per_class.append(np.where(mask_std == 0)[0])

# Create and display visualizations
output_widgets = []

out = Output()
with out:
    fig_max = create_index_tracking_plot(max_indices_per_class, 'Max Mask Indices Across Classes')
    display(fig_max)
output_widgets.append(out)

out = Output()
with out:
    fig_std = create_index_tracking_plot(std_indices_per_class, 'Std Mask Indices Across Classes')
    display(fig_std)
output_widgets.append(out)

# Display all visualizations
# display(VBox(output_widgets))

In [1]:
from transformers import DistilBertTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch
from prettytable import PrettyTable
from model_distill_bert import getmodel
from utilities import compute_accuracy, compute_masks, mask


compliment = True
results_table = PrettyTable()
if(compliment):
   results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "STD Accuracy", "STD Confidence", "STD compliment ACC", "STD compliment Conf", "Total Masked"]#, "Same as Max"]#"MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"
# results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "STD Accuracy", "STD Confidence", "Same as Max"]#, "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"]

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
std_masked_counts = []
std_accuracies = []
std_confidences = []
std_comp_acc = []
std_comp_conf = []
max_masked_counts = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []
diff_from_max = []
total_masked = []


tokenizer = DistilBertTokenizer.from_pretrained("textattack/distilbert-base-uncased-ag-news")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dataset_test = load_dataset("glue", "sst2")

# Load the dataset
dataset_all = load_dataset("fancyzhx/ag_news")
# Select the train split
dataset_all = dataset_all['train']

# Filter Classes
# class_label:
    # names:
    #     '0': sadness
    #     '1': joy
    #     '2': love
    #     '3': anger
    #     '4': fear
    #     '5': surprise
    

for j in range(0,4):
    model = getmodel("textattack/distilbert-base-uncased-ag-news")
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    dataset_complement = dataset_all.filter(lambda x: x['label'] not in [j])
    
    if(j==6):
        dataset = dataset_all

    class_labels.append(f"Class {j}")
    acc = compute_accuracy(dataset, model, tokenizer)
    print("Class ",j, "base accuracy: ", acc)
    base_accuracies.append(acc[0])
    base_confidences.append(acc[1])
    if(compliment):
        acc = compute_accuracy(dataset_complement, model, tokenizer)
        print("Class ",j, "complement base accuracy: ", acc)
        base_comp_acc.append(acc[0])
        base_comp_conf.append(acc[1])
        

    #record the activations of the first fully connected layer, CLS tokken
    print("Recording activations...")
    progress_bar = tqdm(total=len(dataset))
    model.to(device)
    model.eval()
    fc_vals = []
    with torch.no_grad():
        for i in range(len(dataset)):
            text = dataset[i]['text']
            inputs = tokenizer(text, return_tensors="pt", max_length = 512).to(device)
            outputs = model(**inputs)
            fc_vals.append(outputs[1].squeeze().cpu().numpy())
            progress_bar.update(1)
        progress_bar.close()


        
    mask_max, mask_std,mask_intersection, mask_max_low_std = compute_masks(fc_vals,0.30)
    mask_std = mask_max_low_std
    # print("Masking STD...")
    # model = mask(model,mask_std)
    # t = int(mask_std.shape[0]-torch.count_nonzero(mask_std))
    # print("Total Masked :", t)
    # total_masked.append(t)
    # diff_from_max.append(115-(mask_std.shape[0]-torch.count_nonzero(mask_std)))
    # acc = compute_accuracy(dataset, model, tokenizer)
    # print("accuracy after masking STD: ", acc)
    # std_accuracies.append(acc[0])
    # std_confidences.append(acc[1])
    # if(compliment):
    #     acc = compute_accuracy(dataset_complement, model, tokenizer)
    #     print("accuracy after masking STD on complement: ", acc)
    #     std_comp_acc.append(acc[0])
    #     std_comp_conf.append(acc[1])

    print("Masking MAX...")
    model = mask(model,mask_max)
    t = int(mask_max.shape[0]-torch.count_nonzero(mask_max))
    print("Total Masked :", t)
    total_masked.append(t)
    acc = compute_accuracy(dataset, model, tokenizer)
    print("accuracy after masking MAX: ", acc)
    max_accuracies.append(acc[0])
    max_confidences.append(acc[1])
    acc = compute_accuracy(dataset_complement, model, tokenizer)
    print("accuracy after masking MAX on complement: ", acc)
    max_comp_acc.append(acc[0])
    max_comp_conf.append(acc[1])
    if(compliment):
        results_table.add_row([
            class_labels[j],
            base_accuracies[j],
            base_confidences[j],
            base_comp_acc[j],
            base_comp_conf[j],
            # std_accuracies[j],
            # std_confidences[j],
            # std_comp_acc[j],
            # std_comp_conf[j],
            max_accuracies[j],
            max_confidences[j],
            max_comp_acc[j],
            max_comp_conf[j],
            total_masked[j],
            # diff_from_max[j]
        ])
    # results_table.add_row([
    #     class_labels[j],
    #     base_accuracies[j],
    #     base_confidences[j],
    #     std_accuracies[j],
    #     std_confidences[j],
    #     # max_accuracies[j],
    #     # max_confidences[j],
    #     diff_from_max[j]
    # ])

print(results_table)

Using custom data configuration fancyzhx--ag_news-f4012edcb412d6fa
Reusing dataset parquet (/u/amo-d1/grad/mha361/.cache/huggingface/datasets/parquet/fancyzhx--ag_news-f4012edcb412d6fa/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


  0%|          | 0/2 [00:00<?, ?it/s]

KeyboardInterrupt: 