In [1]:
import torch
from datasets import IndexedDataset, WeightedDataset, load_data
from torch.utils.data import DataLoader, DistributedSampler

from utils import get_args
from architectures import load_architecture

from samplers import DistributedCustomSampler
from losses import trades_loss, apgd_loss
from tqdm.notebook import tqdm
from architectures import load_architecture, add_lora, set_lora_gradients #load_statedict

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
from datasets.semisupervised_dataset import SemiSupervisedDataset
from torch.utils.data import TensorDataset
import random
from torch.utils.data import Subset
from torch.utils.data import random_split
from datasets.eurosat import EuroSATDataset

from sklearn.model_selection import train_test_split

import math
from functools import partial
from typing import Union, Tuple  # Import Union and Tuple for type annotations

import torch.nn.functional as F
from torch import optim

from torch.utils.data import DataLoader
import numpy as np


In [11]:
@torch.inference_mode()
def _get_masks(activations, tau:float, ineq_type:str) -> torch.Tensor:
    """
    Computes the ReDo mask for a given set of activations.
    The returned mask has True where neurons are dormant and False where they are active.
    """
    masks = []

    for name, activation in list( activations.items() ):
        # Taking the mean here conforms to the expectation under D in the main paper's formula
        if activation.ndim == 4:
            # Conv layer
            score = activation.abs().mean( dim=(0, 2, 3) )
        else:
            # Linear layer
            score = activation.abs().mean(dim=0)

        # print('score', score)
        # Divide by activation mean to make the threshold independent of the layer size
        # see https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/labs/redo/weight_recyclers.py#L314
        # https://github.com/google/dopamine/issues/209

        normalized_score = score / (score.mean() + 1e-9)
        layer_mask = torch.zeros_like(normalized_score, dtype=torch.bool)

        if tau > 0.0 and ineq_type == 'leq':
            layer_mask[normalized_score <= tau] = 1
        elif tau > 0.0 and ineq_type == 'geq':
            layer_mask[normalized_score >= tau] = 1
        else:
            layer_mask[ torch.isclose( normalized_score, torch.zeros_like(normalized_score) ) ] = 1

        masks.append(layer_mask)

    return masks

@torch.inference_mode()
def _get_activation(name: str, activations):
    """Fetches and stores the activations of a network layer."""

    def hook(layer: Union[nn.Linear, nn.Conv2d], input: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
        """
        Get the activations of a layer with ReLU nonlinearity.
        ReLU has to be called explicitly here because the hook is attached to the conv/linear layer.
        """
        activations[name] = F.relu(output)

    return hook

@torch.inference_mode()
def run_redo(
    obs: torch.Tensor,
    model: nn.Module,
):
    """
    Checks the number of dormant neurons for a given model.

    Returns the number of dormant neurons.
    """

    #print('step1')

    activations = {}
    activation_getter = partial(_get_activation, activations=activations)

    #print('step2')

    # Register hooks for all Conv2d and Linear layers to calculate activations
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            handles.append( module.register_forward_hook(activation_getter(name)) )
    # print(handles)
    #print('step3')

    # Calculate activations
    _ = model( obs )  # Add batch dimension if necessary,  .unsqueeze(0)

    #print('step4')

    # Masks for tau=0 logging
    zero_masks = _get_masks(activations, 0.0, 'leq')
    total_neurons = sum([torch.numel(mask) for mask in zero_masks])
    zero_count = sum([torch.sum(mask) for mask in zero_masks])
    zero_fraction = zero_count / total_neurons

    # Calculate the masks actually used for resetting
    masks = _get_masks(activations, 0.1, 'leq')
    dormant_count = sum([torch.sum(mask) for mask in masks])
    total_neurons = sum([torch.numel(mask) for mask in masks])
    dormant_fraction = dormant_count / total_neurons 

    # Calculate the masks actually used for resetting
    masks = _get_masks(activations, 3, 'geq')
    overactive_count = sum([torch.sum(mask) for mask in masks])
    total_neurons = sum([torch.numel(mask) for mask in masks])
    overactive_fraction = overactive_count / total_neurons 
    
    # print(dormant_count, total_neurons, dormant_fraction)

    # Remove the hooks again
    for handle in handles:
        handle.remove()

    return {
        "total_neurons":total_neurons,
        "zero_fraction": zero_fraction.item(),
        "zero_count": zero_count.item(),
        "dormant_fraction": dormant_fraction.item(),
        "dormant_count": dormant_count.item(),
        "overactive_fraction": overactive_fraction.item(),
        "overactive_count": overactive_count.item(),
    }

In [12]:
args = get_args()

args.iterations = 20
args.pruning_ratio = 0
args.delta = 1

args.dataset = 'CIFAR10'
args.selection_method = 'random'

train_dataset, val_dataset, test_dataset, N, train_transform, transform = load_data(args) 

test_dataset = IndexedDataset(args, test_dataset, transform, N,)  

testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

backbones = [ 
              'convnext_base',  'convnext_base.fb_in22k', 'robust_convnext_base',
              'convnext_tiny',  'convnext_tiny.fb_in22k', 'robust_convnext_tiny',

              'robust_wideresnet_28_10', 

              'deit_small_patch16_224.fb_in1k',
              'robust_deit_small_patch16_224',
              
              'vit_base_patch16_224.augreg_in1k',
              'vit_base_patch16_224.augreg_in21k',
              'robust_vit_base_patch16_224' ]

# Assuming 'train_dataset' is defined and contains your training data
# If 'train_dataset[0][0]' is an image tensor, make sure it's properly preprocessed



backbones_result = {}
for backbone_name in backbones:
    # print(backbone_name)

    backbone_result = {
        "total_neurons":[],
        "zero_fraction": [],
        "zero_count": [],
        "dormant_fraction": [],
        "dormant_count": [],
        "overactive_fraction": [],
        "overactive_count": [],
    }

    args.backbone = backbone_name
    model = load_architecture(args,)

    # print(model)
    count = 0
    for obs,labal,idx in testloader:

        result = run_redo(obs, model)
        backbone_result["total_neurons"].append(result["total_neurons"])
        backbone_result["zero_fraction"].append(result["zero_fraction"])
        backbone_result["zero_count"].append(result["zero_count"])
        backbone_result["dormant_fraction"].append(result["dormant_fraction"])
        backbone_result["dormant_count"].append(result["dormant_count"])
        backbone_result["overactive_fraction"].append(result["overactive_fraction"])
        backbone_result["overactive_count"].append(result["overactive_count"])

        count+=1
        if count==100:
            break

    backbones_result[backbone_name] = { "mean_total_neurons":np.mean(backbone_result["total_neurons"]),
                                   "mean_zero_fraction":np.mean(backbone_result["zero_fraction"]),
                                   "mean_zero_count":np.mean(backbone_result["zero_count"]),
                                   "mean_dormant_fraction":np.mean(backbone_result["dormant_fraction"]),
                                   "mean_dormant_count": np.mean(backbone_result["dormant_count"]),
                                   "mean_overactive_fraction":np.mean(backbone_result["overactive_fraction"]),
                                   "mean_overactive_count": np.mean(backbone_result["overactive_count"]),

                                   "std_total_neurons":np.std(backbone_result["total_neurons"]),
                                   "std_zero_fraction":np.std(backbone_result["zero_fraction"]),
                                   "std_zero_count":np.std(backbone_result["zero_count"]),
                                   "std_dormant_fraction":np.std(backbone_result["dormant_fraction"]),
                                   "std_dormant_count": np.std(backbone_result["dormant_count"]),
                                   "std_overactive_fraction":np.std(backbone_result["overactive_fraction"]),
                                   "std_overactive_count": np.std(backbone_result["overactive_count"]),
                                    
                                 }
    
    print(backbones_result)



./data
train size 47500 val size 2500
Files already downloaded and verified
{'convnext_base': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.036011278182268146, 'mean_zero_count': 766.32, 'mean_dormant_fraction': 0.13409398503601552, 'mean_dormant_count': 2853.52, 'mean_overactive_fraction': 0.05095300741493702, 'mean_overactive_count': 1084.28, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.0046124065733678336, 'std_zero_count': 98.15201271497186, 'std_dormant_fraction': 0.011446242472055803, 'std_dormant_count': 243.57604479915506, 'std_overactive_fraction': 0.0026351938938786726, 'std_overactive_count': 56.07692573599234}}
{'convnext_base': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.036011278182268146, 'mean_zero_count': 766.32, 'mean_dormant_fraction': 0.13409398503601552, 'mean_dormant_count': 2853.52, 'mean_overactive_fraction': 0.05095300741493702, 'mean_overactive_count': 1084.28, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.0046124065733678336, 'std

In [23]:

backbones = [ 
              'convnext_base',  'convnext_base.fb_in22k', 'robust_convnext_base',
              'convnext_tiny',  'convnext_tiny.fb_in22k', 'robust_convnext_tiny',

              'robust_wideresnet_28_10', 

              'deit_small_patch16_224.fb_in1k',
              'robust_deit_small_patch16_224',
              
              'vit_base_patch16_224.augreg_in1k',
              'vit_base_patch16_224.augreg_in21k',
              'robust_vit_base_patch16_224' ]

args = get_args()

backbone_name = 'robust_vit_base_patch16_224'

args.backbone = backbone_name

args.iterations = 20
args.pruning_ratio = 0
args.delta = 1

args.dataset = 'Aircraft'
args.selection_method = 'random'

train_dataset, val_dataset, test_dataset, N, train_transform, transform = load_data(args) 

test_dataset = IndexedDataset(args, test_dataset, transform, N,)  

testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

model = load_architecture(args,)




./data


In [24]:
model(testloader.dataset[0][0].unsqueeze(0)).shape

torch.Size([1, 100])

In [13]:
import pandas as pd

df = pd.DataFrame.from_dict(backbones_result).T

df.to_csv( 'neurons_results_{}.csv'.format(args.dataset) )

In [14]:
df

Unnamed: 0,mean_total_neurons,mean_zero_fraction,mean_zero_count,mean_dormant_fraction,mean_dormant_count,mean_overactive_fraction,mean_overactive_count,std_total_neurons,std_zero_fraction,std_zero_count,std_dormant_fraction,std_dormant_count,std_overactive_fraction,std_overactive_count
convnext_base,21280.0,0.036011,766.32,0.134094,2853.52,0.050953,1084.28,0.0,0.004612,98.152013,0.011446,243.576045,0.002635,56.076926
convnext_base.fb_in22k,21280.0,0.031231,664.6,0.127022,2703.02,0.049787,1059.46,0.0,0.00426,90.645684,0.011726,249.533324,0.002725,57.997659
robust_convnext_base,21280.0,0.048816,1038.8,0.170796,3634.54,0.055089,1172.29,0.0,0.02989,636.049196,0.055938,1190.354371,0.009465,201.407413
convnext_tiny,8872.0,0.028376,251.75,0.096775,858.59,0.041843,371.23,0.0,0.004236,37.583607,0.009927,88.068507,0.00212,18.806837
convnext_tiny.fb_in22k,8872.0,0.049092,435.54,0.140845,1249.58,0.051425,456.24,0.0,0.004844,42.975439,0.010921,96.890988,0.002306,20.461241
robust_convnext_tiny,8872.0,0.117084,1038.77,0.228022,2023.01,0.049702,440.96,0.0,0.015022,133.272192,0.022983,203.905051,0.002287,20.286409
robust_wideresnet_28_10,10106.0,0.003137,31.7,0.032006,323.45,0.016716,168.93,0.0,0.000663,6.7,0.002625,26.529371,0.001281,12.949328
deit_small_patch16_224.fb_in1k,8170378.0,0.687746,5619148.32,0.69848,5706847.73,0.088602,723914.23,0.0,0.001164,9514.074414,0.001291,10546.018658,0.000839,6852.377055
robust_deit_small_patch16_224,8170378.0,0.685946,5604437.18,0.698974,5710879.07,0.086791,709114.76,0.0,0.002762,22566.92938,0.002888,23595.266711,0.001686,13775.531624
vit_base_patch16_224.augreg_in1k,16340746.0,0.686455,11217181.92,0.696979,11389157.83,0.090608,1480607.17,0.0,0.00161,26307.535625,0.001702,27815.624333,0.001261,20612.889583
