In [1]:
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import torch
from prettytable import PrettyTable
# from model_distill_bert import getmodel
from utilities import compute_accuracy, compute_masks, mask_distillbert, get_model_distilbert, record_activations, mask_range_distilbert



tokenizer = AutoTokenizer.from_pretrained("esuriddick/distilbert-base-uncased-finetuned-emotion")
# Check if a GPU is available and use it
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Load the dataset
dataset_all = load_dataset("dair-ai/emotion")
# Select the train split
dataset_all = dataset_all['train']
all_fc_vals = []
for j in range(0,6):
    model = get_model_distilbert("2O24dpower2024/distilbert-base-uncased-finetuned-emotion",5)
    dataset = dataset_all.filter(lambda x: x['label'] in [j])
    dataset_complement = dataset_all.filter(lambda x: x['label'] not in [j])
    print("Recording activations...")
    progress_bar = tqdm(total=len(dataset))
    model.to(device)
    model.eval()
    fc_vals = []
    all_fc_vals.append(record_activations(dataset, model, tokenizer, text_tag='text', mask_layer=5, batch_size=256))

  model.load_state_dict(torch.load(weights_path))


Recording activations...


  0%|          | 0/4666 [00:00<?, ?it/s]

  0%|          | 0/4666 [00:00<?, ?it/s]

Recording activations...


  0%|          | 0/4666 [00:11<?, ?it/s]


  0%|          | 0/5362 [00:00<?, ?it/s]

Recording activations...


  0%|          | 0/5362 [00:13<?, ?it/s]


  0%|          | 0/1304 [00:00<?, ?it/s]

Recording activations...


  0%|          | 0/1304 [00:09<?, ?it/s]


  0%|          | 0/2159 [00:00<?, ?it/s]

Recording activations...


  0%|          | 0/2159 [00:10<?, ?it/s]


  0%|          | 0/1937 [00:00<?, ?it/s]

Recording activations...


  0%|          | 0/1937 [00:09<?, ?it/s]


  0%|          | 0/572 [00:00<?, ?it/s]

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

output_widgets = []

for i, fc1 in enumerate(all_fc_vals):
    fc1 = np.array(fc1)
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc1, 0.5)
    # mask_std = mask_max_low_std
    
    m = np.mean(np.abs(fc1), axis=0)
    s = np.std(fc1, axis=0)
    min_val = np.min(fc1, axis=0)
    max_val = np.max(fc1, axis=0)
    
    # Normalize std and mean
    s_norm = (s - min_val) / (max_val - min_val)
    m_norm = m#(m - min_val) / (max_val - min_val)
    
    # Create indices for different masks
    indices_max = np.where(mask_max == 0)[0]
    indices_std = np.where(mask_std == 0)[0]
    indices_intersection = np.intersect1d(indices_max, indices_std)
    indices_max_minus_std = np.setdiff1d(indices_max, indices_std)
    indices_std_minus_max = np.setdiff1d(indices_std, indices_max)
    
    # Count the indices in each set
    count_all = len(m_norm)
    count_max = len(indices_max)
    count_std = len(indices_std)
    count_intersection = len(indices_intersection)
    count_max_minus_std = len(indices_max_minus_std)
    count_std_minus_max = len(indices_std_minus_max)
    
    out = Output()
    with out:
        # Create subplots with counts in titles
        fig = make_subplots(rows=2, cols=3, 
                            subplot_titles=(f"All Activations (Count: {count_all})",
                                            f"Max Mask (Count: {count_max})", 
                                            f"Std Mask (Count: {count_std})", 
                                            f"Intersection (Count: {count_intersection})",
                                            f"Max - Std (Count: {count_max_minus_std})", 
                                            f"Std - Max (Count: {count_std_minus_max})"))
        
        # Helper function to add traces
        def add_traces(indices, row, col):
            indices_list = list(indices)  # Convert range or numpy array to list
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=m_norm[indices_list],
                    mode='markers',
                    name='Mean',
                    marker=dict(size=3, color='blue'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            fig.add_trace(
                go.Scatter(
                    x=indices_list,
                    y=s_norm[indices_list],
                    mode='markers',
                    name='Std Dev',
                    marker=dict(size=3, color='red'),
                    showlegend=(row == 1 and col == 1)
                ),
                row=row, col=col
            )
            for j in indices_list:
                fig.add_trace(
                    go.Scatter(
                        x=[j, j],
                        y=[m_norm[j], s_norm[j]],
                        mode='lines',
                        line=dict(color='gray', width=0.5),
                        showlegend=False
                    ),
                    row=row, col=col
                )
        
        # Add traces for all activations
        add_traces(range(len(m_norm)), 1, 1)
        
        # Add traces for other plots
        add_traces(indices_max, 1, 2)
        add_traces(indices_std, 1, 3)
        add_traces(indices_intersection, 2, 1)
        add_traces(indices_max_minus_std, 2, 2)
        add_traces(indices_std_minus_max, 2, 3)
        
        # Update layout
        fig.update_layout(
            title=f'Mean and Standard Deviation of Activations - Class {i+1}',
            height=1200,
            width=1800,
            hovermode='closest'
        )
        
        # Update x and y axis labels for all subplots
        for row in range(1, 3):
            for col in range(1, 4):
                fig.update_xaxes(title_text="Activation Index", row=row, col=col)
                fig.update_yaxes(title_text="Normalized Value", row=row, col=col)
        
        display(fig)
    
    output_widgets.append(out)

# Display all figures in a vertical box
# display(VBox(output_widgets))

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import VBox, Output
from utilities import compute_masks
from IPython.display import display

def create_index_tracking_plot(indices_per_class, title):
    num_classes = len(indices_per_class)
    all_indices = sorted(set.union(*[set(indices) for indices in indices_per_class]))
    
    fig = go.Figure()
    
    # Create a color scale
    color_scale = px.colors.diverging.RdYlGn_r  # Red to Yellow to Green color scale

    # Add edges for indices present in multiple classes
    for idx in all_indices:
        classes_with_idx = [i for i, indices in enumerate(indices_per_class) if idx in indices]
        if len(classes_with_idx) > 1:
            x = [idx] * len(classes_with_idx)
            y = classes_with_idx
            color_index = (len(classes_with_idx) - 1) / (num_classes - 1)  # Normalize to [0, 1]
            edge_color = px.colors.sample_colorscale(color_scale, [color_index])[0]
            
            fig.add_trace(go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='text',
                hovertext=f'Index: {idx}<br>Present in {len(classes_with_idx)} classes',
                showlegend=False
            ))
    
    # Add scatter plots for each class
    for class_idx, indices in enumerate(indices_per_class):
        fig.add_trace(go.Scatter(
            x=indices,
            y=[class_idx] * len(indices),
            mode='markers',
            name=f'Class {class_idx + 1}',
            marker=dict(size=4, symbol='circle', color='black'),
            hoverinfo='text',
            hovertext=[f'Index: {idx}<br>Class: {class_idx + 1}' for idx in indices]
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Activation Index',
        yaxis_title='Class',
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(num_classes)),
            ticktext=[f'Class {i+1}' for i in range(num_classes)]
        ),
        hovermode='closest',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=False
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    # Add color bar
    fig.add_trace(go.Scatter(
        x=[None],
        y=[None],
        mode='markers',
        marker=dict(
            colorscale=color_scale,
            showscale=True,
            cmin=1,
            cmax=num_classes,
            colorbar=dict(
                title='Number of Classes',
                tickvals=list(range(1, num_classes+1)),
                ticktext=list(range(1, num_classes+1))
            )
        ),
        hoverinfo='none',
        showlegend=False
    ))
    
    return fig

# Collect indices for each class
max_indices_per_class = []
std_indices_per_class = []

for fc1 in all_fc_vals:
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max= compute_masks(fc1, 0.2)
    max_indices_per_class.append(np.where(mask_max == 0)[0])
    std_indices_per_class.append(np.where(mask_std == 0)[0])

# Create and display visualizations
output_widgets = []

out = Output()
with out:
    fig_max = create_index_tracking_plot(max_indices_per_class, 'Max Mask Indices Across Classes')
    display(fig_max)
output_widgets.append(out)

out = Output()
with out:
    fig_std = create_index_tracking_plot(std_indices_per_class, 'Std Mask Indices Across Classes')
    display(fig_std)
output_widgets.append(out)

# Display all visualizations
# display(VBox(output_widgets))

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
from ipywidgets import VBox, Output
from IPython.display import display

def find_common_masked_neurons(masks_per_class):
    """
    Find neurons that are masked (inactive) in all classes
    """
    # Convert all masks to numpy boolean arrays
    bool_masks = []
    for mask in masks_per_class:
        if torch.is_tensor(mask):
            mask = mask.cpu().numpy()
        bool_masks.append(mask.astype(np.bool_))
    
    # Find neurons that are masked (True) in all classes
    common_masked = bool_masks[0]  # Start with first mask
    for mask in bool_masks[1:]:
        common_masked = common_masked & mask
    
    return common_masked

def compute_activation_stats(fc_vals, tao=0.2):
    """
    Compute activation statistics for neurons
    """
    if torch.is_tensor(fc_vals):
        fc_vals = fc_vals.cpu().numpy()
    
    mean = np.mean(fc_vals, axis=0)
    std = np.std(fc_vals, axis=0)
    
    lower_bound = mean - tao * std
    upper_bound = mean + tao * std
    
    return mean, std, lower_bound, upper_bound

def create_common_masked_plot(fc_vals_per_class, masks_per_class, tao=0.2, title="Common Masked Neurons Across Classes"):
    """
    Create a visualization of activation ranges for neurons that are commonly masked across all classes.
    """
    # Find common masked neurons
    common_masked = find_common_masked_neurons(masks_per_class)
    masked_indices = np.where(common_masked)[0]
    
    if len(masked_indices) == 0:
        print("No common masked neurons found across all classes!")
        return None
        
    print(f"Found {len(masked_indices)} neurons masked across all classes")
    print(f"Indices of commonly masked neurons: {masked_indices}")
    
    # Colors for different classes
    colors = [f'rgb({r}, {g}, {b})' for r, g, b in [
        (31, 119, 180),  # blue
        (255, 127, 14),  # orange
        (44, 160, 44),   # green
        (214, 39, 40),   # red
        (148, 103, 189), # purple
        (140, 86, 75),   # brown
        (227, 119, 194), # pink
        (127, 127, 127), # gray
    ]]
    
    fig = go.Figure()
    
    # Process each class
    for class_idx, fc_vals in enumerate(fc_vals_per_class):
        if torch.is_tensor(fc_vals):
            fc_vals = fc_vals.cpu().numpy()
            
        # Compute statistics for masked neurons only
        mean, std, lower_bound, upper_bound = compute_activation_stats(fc_vals, tao)
        
        color = colors[class_idx % len(colors)]
        fill_color = color.replace('rgb', 'rgba').replace(')', ', 0.2)')
        
        # Add range area
        fig.add_trace(go.Scatter(
            x=np.concatenate([masked_indices, masked_indices[::-1]]),
            y=np.concatenate([upper_bound[common_masked], lower_bound[common_masked][::-1]]),
            fill='toself',
            fillcolor=fill_color,
            line=dict(color='rgba(0,0,0,0)'),
            name=f'Range Class {class_idx + 1}',
            showlegend=False,
            hoverinfo='skip'
        ))
        
        # Add mean points
        fig.add_trace(go.Scatter(
            x=masked_indices,
            y=mean[common_masked],
            mode='markers',
            marker=dict(
                size=6,
                color=color,
                symbol='circle'
            ),
            name=f'Class {class_idx + 1}',
            hovertemplate=(
                "Neuron: %{x}<br>" +
                "Mean: %{y:.3f}<br>" +
                f"Range: %{{customdata[0]:.3f}} to %{{customdata[1]:.3f}}<br>" +
                f"Class: {class_idx + 1}"
            ),
            customdata=np.column_stack((
                lower_bound[common_masked],
                upper_bound[common_masked]
            ))
        ))

    # Add vertical lines for better visualization of neuron positions
    for idx in masked_indices:
        fig.add_shape(
            type="line",
            x0=idx,
            x1=idx,
            y0=fig.layout.yaxis.range[0] if fig.layout.yaxis.range else min(lower_bound[common_masked]),
            y1=fig.layout.yaxis.range[1] if fig.layout.yaxis.range else max(upper_bound[common_masked]),
            line=dict(color="lightgray", width=1, dash="dot")
        )

    fig.update_layout(
        title=title,
        xaxis_title='Neuron Index',
        yaxis_title='Activation Value',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        )
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    return fig

def visualize_common_masked(all_fc_vals, all_masks, tao=0.2):
    """
    Create and display visualization for common masked neurons
    """
    out = Output()
    with out:
        fig = create_common_masked_plot(
            all_fc_vals,
            all_masks,
            tao=tao,
            title='Common Masked Neurons Across Classes'
        )
        if fig is not None:
            display(fig)
    
    display(out)


# Example usage:

# Assuming you have your data ready:
# all_fc_vals: list of activation values for each class
# all_masks: list of masks computed from compute_masks function

# First, compute the masks if you haven't already
all_masks = []
for fc1 in all_fc_vals:
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc1, 0.5)
    all_masks.append(mask_max)  # or whichever mask type you want to visualize

# Then create the visualization
visualize_common_masked(all_fc_vals, all_masks, tao=2)

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
from ipywidgets import VBox, Output
from IPython.display import display

def find_common_masked_neurons(masks_per_class):
    """
    Find neurons that are masked (inactive) in all classes
    """
    bool_masks = []
    for mask in masks_per_class:
        if torch.is_tensor(mask):
            mask = mask.cpu().numpy()
        bool_masks.append(mask.astype(np.bool_))
    
    common_masked = bool_masks[0]
    for mask in bool_masks[1:]:
        common_masked = common_masked & mask
    
    return common_masked

def create_common_masked_boxplot(fc_vals_per_class, masks_per_class, title="Common Masked Neurons Box Plot"):
    """
    Create a box plot visualization of activation distributions for commonly masked neurons across classes.
    """
    # Find common masked neurons
    common_masked = find_common_masked_neurons(masks_per_class)
    masked_indices = np.where(common_masked)[0]
    
    if len(masked_indices) == 0:
        print("No common masked neurons found across all classes!")
        return None
        
    print(f"Found {len(masked_indices)} neurons masked across all classes")
    print(f"Indices of commonly masked neurons: {masked_indices}")
    
    # Colors for different classes
    colors = [
        'rgb(31, 119, 180)',   # blue
        'rgb(255, 127, 14)',   # orange
        'rgb(44, 160, 44)',    # green
        'rgb(214, 39, 40)',    # red
        'rgb(148, 103, 189)',  # purple
        'rgb(140, 86, 75)',    # brown
        'rgb(227, 119, 194)',  # pink
        'rgb(127, 127, 127)'   # gray
    ]
    
    fig = go.Figure()
    
    # Process each neuron
    for i, neuron_idx in enumerate(masked_indices):
        # Collect activation values for this neuron across all classes
        for class_idx, fc_vals in enumerate(fc_vals_per_class):
            if torch.is_tensor(fc_vals):
                fc_vals = fc_vals.cpu().numpy()
            
            # Convert fc_vals to numpy array if it's a list
            if isinstance(fc_vals, list):
                fc_vals = np.array(fc_vals)
            
            # Check if fc_vals is 1D or 2D
            if len(fc_vals.shape) == 1:
                neuron_activations = fc_vals
            else:
                neuron_activations = fc_vals[:, neuron_idx]
            
            fig.add_trace(go.Box(
                y=neuron_activations,
                x=[f"Neuron {neuron_idx}"] * len(neuron_activations),
                name=f"Class {class_idx + 1}",
                boxpoints='outliers',
                marker_color=colors[class_idx % len(colors)],
                showlegend=True if i == 0 else False,  # Show legend only for first neuron
                hovertemplate=(
                    "Neuron: %{x}<br>" +
                    "Q1: %{q1:.3f}<br>" +
                    "Median: %{median:.3f}<br>" +
                    "Q3: %{q3:.3f}<br>" +
                    "Min: %{lowerfence:.3f}<br>" +
                    "Max: %{upperfence:.3f}<br>" +
                    f"Class: {class_idx + 1}"
                )
            ))

    # Update layout
    fig.update_layout(
        title=title,
        xaxis_title='Neuron Index',
        yaxis_title='Activation Value',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        ),
        boxmode='group'  # Group boxes for each neuron
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    return fig

def visualize_common_masked_boxplot(all_fc_vals, all_masks):
    """
    Create and display box plot visualization for common masked neurons
    """
    out = Output()
    with out:
        fig = create_common_masked_boxplot(
            all_fc_vals,
            all_masks,
            title='Common Masked Neurons Distribution Across Classes'
        )
        if fig is not None:
            display(fig)
    
    display(out)

# Example usage:
# all_fc_vals: list of activation values for each class
# all_masks: list of masks computed from compute_masks function

# First, compute the masks if you haven't already
all_masks = []
for fc1 in all_fc_vals:
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc1, 0.5)
    all_masks.append(mask_max)  # or whichever mask type you want to visualize

# Then create the visualization
visualize_common_masked_boxplot(all_fc_vals, all_masks)

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
from itertools import combinations
from ipywidgets import VBox, Output
from IPython.display import display

def find_common_masked_neurons(masks_per_class):
    """Find neurons that are masked in all classes"""
    bool_masks = [mask.cpu().numpy().astype(np.bool_) if torch.is_tensor(mask) 
                 else mask.astype(np.bool_) for mask in masks_per_class]
    common_masked = bool_masks[0]
    for mask in bool_masks[1:]:
        common_masked = common_masked & mask
    return common_masked

def compute_ranges(fc_vals_per_class, common_masked, tao=0.2):
    """Compute activation ranges for each class"""
    ranges = []
    for fc_vals in fc_vals_per_class:
        if torch.is_tensor(fc_vals):
            fc_vals = fc_vals.cpu().numpy()
        mean = np.mean(fc_vals, axis=0)[common_masked]
        std = np.std(fc_vals, axis=0)[common_masked]
        ranges.append({
            'lower': mean - tao * std,
            'upper': mean + tao * std
        })
    return ranges

def compute_overlap_percentage(range1, range2):
    """Compute overlap percentage between two ranges"""
    overlap_percentages = []
    for i in range(len(range1['lower'])):
        lower = max(range1['lower'][i], range2['lower'][i])
        upper = min(range1['upper'][i], range2['upper'][i])
        
        if lower > upper:
            overlap_percentages.append(0)
        else:
            total_range = max(range1['upper'][i], range2['upper'][i]) - min(range1['lower'][i], range2['lower'][i])
            overlap = upper - lower
            overlap_percentages.append((overlap / total_range) * 100)
    
    return np.mean(overlap_percentages)

def create_overlap_heatmap(fc_vals_per_class, masks_per_class, tao=0.2):
    """Create a heatmap showing range overlaps between all class combinations"""
    common_masked = find_common_masked_neurons(masks_per_class)
    masked_indices = np.where(common_masked)[0]
    
    if len(masked_indices) == 0:
        print("No common masked neurons found!")
        return None
        
    print(f"Analyzing {len(masked_indices)} commonly masked neurons")
    
    # Compute ranges for each class
    ranges = compute_ranges(fc_vals_per_class, common_masked, tao)
    num_classes = len(fc_vals_per_class)
    
    # Initialize overlap matrix
    overlap_matrix = np.zeros((num_classes, num_classes))
    
    # Compute overlaps for all class pairs
    for i in range(num_classes):
        for j in range(num_classes):
            if i == j:
                overlap_matrix[i, j] = 100
            else:
                overlap_matrix[i, j] = compute_overlap_percentage(ranges[i], ranges[j])
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=overlap_matrix,
        x=[f'Class {i+1}' for i in range(num_classes)],
        y=[f'Class {i+1}' for i in range(num_classes)],
        colorscale='RdYlBu',
        colorbar=dict(title='Overlap %'),
        text=np.round(overlap_matrix, 1),
        texttemplate='%{text}%',
        textfont={"size": 12},
        hoverongaps=False,
    ))
    
    fig.update_layout(
        title='Range Overlap Percentage Between Classes (Common Masked Neurons)',
        width=800,
        height=800,
    )
    
    # Create bar plot for average overlap per class combination
    combinations_data = []
    for k in range(2, num_classes + 1):
        for combo in combinations(range(num_classes), k):
            overlap_sum = 0
            count = 0
            neuron_count = 0
            
            # For each neuron, check if it overlaps in all classes in the combination
            for neuron_idx in range(len(masked_indices)):
                overlaps = True
                min_upper = float('inf')
                max_lower = float('-inf')
                
                for class_idx in combo:
                    min_upper = min(min_upper, ranges[class_idx]['upper'][neuron_idx])
                    max_lower = max(max_lower, ranges[class_idx]['lower'][neuron_idx])
                
                if max_lower <= min_upper:
                    neuron_count += 1
            
            combo_classes = [f'Class {i+1}' for i in combo]
            combinations_data.append({
                'classes': ' + '.join(combo_classes),
                'num_classes': len(combo),
                'overlapping_neurons': neuron_count,
                'total_neurons': len(masked_indices),
                'percentage': (neuron_count / len(masked_indices)) * 100
            })
    
    # Create bar plot
    fig2 = go.Figure()
    
    # Sort combinations by number of classes and percentage
    combinations_data.sort(key=lambda x: (x['num_classes'], -x['percentage']))
    
    # Add bars
    fig2.add_trace(go.Bar(
        x=[d['classes'] for d in combinations_data],
        y=[d['percentage'] for d in combinations_data],
        text=[f"{d['percentage']:.1f}%<br>({d['overlapping_neurons']}/{d['total_neurons']})" 
              for d in combinations_data],
        textposition='auto',
        hovertemplate=(
            "Classes: %{x}<br>" +
            "Overlap: %{text}<br>" +
            "<extra></extra>"
        )
    ))
    
    fig2.update_layout(
        title='Percentage of Neurons with Overlapping Ranges by Class Combination',
        xaxis_title='Class Combinations',
        yaxis_title='Percentage of Neurons with Overlapping Ranges',
        width=1500,
        height=800,
        showlegend=False,
        xaxis_tickangle=45
    )
    
    return fig, fig2

def visualize_range_overlaps(all_fc_vals, all_masks, tao=0.2):
    """Create and display overlap visualizations"""
    out = Output()
    with out:
        figs = create_overlap_heatmap(all_fc_vals, all_masks, tao)
        if figs is not None:
            fig1, fig2 = figs
            print("\nHeatmap of pairwise range overlaps:")
            display(fig1)
            print("\nOverlap analysis for different class combinations:")
            display(fig2)
    
    display(out)
    
all_masks = []
for fc1 in all_fc_vals:
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc1, 0.5)
    all_masks.append(mask_max)
    
    
visualize_range_overlaps(all_fc_vals, all_masks, tao=2)


In [3]:
import torch
import numpy as np
import plotly.graph_objects as go
from itertools import combinations
from ipywidgets import VBox, Output
from IPython.display import display

def compute_ranges_and_overlaps(fc_vals_per_class, masks_per_class, tao=0.2):
    """
    Compute ranges and find overlapping neurons between different class combinations
    """
    num_classes = len(fc_vals_per_class)
    # Get number of neurons from the first activation array
    first_fc = fc_vals_per_class[0]
    if isinstance(first_fc, list):
        num_neurons = len(first_fc[0])
    else:
        num_neurons = first_fc.shape[1] if hasattr(first_fc, 'shape') else len(first_fc[0])
    
    # Convert all data to numpy arrays and compute ranges in one pass
    ranges = []
    masks_np = []
    
    for fc_vals, mask in zip(fc_vals_per_class, masks_per_class):
        # Convert and compute fc_vals statistics
        if isinstance(fc_vals, list):
            fc_vals = np.array(fc_vals)
        elif torch.is_tensor(fc_vals):
            fc_vals = fc_vals.cpu().numpy()
            
        mean = np.mean(fc_vals, axis=0)
        std = np.std(fc_vals, axis=0)
        ranges.append({
            'mean': mean,
            'lower': mean - tao * std,
            'upper': mean + tao * std
        })
        
        # Convert masks
        if torch.is_tensor(mask):
            mask = mask.cpu().numpy()
        masks_np.append(mask.astype(bool))
    
    # Compute overlaps efficiently
    overlap_info = {
        'neurons': np.zeros(num_neurons, dtype=int),
        'combinations': {}
    }

    # Pre-compute active neurons for each class
    active_neurons = [~mask for mask in masks_np]
    
    # For each neuron, find all classes where it's active and check overlaps
    for neuron in range(num_neurons):
        # Get classes where this neuron is active
        active_classes = [i for i, active in enumerate(active_neurons) if active[neuron]]
        
        if len(active_classes) >= 2:
            # Check if ranges overlap for active classes
            min_upper = min(ranges[cls_idx]['upper'][neuron] for cls_idx in active_classes)
            max_lower = max(ranges[cls_idx]['lower'][neuron] for cls_idx in active_classes)
            
            if max_lower <= min_upper:  # Ranges overlap
                overlap_info['neurons'][neuron] = len(active_classes)
                overlap_info['combinations'][neuron] = tuple(active_classes)

    return ranges, overlap_info

def create_detailed_range_plot(fc_vals_per_class, masks_per_class, tao=0.2, title="Activation Ranges with Overlap Analysis"):
    """
    Create visualization showing activation ranges and overlaps
    """
    print("Computing ranges and overlaps...")
    ranges, overlap_info = compute_ranges_and_overlaps(fc_vals_per_class, masks_per_class, tao)
    num_classes = len(fc_vals_per_class)
    num_neurons = len(overlap_info['neurons'])
    
    print("Creating visualization...")
    
    # Create color scheme for different overlap levels
    overlap_colors = {
        0: 'rgba(200, 200, 200, 0.1)',  # No overlap
        2: 'rgba(255, 200, 200, 0.2)',  # 2 classes
        3: 'rgba(255, 150, 150, 0.2)',  # 3 classes
        4: 'rgba(255, 100, 100, 0.2)',  # 4 classes
        5: 'rgba(255, 50, 50, 0.2)',    # 5 classes
        6: 'rgba(255, 0, 0, 0.2)'       # 6 classes
    }
    
    class_colors = [
        'rgb(31, 119, 180)',   # blue
        'rgb(255, 127, 14)',   # orange
        'rgb(44, 160, 44)',    # green
        'rgb(214, 39, 40)',    # red
        'rgb(148, 103, 189)',  # purple
        'rgb(140, 86, 75)',    # brown
    ]
    
    fig = go.Figure()
    
    # Add background color for overlap regions efficiently
    overlap_neurons = np.where(overlap_info['neurons'] >= 2)[0]
    if len(overlap_neurons) > 0:
        for n_classes in range(2, num_classes + 1):
            neurons = np.where(overlap_info['neurons'] == n_classes)[0]
            if len(neurons) > 0:
                fig.add_trace(go.Scatter(
                    x=np.concatenate([neurons - 0.5, neurons + 0.5, [neurons[-1] + 0.5, neurons[0] - 0.5]]),
                    y=[fig.layout.yaxis.range[0] if fig.layout.yaxis.range else -1] * (len(neurons) * 2 + 2),
                    fill="toself",
                    fillcolor=overlap_colors[n_classes],
                    line=dict(width=0),
                    showlegend=True,
                    name=f'{n_classes}-Class Overlap',
                    hoverinfo='skip'
                ))
    
    # Add ranges for each class
    for class_idx in range(num_classes):
        mask = masks_per_class[class_idx]
        if torch.is_tensor(mask):
            mask = mask.cpu().numpy()
        mask = mask.astype(bool)
        
        active_indices = np.where(~mask)[0]
        
        if len(active_indices) > 0:
            fig.add_trace(go.Scatter(
                x=np.concatenate([active_indices, active_indices[::-1]]),
                y=np.concatenate([
                    ranges[class_idx]['upper'][active_indices],
                    ranges[class_idx]['lower'][active_indices][::-1]
                ]),
                fill='toself',
                fillcolor=class_colors[class_idx].replace('rgb', 'rgba').replace(')', ', 0.2)'),
                line=dict(color='rgba(0,0,0,0)'),
                name=f'Range Class {class_idx + 1}',
                showlegend=True,
                hoverinfo='skip'
            ))
            
            fig.add_trace(go.Scatter(
                x=active_indices,
                y=ranges[class_idx]['mean'][active_indices],
                mode='markers',
                marker=dict(
                    size=6,
                    color=class_colors[class_idx],
                    symbol='circle'
                ),
                name=f'Mean Class {class_idx + 1}',
                hovertemplate=(
                    "Neuron: %{x}<br>" +
                    "Mean: %{y:.3f}<br>" +
                    f"Range: %{{customdata[0]:.3f}} to %{{customdata[1]:.3f}}<br>" +
                    f"Class: {class_idx + 1}<br>" +
                    "Overlapping Classes: %{customdata[2]}"
                ),
                customdata=np.column_stack((
                    ranges[class_idx]['lower'][active_indices],
                    ranges[class_idx]['upper'][active_indices],
                    [', '.join([f'Class {i+1}' for i in overlap_info['combinations'].get(idx, [])]) 
                     if idx in overlap_info['combinations'] else 'None'
                     for idx in active_indices]
                ))
            ))
    
    fig.update_layout(
        title=title,
        xaxis_title='Neuron Index',
        yaxis_title='Activation Value',
        width=1500,
        height=800,
        plot_bgcolor='white',
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        )
    )
    
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='LightGrey')
    
    print("Visualization complete!")
    return fig

def visualize_ranges_with_overlaps(all_fc_vals, all_masks, tao=0.2):
    """
    Create and display the visualization
    """
    print("Starting visualization...")
    out = Output()
    with out:
        fig = create_detailed_range_plot(all_fc_vals, all_masks, tao)
        if fig is not None:
            display(fig)
    
    display(out)
    
all_masks = []
for fc1 in all_fc_vals:
    mask_max, mask_std, mask_intersection, mask_max_low_std, mask_max_high_std, mask_std_high_max = compute_masks(fc1, 0.05)
    all_masks.append(mask_max)
    
visualize_ranges_with_overlaps(all_fc_vals, all_masks, tao=2)

Starting visualization...


Output()

In [7]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
 
 
 
def find_common_neurons(all_masked_indices):
    """
    Find neuron indices common across all classes
    """
    # Use set intersection to find common indices
    common_neurons = set(all_masked_indices[0])
    for indices in all_masked_indices[1:]:
        common_neurons = common_neurons.intersection(indices)
    
    return list(common_neurons)
 
def prepare_neuron_data(all_fc_vals, neurons_to_plot,classes_chosen):
    """
    Prepare activation data for common neurons across all classes
    """
    plot_data = []
    
    # Extract masks for each class
    masks = [compute_masks(activations, 0.50)[0] for activations in [all_fc_vals[i] for i in classes_chosen]]
    
    for class_label, (activations, mask) in enumerate(zip([all_fc_vals[i] for i in classes_chosen], masks)):
        # Get neuron indices for this class based on mask where mask is zero (indicating selected neurons)
        class_neuron_indices = np.where(mask == 0)[0]  # These are the neuron indices where the mask is 0
        activations = np.array(activations)  # Ensure activations is a NumPy array
        
        print(neurons_to_plot)
        
        # Filter to only common neurons
        common_neuron_mask = np.isin(class_neuron_indices, neurons_to_plot)  # Find neurons in common_neuron_mask
        common_neuron_mask = np.where(common_neuron_mask, 0, 1).astype(int)  # Convert True/False to 0/1
        
        # Get the indices where mask is 0 for common neurons
        common_neuron_indices = class_neuron_indices[common_neuron_mask == 0]  # Only take indices where mask is 0
        
        # Extract activations for common neurons
        filtered_activations = activations[:, common_neuron_indices]  # Use these indices to filter activations
        
        # Create DataFrame with both neuron index and class label
        df = pd.DataFrame(filtered_activations)
        df = df.melt(var_name='Neuron_Local', value_name='Activation')
        
        # Map local neuron indices back to original neuron numbers
        df['Neuron'] = df['Neuron_Local'].apply(lambda x: common_neuron_indices[x])
        df['Class'] = class_label
        
        plot_data.append(df)
 
   
    
    # Combine data for all classes
    return pd.concat(plot_data, ignore_index=True)
 
def plot_overlapping_neurons(all_fc_vals, max_neurons_to_plot,classes_chosen):
    # Extract max masks for each class
    all_masked_indices = []
    
    for class_label, activations in enumerate([all_fc_vals[i] for i in classes_chosen]):
        # Compute masks for the current class's activations
        mask_max, _, _, _, _, _ = compute_masks(activations, 0.50)
        
        # Use max mask and get its neuron indices
        class_neuron_indices = np.where(mask_max == 0)[0]
        class_neuron_indices =[int(neuron) for neuron in class_neuron_indices]
        all_masked_indices.append(class_neuron_indices)
    
    # Find common neuron indices across classes
    common_neurons = find_common_neurons(all_masked_indices)
    common_neurons=[int(neuron) for neuron in common_neurons]
    print("Commo|n nuerons",common_neurons)
    # n_neurons_to_plot = np.random.randint(1, min(len(common_neurons), max_neurons_to_plot) + 1)
    
    # Randomly select neurons to plot
    neurons_to_plot = np.random.choice(common_neurons, size=min(len(common_neurons), max_neurons_to_plot), replace=False)
    # Select top n common neurons to plot
    # neurons_to_plot = common_neurons[:n_neurons_to_plot]
    neurons_to_plot = [int(neuron) for neuron in neurons_to_plot]
    print(neurons_to_plot)
    
    # Prepare data for plotting
    plot_df = prepare_neuron_data(all_fc_vals, neurons_to_plot,classes_chosen)
    
    # Create plot
    plt.figure(figsize=(25, 10))
    palette = sns.color_palette("Set2", n_colors=len(plot_df['Class'].unique()))
    
    sns.boxplot(data=plot_df, x='Neuron', y='Activation', hue='Class',
                showfliers=False, palette=palette)
    # neuron_indices = [0, 4, 6, 12]  # Example indices where dashed lines should appear
    for pos in range(len(neurons_to_plot)):  # Use `neurons_to_plot` here
        plt.axvline(x=pos+0.5, color='black', linestyle='--', linewidth=1)
        
    plt.title(f'Activation Box Plot for {len(neurons_to_plot)} Common Neurons', fontsize=18)
    plt.xlabel('Neuron Index', fontsize=14)
    plt.ylabel('Activation Value', fontsize=14)
    plt.legend(title='Class', fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
    
    return neurons_to_plot
np.random.seed(2)#21
classes_chosen=np.random.choice([0,1,2,3,4,5,6,7,8,9,10,11,12,13], size=4, replace=False)
print(classes_chosen)
# Call the function with your activations
plot_neurons = plot_overlapping_neurons(all_fc_vals, max_neurons_to_plot=15,classes_chosen=classes_chosen)
print(f"Plotted neuron indices:", plot_neurons)

[11  4  5  0]


IndexError: list index out of range