# ⚒️ Intervention of ResNet MNIST.

In [1]:
from imports import *
from transformers import ResNetForImageClassification, ResNetConfig


In [2]:
resnet = ResNetForImageClassification(ResNetConfig(num_labels=10))
resnet.config

ResNetConfig {
  "depths": [
    3,
    4,
    6,
    3
  ],
  "downsample_in_bottleneck": false,
  "downsample_in_first_stage": false,
  "embedding_size": 64,
  "hidden_act": "relu",
  "hidden_sizes": [
    256,
    512,
    1024,
    2048
  ],
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6",
    "7": "LABEL_7",
    "8": "LABEL_8",
    "9": "LABEL_9"
  },
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4,
    "LABEL_5": 5,
    "LABEL_6": 6,
    "LABEL_7": 7,
    "LABEL_8": 8,
    "LABEL_9": 9
  },
  "layer_type": "bottleneck",
  "model_type": "resnet",
  "num_channels": 3,
  "out_features": [
    "stage4"
  ],
  "out_indices": [
    4
  ],
  "pooler_shape": [
    1,
    1,
    14,
    14
  ],
  "stage_names": [
    "stem",
    "stage1",
    "stage2",
    "stage3",
    "stage4"
  ],
  "transformers_version": "4.36.2"
}

In [3]:
resnet

ResNetForImageClassification(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBottleNeckLayer(
              (shortcut): ResNetShortCut(
                (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64

In [4]:
image = Image.open('pvr_mnist_dataset/images/combined_image_0.png').convert('RGB')

Converted the code:
``` 
# returning un-intervened output without gradients
with torch.inference_mode():
    base_outputs = self.model(base)
```

and also the code 

```
# run intervened forward
counterfactual_outputs = self.model(**base)
set_handlers_to_remove.remove()
```

I removed `(**base)` and converted to just `(base)`

In [18]:
import pyvene as pv

config = resnet.config

pv_resnet = pv.IntervenableModel({
    "component": "resnet.embedder.pooler.output", 
    "source_representation": torch.zeros(config.pooler_shape, dtype = torch.float32)}, model=resnet)


intervened_outputs = pv_resnet(
    base = torch.tensor(np.array(image), dtype = torch.float32).reshape(1,-1,56,56), 
    # unit_locations={"base": ([[[h for h in range(12)]]])}
    unit_locations={"base": [0,1,2,3]}, 
    subspaces=[0,1,2,3]
    )[0][-1][0]

'''
The unit locations has been defined in such a way that it corresponds to the 
language models having the dimension in the order (batch_size, sequence_length, hidden_size).
Meanwhile, the images have the dimension in the order 
(batch_size, channels, height, width), hence creating an issue.
'''

# I think intervened output gives both the output of the model, i.e. the prediction of counterfactual and also of factual model.

# print("Now finally, the model has ran and this is the intervened output" +str( intervened_outputs))

print(intervened_outputs)
predicted_indices = torch.argmax(intervened_outputs)
class_names = [str(i) for i in range(10)]  # Class names from 0 to 9
# predicted_classes = [class_names[idx] for idx in predicted_indices.cpu().numpy()]
class_names[int(predicted_indices)]



torch.Size([1, 3, 56, 56])
The shape of the embedding is:  torch.Size([1, 64, 14, 14])
The pooled output is tensor([[[[ 413.1888]],

         [[ 763.1083]],

         [[   0.0000]],

         ...,

         [[1268.4442]],

         [[ 197.8016]],

         [[1551.6110]]]])
The logits are: tensor([[ -704.1973, -1682.2889,  -191.5321,   473.5226,  -230.1412,  -428.9193,
           295.9371,  -100.1536,   248.2686,  -146.5026]])
torch.Size([1, 3, 56, 56])


ValueError: source with shape torch.Size([1, 1, 14, 14]) cannot be broadcasted into base with shape torch.Size([1, 4, 14, 14]).

In [6]:
intervened_outputs

tensor([ -704.1973, -1682.2889,  -191.5321,   473.5226,  -230.1412,  -428.9193,
          295.9371,  -100.1536,   248.2686,  -146.5026])