In [1]:
import torch
from datasets import IndexedDataset, WeightedDataset, load_data
from torch.utils.data import DataLoader, DistributedSampler

from utils import get_args
from architectures import load_architecture

from samplers import DistributedCustomSampler
from losses import trades_loss, apgd_loss
from tqdm.notebook import tqdm
from architectures import load_architecture, add_lora, set_lora_gradients #load_statedict

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
from datasets.semisupervised_dataset import SemiSupervisedDataset
from torch.utils.data import TensorDataset
import random
from torch.utils.data import Subset
from torch.utils.data import random_split
from datasets.eurosat import EuroSATDataset

from sklearn.model_selection import train_test_split

import math
from functools import partial
from typing import Union, Tuple  # Import Union and Tuple for type annotations

import torch.nn.functional as F
from torch import optim

from torch.utils.data import DataLoader
import numpy as np


In [3]:
@torch.inference_mode()
def _get_redo_masks(activations, tau: float) -> torch.Tensor:
    """
    Computes the ReDo mask for a given set of activations.
    The returned mask has True where neurons are dormant and False where they are active.
    """
    masks = []

    for name, activation in list( activations.items() ):
        # Taking the mean here conforms to the expectation under D in the main paper's formula
        if activation.ndim == 4:
            # Conv layer
            score = activation.abs().mean( dim=(0, 2, 3) )
        else:
            # Linear layer
            score = activation.abs().mean(dim=0)

        # print('score', score)
        # Divide by activation mean to make the threshold independent of the layer size
        # see https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/labs/redo/weight_recyclers.py#L314
        # https://github.com/google/dopamine/issues/209

        normalized_score = score / (score.mean() + 1e-9)

        # print(normalized_score)

        layer_mask = torch.zeros_like(normalized_score, dtype=torch.bool)

        if tau > 0.0:
            layer_mask[normalized_score <= tau] = 1
        else:
            layer_mask[ torch.isclose( normalized_score, torch.zeros_like(normalized_score) ) ] = 1

        masks.append(layer_mask)

    return masks

@torch.inference_mode()
def _get_activation(name: str, activations):
    """Fetches and stores the activations of a network layer."""

    def hook(layer: Union[nn.Linear, nn.Conv2d], input: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
        """
        Get the activations of a layer with ReLU nonlinearity.
        ReLU has to be called explicitly here because the hook is attached to the conv/linear layer.
        """
        activations[name] = F.relu(output)

    return hook

@torch.inference_mode()
def run_redo(
    obs: torch.Tensor,
    model: nn.Module,
    tau: float,
):
    """
    Checks the number of dormant neurons for a given model.

    Returns the number of dormant neurons.
    """

    #print('step1')

    activations = {}
    activation_getter = partial(_get_activation, activations=activations)

    #print('step2')

    # Register hooks for all Conv2d and Linear layers to calculate activations
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            handles.append( module.register_forward_hook(activation_getter(name)) )
    # print(handles)
    #print('step3')

    # Calculate activations
    _ = model( obs )  # Add batch dimension if necessary,  .unsqueeze(0)

    #print('step4')

    # Masks for tau=0 logging
    zero_masks = _get_redo_masks(activations, 0.0)

    total_neurons = sum([torch.numel(mask) for mask in zero_masks])
    zero_count = sum([torch.sum(mask) for mask in zero_masks])
    zero_fraction = zero_count / total_neurons

    # print(zero_count, total_neurons, zero_fraction)

    # Calculate the masks actually used for resetting
    masks = _get_redo_masks(activations, tau)
    dormant_count = sum([torch.sum(mask) for mask in masks])
    total_neurons = sum([torch.numel(mask) for mask in masks])
    dormant_fraction = dormant_count / total_neurons 
    
    # print(dormant_count, total_neurons, dormant_fraction)

    # Remove the hooks again
    for handle in handles:
        handle.remove()

    return {
        "total_neurons":total_neurons,
        "zero_fraction": zero_fraction.item(),
        "zero_count": zero_count.item(),
        "dormant_fraction": dormant_fraction.item(),
        "dormant_count": dormant_count.item(),
    }

In [4]:
args = get_args()

args.iterations = 20
args.pruning_ratio = 0
args.delta = 1

args.dataset = 'CIFAR10'
args.selection_method = 'random'

train_dataset, val_dataset, test_dataset, N, train_transform, transform = load_data(args) 

test_dataset = IndexedDataset(args, test_dataset, transform, N,)  

testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

backbones = [ 
              'convnext_base',  'convnext_base.fb_in22k', 'robust_convnext_base',
              'convnext_tiny',  'convnext_tiny.fb_in22k', 'robust_convnext_tiny',

              'robust_wideresnet_28_10', 

              'deit_small_patch16_224.fb_in1k',
              'robust_deit_small_patch16_224',
              
              'vit_base_patch16_224.augreg_in1k',
              'vit_base_patch16_224.augreg_in21k',
              'robust_vit_base_patch16_224' ]

# Assuming 'train_dataset' is defined and contains your training data
# If 'train_dataset[0][0]' is an image tensor, make sure it's properly preprocessed


tau = 0.1

backbones_result = {}
for backbone_name in backbones:
    # print(backbone_name)

    backbone_result = {
        "total_neurons":[],
        "zero_fraction": [],
        "zero_count": [],
        "dormant_fraction": [],
        "dormant_count": [],
    }

    args.backbone = backbone_name
    model = load_architecture(args,)

    # print(model)
    count = 0
    for obs,labal,idx in testloader:

        result = run_redo(obs, model, tau)
        backbone_result["total_neurons"].append(result["total_neurons"])
        backbone_result["zero_fraction"].append(result["zero_fraction"])
        backbone_result["zero_count"].append(result["zero_count"])
        backbone_result["dormant_fraction"].append(result["dormant_fraction"])
        backbone_result["dormant_count"].append(result["dormant_count"])

        count+=1
        if count==100:
            break

    backbones_result[backbone_name] = { "mean_total_neurons":np.mean(backbone_result["total_neurons"]),
                                   "mean_zero_fraction":np.mean(backbone_result["zero_fraction"]),
                                   "mean_zero_count":np.mean(backbone_result["zero_count"]),
                                   "mean_dormant_fraction":np.mean(backbone_result["dormant_fraction"]),
                                   "mean_dormant_count": np.mean(backbone_result["dormant_count"]),
                                   "std_total_neurons":np.std(backbone_result["total_neurons"]),
                                   "std_zero_fraction":np.std(backbone_result["zero_fraction"]),
                                   "std_zero_count":np.std(backbone_result["zero_count"]),
                                   "std_dormant_fraction":np.std(backbone_result["dormant_fraction"]),
                                   "std_dormant_count": np.std(backbone_result["dormant_count"]) }
    
    print(backbones_result)



./data
train size 47500 val size 2500
Files already downloaded and verified
{'convnext_base': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.03604793233796954, 'mean_zero_count': 767.1, 'mean_dormant_fraction': 0.13412405982613562, 'mean_dormant_count': 2854.16, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.004613859975186449, 'std_zero_count': 98.18294149189053, 'std_dormant_fraction': 0.011445894136221137, 'std_dormant_count': 243.5686235950764}}
{'convnext_base': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.03604793233796954, 'mean_zero_count': 767.1, 'mean_dormant_fraction': 0.13412405982613562, 'mean_dormant_count': 2854.16, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.004613859975186449, 'std_zero_count': 98.18294149189053, 'std_dormant_fraction': 0.011445894136221137, 'std_dormant_count': 243.5686235950764}, 'convnext_base.fb_in22k': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.03122086448594928, 'mean_zero_count': 664.38, 'mean_dormant_frac

In [74]:
import pandas as pd

df = pd.DataFrame.from_dict(backbones_result).T

df.to_csv( 'dormant_results_{}.csv'.format(args.dataset) )