# load resnet50

In [1]:
import torchvision
import torch
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
model = torchvision.models.resnet50()
# Model
n_classes = 2
model = torchvision.models.resnet50(pretrained=True)
d = model.fc.in_features
model.fc = torch.nn.Linear(d, n_classes)

checkpoint = torch.load("/content/tmp_checkpoint10(11).pt")
model.load_state_dict(checkpoint)

model.cuda()
model.eval()

  checkpoint = torch.load("/content/tmp_checkpoint10(11).pt")


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

# load autoencoder

In [14]:
"""
This file defines an AutoEncoder class, which also contains an implementation of neuron resampling.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class AutoEncoder(nn.Module):
    def __init__(self, n_inputs: int, n_latents: int, lam: float = 0.003, resampling_interval: int = 25000):
        """
        n_input: Number of inputs
        n_latents: Number of neurons in the hidden layer
        lam: L1-coefficient for Sparse Autoencoder
        resampling_interval: Number of training steps after which dead neurons will be resampled
        """
        super().__init__()
        self.n_inputs, self.n_latents = n_inputs, n_latents
        self.encoder = nn.Linear(n_inputs, n_latents)
        self.relu = nn.ReLU()
        self.decoder = nn.Linear(n_latents, n_inputs)
        self.lam = lam
        self.resampling_interval = resampling_interval
        self.dead_neurons = None
        self.normalize_decoder_columns()

    def forward(self, x):
        latents = self.encode(x)
        reconstructed = self.decode(latents)
        loss = self.calculate_loss(x, latents, reconstructed)

        if self.training:
            return {'loss': loss, 'latents': latents}
        else:
            return {
                'loss': loss,
                'latents': latents,
                'reconst_acts': reconstructed,
                'mse_loss': self.mse_loss(reconstructed, x),
                'l1_loss': self.l1_loss(latents)
            }

    def encode(self, x):
        bias_corrected_input = x - self.decoder.bias
        return self.relu(self.encoder(bias_corrected_input))

    def decode(self, encoded):
        return self.decoder(encoded)

    def calculate_loss(self, x, encoded, reconstructed):
        mse_loss = self.mse_loss(reconstructed, x)
        l1_loss = self.l1_loss(encoded)
        return mse_loss + self.lam * l1_loss

    def mse_loss(self, reconstructed, original):
        return F.mse_loss(reconstructed, original)

    def l1_loss(self, encoded):
        return F.l1_loss(encoded, torch.zeros_like(encoded), reduction='sum') / encoded.shape[0]

    @torch.no_grad()
    def get_feature_activations(self, inputs, start_idx, end_idx):
        """
        Computes the activations of a subset of features in the hidden layer.

        :param inputs: Input tensor of shape (..., n) where n = d_MLP. It includes batch dimensions.
        :param start_idx: Starting index (inclusive) of the feature subset.
        :param end_idx: Ending index (exclusive) of the feature subset.

        Returns the activations for the specified feature range, reducing computation by
        only processing the necessary part of the network's weights and biases.
        """
        adjusted_inputs = inputs - self.decoder.bias  # Adjust input to account for decoder bias
        weight_subset = self.encoder.weight[start_idx:end_idx, :].t()  # Transpose the subset of weights
        bias_subset = self.encoder.bias[start_idx:end_idx]

        activations = self.relu(adjusted_inputs @ weight_subset + bias_subset)

        return activations

    @torch.no_grad()
    def normalize_decoder_columns(self):
        """
        Normalize the decoder's weight vectors to have unit norm along the feature dimension.
        This normalization can help in maintaining the stability of the network's weights.
        """
        self.decoder.weight.data = F.normalize(self.decoder.weight.data, dim=0)

    def remove_parallel_component_of_decoder_grad(self):
        """
        Remove the component of the gradient parallel to the decoder's weight vectors.
        """
        unit_weights = F.normalize(self.decoder.weight, dim=0) # \hat{b}
        proj = (self.decoder.weight.grad * unit_weights).sum(dim=0) * unit_weights
        self.decoder.weight.grad = self.decoder.weight.grad - proj

    @staticmethod
    def is_dead_neuron_investigation_step(step, resampling_interval, num_resamples):
        """
        Determine if the current step is the start of a phase for investigating dead neurons.
        According to Anthropic's specified policy, it occurs at odd multiples of half the resampling interval.
        """
        return (step > 0) and step % (resampling_interval // 2) == 0 and (step // (resampling_interval // 2)) % 2 != 0 and step < resampling_interval * num_resamples

    @staticmethod
    def is_within_neuron_investigation_phase(step, resampling_interval, num_resamples):
        """
        Check if the current step is within a phase where active neurons are investigated.
        This phase occurs in intervals defined in the specified range, starting at odd multiples of half the resampling interval.
        """
        return any(milestone - resampling_interval // 2 <= step < milestone
                   for milestone in range(resampling_interval, resampling_interval * (num_resamples + 1), resampling_interval))

    @torch.no_grad()
    def initiate_dead_neurons(self):
        self.dead_neurons = set(range(self.n_latents))

    @torch.no_grad()
    def update_dead_neurons(self, latents):
        """
        Update the set of dead neurons based on the current feature activations.
        If a neuron is active (has non-zero activation), it is removed from the dead neuron set.
        """
        active_neurons = torch.nonzero(torch.count_nonzero(latents, dim=0), as_tuple=False).view(-1)
        self.dead_neurons.difference_update(active_neurons.tolist())

    @torch.no_grad()
    def resample_dead_neurons(self, data, optimizer, batch_size=8192):
        """
        Resample the dead neurons by resetting their weights and biases based on the characteristics
        of active neurons. Proceeds only if there are dead neurons to resample.
        """
        if not self.dead_neurons:
            return

        device = self._get_device()
        dead_neurons_t, alive_neurons = self._get_neuron_indices()
        average_enc_norm = self._compute_average_norm_of_alive_neurons(alive_neurons)
        probs = self._compute_loss_probabilities(data, batch_size, device)
        selected_examples = self._select_examples_based_on_probabilities(data, probs)

        self._resample_neurons(selected_examples, dead_neurons_t, average_enc_norm, device)
        self._update_optimizer_parameters(optimizer, dead_neurons_t)

        print('Dead neurons resampled successfully!')
        self.dead_neurons = None

    def _get_device(self):
        return next(self.parameters()).device

    def _get_neuron_indices(self):
        dead_neurons_t = torch.tensor(list(self.dead_neurons), device=self._get_device())
        alive_neurons = torch.tensor([i for i in range(self.n_latents) if i not in self.dead_neurons], device=self._get_device())
        return dead_neurons_t, alive_neurons

    def _compute_average_norm_of_alive_neurons(self, alive_neurons):
        return torch.linalg.vector_norm(self.encoder.weight[alive_neurons], dim=1).mean()

    def _compute_loss_probabilities(self, data, batch_size, device):
        num_batches = (len(data) + batch_size - 1) // batch_size
        probs = torch.zeros(len(data), device=device)
        for i in range(num_batches):
            batch_slice = slice(i * batch_size, (i + 1) * batch_size)
            x_batch = data[batch_slice].to(device)
            probs[batch_slice] = self._compute_batch_loss_squared(x_batch)
        return probs.cpu()

    def _compute_batch_loss_squared(self, x_batch):
        latents = self.encode(x_batch)
        reconst_acts = self.decode(latents)
        mselosses = F.mse_loss(reconst_acts, x_batch, reduction='none').sum(dim=1)
        l1losses = F.l1_loss(latents, torch.zeros_like(latents), reduction='none').sum(dim=1)
        return (mselosses + self.lam * l1losses).square()

    def _select_examples_based_on_probabilities(self, data, probs):
        selection_indices = torch.multinomial(probs, num_samples=len(self.dead_neurons))
        return data[selection_indices].to(dtype=torch.float32)

    def _resample_neurons(self, examples, dead_neurons_t, average_enc_norm, device):
        examples_unit_norm = F.normalize(examples, dim=1).to(device)
        self.decoder.weight[:, dead_neurons_t] = examples_unit_norm.T

        # Renormalize examples to have a certain norm and reset encoder weights and biases
        adjusted_examples = examples_unit_norm * average_enc_norm * 0.2
        self.encoder.weight[dead_neurons_t] = adjusted_examples
        self.encoder.bias[dead_neurons_t] = 0

    def _update_optimizer_parameters(self, optimizer, dead_neurons_t):
        for i, param in enumerate(optimizer.param_groups[0]['params']):
            param_state = optimizer.state[param]
            if i in [0, 1]:  # Encoder weights and biases
                param_state['exp_avg'][dead_neurons_t] = 0
                param_state['exp_avg_sq'][dead_neurons_t] = 0
            elif i == 2:  # Decoder weights
                param_state['exp_avg'][:, dead_neurons_t] = 0
                param_state['exp_avg_sq'][:, dead_neurons_t] = 0

In [15]:
## MY

"""
Train a Sparse AutoEncoder model

Run on a macbook on a Shakespeare dataset as
python train.py --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 --eval_iters=1 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150 --wandb_log=True
"""
import os
import torch
import numpy as np
import time


## hyperparameters

# training
n_features = 8096
batch_size = 32 # batch size for autoencoder training
l1_coeff = 3e-3
learning_rate = 3e-4
resampling_interval = 25000 # number of training steps after which neuron resampling will be performed
num_resamples = 4 # number of times resampling is to be performed; it is done 4 times in Anthropic's paper
resampling_data_size = 819200
# evaluation
eval_batch_size = 16 # batch size (number of GPT contexts) for evaluation
eval_iters = 200 # number of iterations in the evaluation loop
eval_interval = 1000 # number of training steps after which the autoencoder is evaluated
# I/O
save_checkpoint = True # whether to save model, optimizer, etc or not
save_interval = 10000 # number of training steps after which a checkpoint will be saved
out_dir = 'out' # directory containing trained autoencoder model weights
# wandb logging
wandb_log = True
# system
device = 'cuda'
# reproducibility
seed = 1442

# -----------------------------------------------------------------------------
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
#exec(open('configurator.py').read()) # overrides from command line or config file
config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------

In [16]:
autoencoder = AutoEncoder(n_inputs = 2048,
                            n_latents = n_features,
                            lam = l1_coeff,
                            resampling_interval = resampling_interval).to(device)
autoencoder.load_state_dict(torch.load('/content/drive/MyDrive/ckpt_final.pt', weights_only=True)['autoencoder'])
autoencoder.eval()

AutoEncoder(
  (encoder): Linear(in_features=2048, out_features=8096, bias=True)
  (relu): ReLU()
  (decoder): Linear(in_features=8096, out_features=2048, bias=True)
)

# load data

In [1]:
# !pip install kaggle
# !mkdir ~/.kaggle
# !touch ~/.kaggle/kaggle.json

# api_token = {"username":"username","key":"api-key"}

# import json

# with open('/root/.kaggle/kaggle.json', 'w') as file:
#     json.dump(api_token, file)

# !chmod 600 ~/.kaggle/kaggle.json
# !kaggle datasets download -d jessicali9530/celeba-dataset

# !unzip celeba-dataset.zip -d celeba-dataset


In [5]:
!git clone https://github.com/PolinaKirichenko/deep_feature_reweighting.git

Cloning into 'deep_feature_reweighting'...
remote: Enumerating objects: 55, done.[K
remote: Counting objects: 100% (25/25), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 55 (delta 19), reused 13 (delta 13), pack-reused 30 (from 1)[K
Receiving objects: 100% (55/55), 2.51 MiB | 5.33 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [6]:
!cp deep_feature_reweighting/celeba_metadata.csv celeba-dataset/img_align_celeba/metadata.csv

In [17]:
from deep_feature_reweighting.wb_data import WaterBirdsDataset, get_loader, get_transform_cub, log_data

basedir = '/content/celeba-dataset/img_align_celeba/'
test_wb_dir='/content/celeba-dataset/img_align_celeba/'

# Data
target_resolution = (224, 224)
train_transform = get_transform_cub(target_resolution=target_resolution, train=True, augment_data=None)
test_transform = get_transform_cub(target_resolution=target_resolution, train=False, augment_data=None)


trainset = WaterBirdsDataset(basedir=basedir, split="train", transform=train_transform)
testset_dict = {
    'wb': WaterBirdsDataset(basedir=test_wb_dir, split="test", transform=test_transform),
    'wb_val': WaterBirdsDataset(basedir=test_wb_dir, split="val", transform=test_transform),
}
''
loader_kwargs = {'batch_size': batch_size, 'num_workers': 4, 'pin_memory': True}
train_loader = get_loader(trainset, reweight_groups=None,
                          reweight_classes=None, reweight_places=None, train=False, **loader_kwargs)

202599
162770
202599
19962
202599
19867


In [18]:
test_loader_dict = {}
for test_name, testset_v in testset_dict.items():
    test_loader_dict[test_name] = get_loader(
        testset_v, train=False, reweight_groups=None,
        reweight_classes=None, reweight_places=None, **loader_kwargs)

# estimate

In [9]:
from deep_feature_reweighting.utils import AverageMeter, get_results
import tqdm
import gc

In [10]:
def get_embed(m, x):
    x = m.conv1(x)
    x = m.bn1(x)
    x = m.relu(x)
    x = m.maxpool(x)

    x = m.layer1(x)
    x = m.layer2(x)
    x = m.layer3(x)
    x = m.layer4(x)

    x = m.avgpool(x)
    x = torch.flatten(x, 1)
    return x

In [11]:
def model_and_autoencoder(m, x):
  x = get_embed(m, x)
  x = autoencoder(x)['reconst_acts']
  x = m.fc(x)
  return x

In [10]:
predict_place = False
acc_groups = {g_idx : AverageMeter() for g_idx in range(train_loader.dataset.n_groups)}
for x, y, g, p in tqdm.tqdm(train_loader):
    x, y, p = x.cuda(), y.cuda(), p.cuda()
    if predict_place:
        y = p

    logits = model_and_autoencoder(model, x)
    preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n)

100%|██████████| 5087/5087 [03:55<00:00, 21.62it/s]


In [16]:
groups = acc_groups.keys()

all_correct = sum([acc_groups[g].sum for g in groups])
all_total = sum([acc_groups[g].count for g in groups])
print("mean_accuracy :", all_correct / all_total)


mean_accuracy : 0.9214843030042391


In [12]:

torch.cuda.empty_cache()
gc.collect()

0

In [19]:
test_loader_dict

{'wb': <torch.utils.data.dataloader.DataLoader at 0x7b644686a1d0>,
 'wb_val': <torch.utils.data.dataloader.DataLoader at 0x7b644686ab10>}

In [22]:
predict_place = False
acc_groups = {g_idx : AverageMeter() for g_idx in range(test_loader_dict['wb'].dataset.n_groups)}
for x, y, g, p in tqdm.tqdm(test_loader_dict['wb']):
    x, y, p = x.cuda(), y.cuda(), p.cuda()
    if predict_place:
        y = p

    logits = model_and_autoencoder(model, x)
    preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n)

groups = acc_groups.keys()

all_correct = sum([acc_groups[g].sum for g in groups])
all_total = sum([acc_groups[g].count for g in groups])
print("mean_accuracy :", all_correct / all_total)

100%|██████████| 624/624 [00:28<00:00, 21.67it/s]

mean_accuracy : 0.9197475202885482





In [23]:
predict_place = False
acc_groups = {g_idx : AverageMeter() for g_idx in range(test_loader_dict['wb_val'].dataset.n_groups)}
for x, y, g, p in tqdm.tqdm(test_loader_dict['wb_val']):
    x, y, p = x.cuda(), y.cuda(), p.cuda()
    if predict_place:
        y = p

    logits = model_and_autoencoder(model, x)
    preds = torch.argmax(logits, axis=1)
    correct_batch = (preds == y)
    g = g.cpu()
    for g_val in np.unique(g):
        mask = g == g_val
        n = mask.sum().item()
        corr = correct_batch[mask].sum().item()
        acc_groups[g_val].update(corr / n, n)

groups = acc_groups.keys()

all_correct = sum([acc_groups[g].sum for g in groups])
all_total = sum([acc_groups[g].count for g in groups])
print("mean_accuracy :", all_correct / all_total)

100%|██████████| 621/621 [00:28<00:00, 21.47it/s]

mean_accuracy : 0.9171993758493985



