In [1]:
import torch
from datasets import IndexedDataset, WeightedDataset, load_data
from torch.utils.data import DataLoader, DistributedSampler

from utils import get_args
from architectures import load_architecture

from samplers import DistributedCustomSampler
from losses import trades_loss, apgd_loss
from tqdm.notebook import tqdm
from architectures import load_architecture, add_lora, set_lora_gradients #load_statedict

import torch.nn as nn

from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
from datasets.semisupervised_dataset import SemiSupervisedDataset
from torch.utils.data import TensorDataset
import random
from torch.utils.data import Subset
from torch.utils.data import random_split
from datasets.eurosat import EuroSATDataset

from sklearn.model_selection import train_test_split

import math
from functools import partial
from typing import Union, Tuple  # Import Union and Tuple for type annotations

import torch.nn.functional as F
from torch import optim

from torch.utils.data import DataLoader
import numpy as np


In [2]:
def load_data(args):

    if args.dataset == 'MNIST':

        # Load the dataset
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))    ])

        N = 10
        
        dataset = datasets.MNIST(root=args.data_dir, train=True, download=True, )

        train_size = int(0.95 * len(dataset))
        val_size = len(dataset) - train_size

        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        test_dataset = datasets.MNIST(root=args.data_dir, train=False, download=True, )


    elif args.dataset == 'CIFAR10':

        transform = transforms.Compose([transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ),])
        
        N = 10
        

        dataset = datasets.CIFAR10(root=args.data_dir, train=True, download=False, )

        train_size = int(0.95 * len(dataset))
        val_size = len(dataset) - train_size
        print("train size", train_size, "val size", val_size)

        generator1 = torch.Generator().manual_seed(42)
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator = generator1)  
        test_dataset = datasets.CIFAR10(root=args.data_dir, train=False, download=True, )

    elif args.dataset == 'CIFAR100':

        transform = transforms.Compose([transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ),])


        N = 10
        
        dataset = datasets.CIFAR100(root=args.data_dir, train=True, download=False, )

        train_size = int(0.95 * len(dataset))
        val_size = len(dataset) - train_size
        print("train size", train_size, "val size", val_size)

        generator1 = torch.Generator().manual_seed(42)
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator = generator1)  
        test_dataset = datasets.CIFAR100(root=args.data_dir, train=False, download=True, )

    elif args.dataset == 'Aircraft':
        
        transform = transforms.Compose([
                                        transforms.Resize((224, 224)),  # Resize images to 224x224
                                        transforms.ToTensor(), 
                                        transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ),])


        N = 100
        
        train_dataset = datasets.FGVCAircraft(root=args.data_dir, split='train', download=False, )
        val_dataset =   datasets.FGVCAircraft(root=args.data_dir, split='val', download=False, )
        test_dataset = datasets.FGVCAircraft(root=args.data_dir, split='test', download=False, )

    elif args.dataset == 'EuroSAT':
        
        transform = transforms.Compose([
                                        transforms.Resize((224, 224)),  
                                        transforms.ToTensor(),
                                        transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ),])


        # Load the dataset
        path = args.data_dir + '/2750'
        dataset = EuroSATDataset(root_dir=path)

        # Extract labels from the dataset
        labels = [label for _, label in dataset]

        # Split the dataset into train+val and test, keeping stratification
        train_val_indices, test_indices = train_test_split( range(len(labels)), test_size=0.2, stratify=labels, random_state=42 )

        # Extract labels for the train+val set for further stratification
        train_val_labels = [labels[i] for i in train_val_indices]

        # Split the train+val set into train and validation, keeping stratification
        train_indices, val_indices = train_test_split( train_val_indices, test_size=0.15, stratify=train_val_labels, random_state=42 )  # 0.25 * 0.8 = 0.2 of the dataset

        # Create subsets for train, validation, and test

        N = 10

        train_dataset = torch.utils.data.Subset(dataset, train_indices)
        val_dataset = torch.utils.data.Subset(dataset, val_indices)
        test_dataset = torch.utils.data.Subset(dataset, test_indices)
        
    return train_dataset, val_dataset, test_dataset, N, None, transform

def to_rgb(x):
    return x.convert("RGB")


In [3]:
@torch.inference_mode()
def _get_redo_masks(activations, tau: float) -> torch.Tensor:
    """
    Computes the ReDo mask for a given set of activations.
    The returned mask has True where neurons are dormant and False where they are active.
    """
    masks = []

    for name, activation in list( activations.items() ):
        # Taking the mean here conforms to the expectation under D in the main paper's formula
        if activation.ndim == 4:
            # Conv layer
            score = activation.abs().mean( dim=(0, 2, 3) )
        else:
            # Linear layer
            score = activation.abs().mean(dim=0)

        # print('score', score)
        # Divide by activation mean to make the threshold independent of the layer size
        # see https://github.com/google/dopamine/blob/ce36aab6528b26a699f5f1cefd330fdaf23a5d72/dopamine/labs/redo/weight_recyclers.py#L314
        # https://github.com/google/dopamine/issues/209

        normalized_score = score / (score.mean() + 1e-9)

        # print(normalized_score)

        layer_mask = torch.zeros_like(normalized_score, dtype=torch.bool)

        if tau > 0.0:
            layer_mask[normalized_score <= tau] = 1
        else:
            layer_mask[ torch.isclose( normalized_score, torch.zeros_like(normalized_score) ) ] = 1

        masks.append(layer_mask)

    return masks

@torch.inference_mode()
def _get_activation(name: str, activations):
    """Fetches and stores the activations of a network layer."""

    def hook(layer: Union[nn.Linear, nn.Conv2d], input: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None:
        """
        Get the activations of a layer with ReLU nonlinearity.
        ReLU has to be called explicitly here because the hook is attached to the conv/linear layer.
        """
        activations[name] = F.relu(output)

    return hook

@torch.inference_mode()
def run_redo(
    obs: torch.Tensor,
    model: nn.Module,
    tau: float,
):
    """
    Checks the number of dormant neurons for a given model.

    Returns the number of dormant neurons.
    """

    #print('step1')

    activations = {}
    activation_getter = partial(_get_activation, activations=activations)

    #print('step2')

    # Register hooks for all Conv2d and Linear layers to calculate activations
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            handles.append( module.register_forward_hook(activation_getter(name)) )
    # print(handles)
    #print('step3')

    # Calculate activations
    _ = model( obs )  # Add batch dimension if necessary,  .unsqueeze(0)

    #print('step4')

    # Masks for tau=0 logging
    zero_masks = _get_redo_masks(activations, 0.0)

    total_neurons = sum([torch.numel(mask) for mask in zero_masks])
    zero_count = sum([torch.sum(mask) for mask in zero_masks])
    zero_fraction = zero_count / total_neurons

    # print(zero_count, total_neurons, zero_fraction)

    # Calculate the masks actually used for resetting
    masks = _get_redo_masks(activations, tau)
    dormant_count = sum([torch.sum(mask) for mask in masks])
    total_neurons = sum([torch.numel(mask) for mask in masks])
    dormant_fraction = dormant_count / total_neurons 
    
    # print(dormant_count, total_neurons, dormant_fraction)

    # Remove the hooks again
    for handle in handles:
        handle.remove()

    return {
        "total_neurons":total_neurons,
        "zero_fraction": zero_fraction.item(),
        "zero_count": zero_count.item(),
        "dormant_fraction": dormant_fraction.item(),
        "dormant_count": dormant_count.item(),
    }

In [4]:
args = get_args()

args.iterations = 20
args.pruning_ratio = 0
args.delta = 1

args.dataset = 'CIFAR10'
args.selection_method = 'random'

train_dataset, val_dataset, test_dataset, N, train_transform, transform = load_data(args) 

test_dataset = IndexedDataset(args, test_dataset, transform, N,)  

testloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)


backbones = [ 
              'convnext_base',  'convnext_base.fb_in22k', 'robust_convnext_base',
              'convnext_tiny',  'convnext_tiny.fb_in22k', 'robust_convnext_tiny',

              'robust_wideresnet_28_10', 

              'deit_small_patch16_224.fb_in1k',
              'robust_deit_small_patch16_224',
              
              'vit_base_patch16_224.augreg_in1k',
              'vit_base_patch16_224.augreg_in21k',
              'robust_vit_base_patch16_224' ]

# Assuming 'train_dataset' is defined and contains your training data
# If 'train_dataset[0][0]' is an image tensor, make sure it's properly preprocessed


tau = 0.1

backbones_result = {}
for backbone_name in backbones:
    # print(backbone_name)

    backbone_result = {
        "total_neurons":[],
        "zero_fraction": [],
        "zero_count": [],
        "dormant_fraction": [],
        "dormant_count": [],
    }

    args.backbone = backbone_name
    model = load_architecture(args,)

    # print(model)
    count = 0
    for obs,labal,idx in testloader:

        result = run_redo(obs, model, tau)
        backbone_result["total_neurons"].append(result["total_neurons"])
        backbone_result["zero_fraction"].append(result["zero_fraction"])
        backbone_result["zero_count"].append(result["zero_count"])
        backbone_result["dormant_fraction"].append(result["dormant_fraction"])
        backbone_result["dormant_count"].append(result["dormant_count"])

        count+=1
        if count==100:
            break

    backbones_result[backbone_name] = { "mean_total_neurons":np.mean(backbone_result["total_neurons"]),
                                   "mean_zero_fraction":np.mean(backbone_result["zero_fraction"]),
                                   "mean_zero_count":np.mean(backbone_result["zero_count"]),
                                   "mean_dormant_fraction":np.mean(backbone_result["dormant_fraction"]),
                                   "mean_dormant_count": np.mean(backbone_result["dormant_count"]),
                                   "std_total_neurons":np.std(backbone_result["total_neurons"]),
                                   "std_zero_fraction":np.std(backbone_result["zero_fraction"]),
                                   "std_zero_count":np.std(backbone_result["zero_count"]),
                                   "std_dormant_fraction":np.std(backbone_result["dormant_fraction"]),
                                   "std_dormant_count": np.std(backbone_result["dormant_count"]) }
    
    print(backbones_result)



./data
train size 47500 val size 2500
Files already downloaded and verified
{'convnext_base': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.03606156019493938, 'mean_zero_count': 767.39, 'mean_dormant_fraction': 0.1341353379935026, 'mean_dormant_count': 2854.4, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.004606935865090335, 'std_zero_count': 98.03559506628191, 'std_dormant_fraction': 0.011442157080328938, 'std_dormant_count': 243.48909626511002}}
{'convnext_base': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.03606156019493938, 'mean_zero_count': 767.39, 'mean_dormant_fraction': 0.1341353379935026, 'mean_dormant_count': 2854.4, 'std_total_neurons': 0.0, 'std_zero_fraction': 0.004606935865090335, 'std_zero_count': 98.03559506628191, 'std_dormant_fraction': 0.011442157080328938, 'std_dormant_count': 243.48909626511002}, 'convnext_base.fb_in22k': {'mean_total_neurons': 21280.0, 'mean_zero_fraction': 0.03125563904643059, 'mean_zero_count': 665.12, 'mean_dormant_frac

In [74]:
import pandas as pd

df = pd.DataFrame.from_dict(backbones_result).T

df.to_csv( 'dormant_results_{}.csv'.format(args.dataset) )