In [1]:
import os
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optim
from utilities import train, eval, pad, get_model_bert
from POS_dataset import PosDataset
from prettytable import PrettyTable
import nltk
tagged_sents = nltk.corpus.treebank.tagged_sents()

tags = list(set(word_pos[1] for sent in tagged_sents for word_pos in sent))

",".join(tags)

tags = ["<pad>"] + tags

tag2idx = {tag:idx for idx, tag in enumerate(tags)}
idx2tag = {idx:tag for idx, tag in enumerate(tags)}

# Let's split the data into train and test (or eval)
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(tagged_sents, test_size=.1)
len(train_data), len(test_data)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from transformers import AutoTokenizer, AutoModelForTokenClassification, TokenClassificationPipeline
import torch

model_name = "QCRI/bert-base-multilingual-cased-pos-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)

class MaskLayer(nn.Module):
    def __init__(self, lower_bound, upper_bound, replacement_values):
        super(MaskLayer, self).__init__()
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.replacement_values = replacement_values

 

    def forward(self, x):
        lower_bound = self.lower_bound.to(dtype=x.dtype, device=x.device).view(1, 1, -1)
        upper_bound = self.upper_bound.to(dtype=x.dtype, device=x.device).view(1, 1, -1)
        replacement_values = self.replacement_values.to(dtype=x.dtype, device=x.device).view(1, 1, -1)

 

        mask = (x >= lower_bound) & (x <= upper_bound)
        x = torch.where(mask, replacement_values, x)
        return x
    
    def set_perms(self,lower_bound, upper_bound, replacement_values):
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.replacement_values = replacement_values

class Net(nn.Module):
    def __init__(self, vocab_size=None):
        super().__init__()
        self.model = AutoModelForTokenClassification.from_pretrained(model_name)
        self.bert = self.model.bert
        self.masking_layer = torch.ones(768).to("cuda")
        self.mask_layer = MaskLayer(torch.tensor(float('inf')), torch.tensor(float('-inf')), torch.tensor(0.0))

        self.fc = nn.Linear(768, vocab_size)
        self.device = device

    def forward(self, x, y):
        '''
        x: (N, T). int64
        y: (N, T). int64
        '''
        x = x.to(device)
        y = y.to(device)
        
        if self.training:
            self.bert.train()
            encoded_layers = self.bert(x)
            enc = encoded_layers[-1]
        else:
            self.bert.eval()
            with torch.no_grad():
                encoded_layers = self.bert(x)
                enc = encoded_layers[-1]
        # enc = nn.ReLU(enc)
        # enc = enc * self.masking_layer
        enc = self.mask_layer(enc)
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        confidence = logits.softmax(-1).max(-1).values
        return enc, logits, y, y_hat, confidence
    
    
model = Net(vocab_size=len(tag2idx))
model.to(device)

train_dataset = PosDataset(train_data, tokenizer, tag2idx)
eval_dataset = PosDataset(test_data, tokenizer, tag2idx)

train_iter = data.DataLoader(dataset=train_dataset,
                             batch_size=8,
                             shuffle=True,
                             num_workers=1,
                             collate_fn=pad)
test_iter = data.DataLoader(dataset=eval_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)
activation_iter = data.DataLoader(dataset=train_dataset+eval_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

optimizer = optim.Adam(model.parameters(), lr = 0.0001)

criterion = nn.CrossEntropyLoss(ignore_index=0)

for i in range(10):
    train(model, train_iter, optimizer, criterion)

from utilities import  eval

model.masking_layer = torch.ones(768).to("cuda")
activation_iter = data.DataLoader(dataset=train_dataset+eval_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1,
                             collate_fn=pad)

enc_dict = eval(model, activation_iter, idx2tag, tag2idx,33)

Some weights of the model checkpoint at QCRI/bert-base-multilingual-cased-pos-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


step: 0, loss: 4.012983798980713
step: 10, loss: 0.3446679413318634
step: 20, loss: 0.28085753321647644
step: 30, loss: 0.12452409416437149
step: 40, loss: 0.11375977098941803
step: 50, loss: 0.12363813072443008
step: 60, loss: 0.0389459989964962
step: 70, loss: 0.08982899785041809
step: 80, loss: 0.16142012178897858
step: 90, loss: 0.25490763783454895
step: 100, loss: 0.12713389098644257
step: 110, loss: 0.11952819675207138
step: 120, loss: 0.07232357561588287
step: 130, loss: 0.09034412354230881
step: 140, loss: 0.06357061862945557
step: 150, loss: 0.05770187824964523
step: 160, loss: 0.04031563177704811
step: 170, loss: 0.0304564218968153
step: 180, loss: 0.0903044268488884
step: 190, loss: 0.05765240266919136
step: 200, loss: 0.11337007582187653
step: 210, loss: 0.05553313344717026
step: 220, loss: 0.15457281470298767
step: 230, loss: 0.13043850660324097
step: 240, loss: 0.07155825197696686
step: 250, loss: 0.06299138069152832
step: 260, loss: 0.06146363541483879
step: 270, loss: 0

In [3]:
def mask_range_bert(model, mask, fc_vals):
    mean = torch.tensor(np.mean(fc_vals, axis=0))
    std = torch.tensor(np.std(fc_vals, axis=0))
    mask = mask.to(torch.bool)
    a = 2.5
    lower_bound = torch.full_like(mean, torch.inf)
    lower_bound[~mask] = mean[~mask] - a*std[~mask]
    upper_bound = torch.full_like(mean, -torch.inf)
    upper_bound[~mask] = mean[~mask] + a*std[~mask]
    
    model.mask_layer.lower_bound = lower_bound.to(device)
    model.mask_layer.upper_bound = upper_bound.to(device)
    
    return model

In [2]:
results_table = PrettyTable()

results_table.field_names = results_table.field_names = ["Class", "Base Accuracy", "Base Confidence", "Base Complement Acc", "Base Compliment Conf", "MAX Accuracy", "MAX Confidence", "Max compliment acc", "Max compliment conf"]

class_labels = []
base_accuracies = []
base_confidences = []
base_comp_acc = []
base_comp_conf = []
max_accuracies = []
max_confidences = []
max_comp_acc = []
max_comp_conf = []

In [7]:
from utilities import compute_masks, eval
for tok in range(45):
    print(idx2tag[tok],"----------------")
    # model.masking_layer = torch.ones(768).to("cuda")
    model.mask_layer = MaskLayer(torch.tensor(float('inf')), torch.tensor(float('-inf')), torch.tensor(0.0)).to("cuda")
    activation_iter = data.DataLoader(dataset=train_dataset+eval_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=1,
                                collate_fn=pad)
    print("Original:")
    enc_dict = eval(model, activation_iter, idx2tag, tag2idx, tok)
    class_labels.append(idx2tag[tok])
    base_accuracies.append(enc_dict[1][0])
    base_confidences.append(enc_dict[1][1])
    base_comp_acc.append(enc_dict[2][0])
    base_comp_conf.append(enc_dict[2][1])
    print("Tok:", enc_dict[1])
    print('Compliment:', enc_dict[2])

    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(enc_dict[0][tok],1)
    # print("STD:")print("Tok:", enc_dict[1])
    print('Compliment:', enc_dict[2])
    # model.masking_layer = mask_std.to("cuda")

    # enc_dict = eval(model, activation_iter, idx2tag, tag2idx, tok)
    print("Max:")
    # model.masking_layer = mask_max.to("cuda")
    model = mask_range_bert(model, mask_max, enc_dict[0][tok])

    enc_dict = eval(model, activation_iter, idx2tag, tag2idx, tok)
    max_accuracies.append(enc_dict[1][0])
    max_confidences.append(enc_dict[1][1])
    max_comp_acc.append(enc_dict[2][0])
    max_comp_conf.append(enc_dict[2][1])
    print("Tok:", enc_dict[1])
    print('Compliment:', enc_dict[2])
    print("-----------------------------")
    
    results_table.add_row([
                class_labels[tok],
                base_accuracies[tok],
                base_confidences[tok],
                base_comp_acc[tok],
                base_comp_conf[tok],
                max_accuracies[tok],
                max_confidences[tok],
                max_comp_acc[tok],
                max_comp_conf[tok],
            ])
print(results_table)

<pad> ----------------
Original:
<pad>: N/A (0 occurrences)
Tok: (0, 0)
Compliment: (0, 0)
Compliment: (0, 0)
Max:
<pad>: N/A (0 occurrences)
Tok: (0, 0)
Compliment: (0, 0)
-----------------------------
WRB ----------------
Original:
<pad>: N/A (0 occurrences)
Tok: (1.0, 0.9994)
Compliment: (0.9934, 0.9942)
Compliment: (0.9934, 0.9942)
Max:
<pad>: N/A (0 occurrences)
Tok: (0.2079, 0.0354)
Compliment: (0.9935, 0.9908)
-----------------------------
VBN ----------------
Original:
<pad>: N/A (0 occurrences)
Tok: (0.9859, 0.9937)
Compliment: (0.9936, 0.9942)
Compliment: (0.9936, 0.9942)
Max:
<pad>: N/A (0 occurrences)
Tok: (0.1012, 0.0393)
Compliment: (0.9937, 0.9777)
-----------------------------
LS ----------------
Original:
<pad>: N/A (0 occurrences)
Tok: (1.0, 0.9912)
Compliment: (0.9934, 0.9942)
Compliment: (0.9934, 0.9942)
Max:
<pad>: N/A (0 occurrences)
Tok: (0.0769, 0.0231)
Compliment: (0.9934, 0.9899)
-----------------------------
NNS ----------------
Original:
<pad>: N/A (0 occurr

In [5]:
import numpy as np
import torch

def compute_masks(fc_vals, percent):
    # Convert input to numpy array
    fc_vals_array = np.array(fc_vals)
    
    # Compute statistics
    mean_vals = np.mean(np.abs(fc_vals_array), axis=0)
    std_vals = np.std(fc_vals_array, axis=0)
    min_vals = np.min(fc_vals_array, axis=0)
    max_vals = np.max(fc_vals_array, axis=0)
    
    # Normalize standard deviation
    std_vals_normalized = (std_vals - min_vals) / (max_vals - min_vals)
    
    # Convert to PyTorch tensors
    mean_vals_tensor = torch.from_numpy(mean_vals)
    std_vals_tensor = torch.from_numpy(std_vals_normalized)
    
    # Compute masks
    mask_max = compute_max_mask(mean_vals_tensor, percent)
    mask_std = compute_std_mask(std_vals_tensor, percent)
    mask_max_low_std = compute_max_low_std_mask(mean_vals_tensor, std_vals_tensor, percent)
    mask_intersection = torch.logical_or(mask_std, mask_max).float()
    
    return mask_max, mask_std, mask_intersection, mask_max_low_std

def compute_max_mask(values, percent):
    sorted_indices = torch.argsort(values, descending=True)
    mask_count = int(percent * len(values))
    mask = torch.ones_like(values)
    mask[sorted_indices[:mask_count]] = 0.0
    return mask

def compute_std_mask(values, percent):
    sorted_indices = torch.argsort(values, descending=False)
    mask_count = int(percent * len(values))
    mask = torch.ones_like(values)
    mask[sorted_indices[:mask_count]] = 0.0
    return mask

def compute_max_low_std_mask(mean_vals, std_vals, percent):
    # Get indices of bottom 50% std values
    bottom_50_percent_std_count = int(0.99 * len(std_vals))
    bottom_50_percent_std_indices = torch.argsort(std_vals)[:bottom_50_percent_std_count]
    
    # Create a mask for bottom 50% std values
    bottom_50_percent_std_mask = torch.zeros_like(std_vals, dtype=torch.bool)
    bottom_50_percent_std_mask[bottom_50_percent_std_indices] = True
    
    # Filter mean values
    mean_vals_filtered = mean_vals.clone()
    mean_vals_filtered[~bottom_50_percent_std_mask] = float('-inf')
    
    # Compute mask
    return compute_max_mask(mean_vals_filtered, percent)

In [6]:
for i, fc in enumerate(enc_dict):
    print(f"Layer {fc}", i)

Layer 0 0
Layer 38 1
Layer 44 2
Layer 20 3
Layer 10 4
Layer 17 5
Layer 24 6
Layer 15 7
Layer 32 8
Layer 8 9
Layer 25 10
Layer 9 11
Layer 41 12
Layer 30 13
Layer 39 14
Layer 7 15
Layer 2 16
Layer 22 17
Layer 31 18
Layer 12 19
Layer 26 20
Layer 27 21
Layer 11 22
Layer 46 23
Layer 23 24
Layer 40 25
Layer 4 26
Layer 19 27
Layer 45 28
Layer 43 29
Layer 42 30
Layer 14 31
Layer 33 32
Layer 37 33
Layer 34 34
Layer 36 35
Layer 21 36
Layer 16 37
Layer 29 38
Layer 5 39
Layer 28 40
Layer 35 41
Layer 13 42
Layer 1 43
Layer 6 44
Layer 3 45
Layer 18 46


In [7]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(enc_dict):
    tag = idx2tag[fc1]
    fc1 = enc_dict[fc1]
    
    fc1 = np.array(fc1)
    mask_max, mask_std, mask_intersection, mask_max_low_std = compute_masks(fc1, 0.15)
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_all = len(m_norm)
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"All Activations (Count: {count_all})",
                                            f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            indices_list = list(indices)  # Convert range or numpy array to list
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=m_norm[indices_list],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=s_norm[indices_list],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices_list:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all activations
        add_traces(range(len(m_norm)), 1, 1)
        
        # Add traces for other plots
        add_traces(indices_max, 1, 2)
        add_traces(indices_std, 1, 3)
        add_traces(indices_intersection, 2, 1)
        add_traces(indices_max_minus_std, 2, 2)
        add_traces(indices_std_minus_max, 2, 3)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}'+ tag,
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

def create_index_tracking_plot(indices_per_class, title):
    num_classes = len(indices_per_class)
    all_indices = sorted(set.union(*[set(indices) for indices in indices_per_class]))
    
    fig = go.Figure()
    
    # Create a color scale
    color_scale = px.colors.diverging.RdYlGn_r  # Red to Yellow to Green color scale

    # Add edges for indices present in multiple classes
    for idx in all_indices:
        classes_with_idx = [i for i, indices in enumerate(indices_per_class) if idx in indices]
        if len(classes_with_idx) > 1:
            x = [idx] * len(classes_with_idx)
            y = classes_with_idx
            color_index = (len(classes_with_idx) - 1) / (num_classes - 1)  # Normalize to [0, 1]
            edge_color = px.colors.sample_colorscale(color_scale, [color_index])[0]
            
            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='text',
                hovertext=f'Index: {idx}<br>Present in {len(classes_with_idx)} classes',
                showlegend=False
            ))
    
    # Add scatter plots for each class
    for class_idx, indices in enumerate(indices_per_class):
        fig.add_trace(go.Scatter(
            x=indices,
            y=[class_idx] * len(indices),
            mode='markers',
            name=f'Class {class_idx + 1}',
            marker=dict(size=4, symbol='circle', color='black'),
            hoverinfo='text',
            hovertext=[f'Index: {idx}<br>Class: {class_idx + 1}' for idx in indices]
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Activation Index',
        yaxis_title='Class',
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(num_classes)),
            ticktext=[f'Class {i+1}' for i in range(num_classes)]
        ),
        hovermode='closest',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    # Add color bar
    fig.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(
            colorscale=color_scale,
            showscale=True,
            cmin=1,
            cmax=num_classes,
            colorbar=dict(
                title='Number of Classes',
                tickvals=list(range(1, num_classes+1)),
                ticktext=list(range(1, num_classes+1))
            )
        ),
        hoverinfo='none',
        showlegend=False
    ))
    
    return fig

# Collect indices for each class
max_indices_per_class = []
std_indices_per_class = []

for fc1 in all_fc_vals:
    mask_max, mask_std = compute_masks(fc1, 0.15)
    max_indices_per_class.append(np.where(mask_max == 0)[0])
    std_indices_per_class.append(np.where(mask_std == 0)[0])

# Create and display visualizations
output_widgets = []

out = Output()
with out:
    fig_max = create_index_tracking_plot(max_indices_per_class, 'Max Mask Indices Across Classes')
    display(fig_max)
output_widgets.append(out)

out = Output()
with out:
    fig_std = create_index_tracking_plot(std_indices_per_class, 'Std Mask Indices Across Classes')
    display(fig_std)
output_widgets.append(out)

# Display all visualizations
# display(VBox(output_widgets))