In [2]:
import torch
import torch.nn as nn
from datasets import IndexedDataset, load_data
from torch.utils.data import DataLoader

from utils import get_args
from architectures import load_architecture

from architectures import load_architecture, add_lora, set_lora_gradients #load_statedict

from torch.utils.data import DataLoader
import numpy as np

import math
from functools import partial
from typing import Union, Tuple  # Import Union and Tuple for type annotations

import torch.nn.functional as F
from torch import optim

@torch.inference_mode()
def _get_masks(activations, tau:float, ineq_type:str) -> torch.Tensor:
    """
    Computes the ReDo mask for a given set of activations.
    The returned mask has True where neurons are dormant and False where they are active.
    """
    masks = []

    for name, activation in list( activations.items() ):
        # Taking the mean here conforms to the expectation under D in the main paper's formula
        if activation.ndim == 4:
            # Conv layer
            score = activation.abs().mean( dim=(0, 2, 3) )
        else:
            # Linear layer
            score = activation.abs().mean(dim=0)

        # print('score', score)
        # Divide by activation mean to make the threshold independent of the layer size
        # see https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/labs/redo/weight_recyclers.py#L314
        # https://github.com/google/dopamine/issues/209

        normalized_score = score / (score.mean() + 1e-9)
        layer_mask = torch.zeros_like(normalized_score, dtype=torch.bool)

        if tau > 0.0 and ineq_type == 'leq':
            layer_mask[normalized_score <= tau] = 1
        elif tau > 0.0 and ineq_type == 'geq':
            layer_mask[normalized_score >= tau] = 1
        else:
            layer_mask[ torch.isclose( normalized_score, torch.zeros_like(normalized_score) ) ] = 1

        masks.append(layer_mask)

    return masks

@torch.inference_mode()
def _get_activation(name: str, activations):
    """Fetches and stores the activations of a network layer."""

    def hook(layer: Union[nn.Linear, nn.Conv2d], input: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
        """
        Get the activations of a layer with ReLU nonlinearity.
        ReLU has to be called explicitly here because the hook is attached to the conv/linear layer.
        """
        activations[name] = F.relu(output)

    return hook

@torch.inference_mode()
def run_redo(
    obs: torch.Tensor,
    model: nn.Module,
):
    """
    Checks the number of dormant neurons for a given model.

    Returns the number of dormant neurons.
    """

    #print('step1')

    activations = {}
    activation_getter = partial(_get_activation, activations=activations)

    #print('step2')

    # Register hooks for all Conv2d and Linear layers to calculate activations
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            handles.append( module.register_forward_hook(activation_getter(name)) )
    # print(handles)
    #print('step3')

    # Calculate activations
    _ = model( obs )  # Add batch dimension if necessary,  .unsqueeze(0)

    #print('step4')

    # Masks for tau=0 logging
    zero_masks = _get_masks(activations, 0.0, 'leq')
    total_neurons = sum([torch.numel(mask) for mask in zero_masks])
    zero_count = sum([torch.sum(mask) for mask in zero_masks])
    zero_fraction = zero_count / total_neurons

    # Calculate the masks actually used for resetting
    masks = _get_masks(activations, 0.01, 'leq')
    dormant_count = sum([torch.sum(mask) for mask in masks])
    total_neurons = sum([torch.numel(mask) for mask in masks])
    dormant_fraction = dormant_count / total_neurons 

    # Calculate the masks actually used for resetting
    masks = _get_masks(activations, 3, 'geq')
    overactive_count = sum([torch.sum(mask) for mask in masks])
    total_neurons = sum([torch.numel(mask) for mask in masks])
    overactive_fraction = overactive_count / total_neurons 
    
    # print(dormant_count, total_neurons, dormant_fraction)

    # Remove the hooks again
    for handle in handles:
        handle.remove()

    return {
        "total_neurons":total_neurons,
        "zero_fraction": zero_fraction.item(),
        "zero_count": zero_count.item(),
        "dormant_fraction": dormant_fraction.item(),
        "dormant_count": dormant_count.item(),
        "overactive_fraction": overactive_fraction.item(),
        "overactive_count": overactive_count.item(),
    }


In [3]:
args = get_args()

args.iterations = 20
args.pruning_ratio = 0
args.delta = 1

args.dataset = 'EuroSAT'
args.selection_method = 'random'
args.data_dir = './data/EuroSAT'

train_dataset, val_dataset, test_dataset, N, train_transform, transform = load_data(args) 

test_dataset = IndexedDataset(args, test_dataset, transform, N,)  

testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

backbones = [ 
            #   'convnext_base',  'convnext_base.fb_in22k', 'robust_convnext_base',
            #   'convnext_tiny',  'convnext_tiny.fb_in22k', 'robust_convnext_tiny',

            #   'robust_wideresnet_28_10', 'wideresnet_28_10',

              'deit_small_patch16_224.fb_in1k',
              'robust_deit_small_patch16_224',
              
              'vit_base_patch16_224.augreg_in1k',
              'vit_base_patch16_224.augreg_in21k',
              'robust_vit_base_patch16_224' ]

backbones_result = {}
for backbone_name in backbones:
    # print(backbone_name)

    backbone_result = {
        "total_neurons":[],
        "zero_fraction": [],
        "zero_count": [],
        "dormant_fraction": [],
        "dormant_count": [],
        "overactive_fraction": [],
        "overactive_count": [],
    }

    args.backbone = backbone_name
    model = load_architecture(args, N, rank = 0)

    # print(model)
    count = 0
    for obs,labal,idx in testloader:

        result = run_redo(obs, model) # compute amount of zero, dormant and overactive neurons
         
        backbone_result["total_neurons"].append(result["total_neurons"])
        backbone_result["zero_fraction"].append(result["zero_fraction"])
        backbone_result["zero_count"].append(result["zero_count"])
        backbone_result["dormant_fraction"].append(result["dormant_fraction"])
        backbone_result["dormant_count"].append(result["dormant_count"])
        backbone_result["overactive_fraction"].append(result["overactive_fraction"])
        backbone_result["overactive_count"].append(result["overactive_count"])

        count+=1
        if count==100:
            break

    backbones_result[backbone_name] = { "mean_total_neurons":np.mean(backbone_result["total_neurons"]),
                                   "mean_zero_fraction":np.mean(backbone_result["zero_fraction"]),
                                   "mean_zero_count":np.mean(backbone_result["zero_count"]),
                                   "mean_dormant_fraction":np.mean(backbone_result["dormant_fraction"]),
                                   "mean_dormant_count": np.mean(backbone_result["dormant_count"]),
                                   "mean_overactive_fraction":np.mean(backbone_result["overactive_fraction"]),
                                   "mean_overactive_count": np.mean(backbone_result["overactive_count"]),

                                   "std_total_neurons":np.std(backbone_result["total_neurons"]),
                                   "std_zero_fraction":np.std(backbone_result["zero_fraction"]),
                                   "std_zero_count":np.std(backbone_result["zero_count"]),
                                   "std_dormant_fraction":np.std(backbone_result["dormant_fraction"]),
                                   "std_dormant_count": np.std(backbone_result["dormant_count"]),
                                   "std_overactive_fraction":np.std(backbone_result["overactive_fraction"]),
                                   "std_overactive_count": np.std(backbone_result["overactive_count"]),
                  
                                 }
    
    print(backbones_result)

./data
{'deit_small_patch16_224.fb_in1k': {'mean_total_neurons': 8170378.0, 'mean_zero_fraction': 0.6889565885066986, 'mean_zero_count': 5629035.74, 'mean_dormant_fraction': 0.6900521737337112, 'mean_dormant_count': 5637987.1, 'mean_overactive_fraction': 0.0880221626162529, 'mean_overactive_count': 719174.34, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.0020436314216797278, 'std_zero_count': 16697.249924236025, 'std_dormant_fraction': 0.002075395120102153, 'std_dormant_count': 16956.76644970969, 'std_overactive_fraction': 0.0015207764060526888, 'std_overactive_count': 12425.316840402906}}
{'deit_small_patch16_224.fb_in1k': {'mean_total_neurons': 8170378.0, 'mean_zero_fraction': 0.6889565885066986, 'mean_zero_count': 5629035.74, 'mean_dormant_fraction': 0.6900521737337112, 'mean_dormant_count': 5637987.1, 'mean_overactive_fraction': 0.0880221626162529, 'mean_overactive_count': 719174.34, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.0020436314216797278, 'std_zero_count': 16697.24

In [4]:
import pandas as pd

df = pd.DataFrame.from_dict(backbones_result).T

df.to_csv( 'neurons_results_{}.csv'.format(args.dataset) )

In [None]:
import pandas as pd

dataset = 'EuroSAT'

df = pd.read_csv('neurons_results_{}.csv'.format(dataset), index_col=0)

# Replace underscores with spaces in column names and data
df.columns = df.columns.str.replace('_', ' ')
df = df.replace('_', ' ', regex=True)
df.index = df.index.str.replace('_', ' ')

# Rounding a DataFrame to 3 decimal places
df = df.round(3)

latex_code = df.to_latex(
    index=True,
    formatters={"name": str.upper},
    float_format="{}".format
)

# Print the LaTeX code
print(latex_code)


\begin{tabular}{lrrrrrrrrrrrrrr}
\toprule
 & mean total neurons & mean zero fraction & mean zero count & mean dormant fraction & mean dormant count & mean overactive fraction & mean overactive count & std total neurons & std zero fraction & std zero count & std dormant fraction & std dormant count & std overactive fraction & std overactive count \\
\midrule
convnext base & 21280.0 & 0.047 & 1004.16 & 0.163 & 3469.93 & 0.056 & 1200.18 & 0.0 & 0.017 & 359.605 & 0.035 & 737.983 & 0.005 & 112.553 \\
convnext base.fb in22k & 21280.0 & 0.043 & 921.53 & 0.158 & 3354.2 & 0.056 & 1183.79 & 0.0 & 0.017 & 355.33 & 0.035 & 742.071 & 0.005 & 112.939 \\
robust convnext base & 21280.0 & 0.103 & 2194.19 & 0.257 & 5472.24 & 0.069 & 1463.75 & 0.0 & 0.053 & 1127.165 & 0.077 & 1641.495 & 0.012 & 262.539 \\
convnext tiny & 8872.0 & 0.041 & 360.27 & 0.126 & 1115.94 & 0.049 & 432.75 & 0.0 & 0.017 & 154.956 & 0.037 & 325.717 & 0.007 & 57.75 \\
convnext tiny.fb in22k & 8872.0 & 0.062 & 551.18 & 0.175 & 1550.68

In [5]:
df

Unnamed: 0,mean_total_neurons,mean_zero_fraction,mean_zero_count,mean_dormant_fraction,mean_dormant_count,mean_overactive_fraction,mean_overactive_count,std_total_neurons,std_zero_fraction,std_zero_count,std_dormant_fraction,std_dormant_count,std_overactive_fraction,std_overactive_count
convnext_base,21280.0,0.047139,1003.12,0.163013,3468.92,0.056383,1199.83,0.0,0.016902,359.664602,0.034687,738.129524,0.005287,112.49676
convnext_base.fb_in22k,21280.0,0.043351,922.5,0.157668,3355.18,0.055629,1183.79,0.0,0.01666,354.518956,0.03483,741.185609,0.005301,112.813678
robust_convnext_base,21280.0,0.103116,2194.3,0.25716,5472.37,0.068804,1464.14,0.0,0.052942,1126.612387,0.077115,1641.006323,0.012328,262.334444
convnext_tiny,8872.0,0.040629,360.46,0.125803,1116.12,0.048785,432.82,0.0,0.017429,154.629584,0.036688,325.500024,0.006506,57.721292
convnext_tiny.fb_in22k,8872.0,0.062261,552.38,0.174923,1551.92,0.057359,508.89,0.0,0.020163,178.889227,0.03988,353.814434,0.005824,51.673571
robust_convnext_tiny,8872.0,0.183944,1631.95,0.318894,2829.23,0.054776,485.97,0.0,0.062596,555.355244,0.067971,603.036,0.003891,34.522588
robust_wideresnet_28_10,10106.0,0.002412,24.38,0.029793,301.09,0.0154,155.63,0.0,0.000455,4.599522,0.002347,23.719652,0.001168,11.808179
wideresnet_28_10,10106.0,0.002266,22.9,0.026761,270.45,0.008736,88.29,0.0,0.000831,8.401786,0.001652,16.699925,0.000352,3.556107
deit_small_patch16_224.fb_in1k,8170378.0,0.688957,5629035.05,0.699874,5718238.12,0.088022,719174.09,0.0,0.002044,16697.380857,0.002359,19276.447727,0.001521,12425.288639
robust_deit_small_patch16_224,8170378.0,0.692757,5660087.78,0.705982,5768142.5,0.081712,667616.0,0.0,0.005185,42366.565605,0.00542,44279.603482,0.003662,29923.361248
