In [8]:
import torch
import torch.nn as nn
from PIL import Image
from ultralytics import YOLO
from collections import OrderedDict
from torchvision import models, transforms

model = models.googlenet(pretrained=True)

In [6]:
for name, module in model.named_modules():
    print(name)


conv1
conv1.conv
conv1.bn
maxpool1
conv2
conv2.conv
conv2.bn
conv3
conv3.conv
conv3.bn
maxpool2
inception3a
inception3a.branch1
inception3a.branch1.conv
inception3a.branch1.bn
inception3a.branch2
inception3a.branch2.0
inception3a.branch2.0.conv
inception3a.branch2.0.bn
inception3a.branch2.1
inception3a.branch2.1.conv
inception3a.branch2.1.bn
inception3a.branch3
inception3a.branch3.0
inception3a.branch3.0.conv
inception3a.branch3.0.bn
inception3a.branch3.1
inception3a.branch3.1.conv
inception3a.branch3.1.bn
inception3a.branch4
inception3a.branch4.0
inception3a.branch4.1
inception3a.branch4.1.conv
inception3a.branch4.1.bn
inception3b
inception3b.branch1
inception3b.branch1.conv
inception3b.branch1.bn
inception3b.branch2
inception3b.branch2.0
inception3b.branch2.0.conv
inception3b.branch2.0.bn
inception3b.branch2.1
inception3b.branch2.1.conv
inception3b.branch2.1.bn
inception3b.branch3
inception3b.branch3.0
inception3b.branch3.0.conv
inception3b.branch3.0.bn
inception3b.branch3.1
incepti

In [10]:
#define transforms to preprocess input image into format expected by model
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
#inverse transform to get normalize image back to original form for visualization
inv_normalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
)

#transforms to resize image to the size expected by pretrained model,
#convert PIL image to tensor, and
#normalize the image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize,          
])

In [11]:
readImg = './eagle.jpg'
img0 = Image.open(readImg).convert("RGB")

In [12]:
def create_matching_sequential(original_model):
    layers = []
    for name, module in original_model.named_children():
        layers.append((name, module))
    return nn.Sequential(OrderedDict(layers))

In [49]:
def replace_last_linear_with_identity(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            parent_name, child_name = name.rsplit('.', 1)
            parent = model
            for part in parent_name.split('.'):
                parent = getattr(parent, part)
            setattr(parent, child_name, nn.Identity())
            print(f"Replaced {name} with Identity layer")
            return True
    return False

In [70]:
##Create a sequential model clone of the classifier
sequential_model = create_matching_sequential(model)

sequential_model.load_state_dict(model.state_dict()) #Load the model state dict

##check matches on every layer
for (name1, param1), (name2, param2) in zip(model.model.named_parameters(), sequential_model.named_parameters()):
    if not torch.allclose(param1, param2):
        print(f"Mismatch in {name1}")
    else:
        print(f"Match in {name1}")

##Check if our new model output agrees with original model
with torch.no_grad():
    original_output = model(img0)
    sequential_output = sequential_model(model.transforms(img0).to(model.device).reshape(1, 3, 224, 224))
 
##Code below should give 'True'
original_output[0].probs.data == sequential_output
 
 ## can also check on random
#img = torch.rand(1, 3, 224, 224)
#with torch.no_grad():
#    original_output = model(img)
#   sequential_output = sequential_model(img.to(model.device))
 
 ##Final step is to replace last layer of our sequential model and verify it's length is what we expect. for my case, it's 1280
success = replace_last_linear_with_identity(sequential_model)
if not success:
    print("Could not find a linear layer to replace")
len(sequential_model(img0.to(model.device))[0])

Match in model.0.conv.weight
Match in model.0.conv.bias
Match in model.1.conv.weight
Match in model.1.conv.bias
Match in model.2.cv1.conv.weight
Match in model.2.cv1.conv.bias
Match in model.2.cv2.conv.weight
Match in model.2.cv2.conv.bias
Match in model.2.m.0.cv1.conv.weight
Match in model.2.m.0.cv1.conv.bias
Match in model.2.m.0.cv2.conv.weight
Match in model.2.m.0.cv2.conv.bias
Match in model.3.conv.weight
Match in model.3.conv.bias
Match in model.4.cv1.conv.weight
Match in model.4.cv1.conv.bias
Match in model.4.cv2.conv.weight
Match in model.4.cv2.conv.bias
Match in model.4.m.0.cv1.conv.weight
Match in model.4.m.0.cv1.conv.bias
Match in model.4.m.0.cv2.conv.weight
Match in model.4.m.0.cv2.conv.bias
Match in model.5.conv.weight
Match in model.5.conv.bias
Match in model.6.cv1.conv.weight
Match in model.6.cv1.conv.bias
Match in model.6.cv2.conv.weight
Match in model.6.cv2.conv.bias
Match in model.6.m.0.cv1.conv.weight
Match in model.6.m.0.cv1.conv.bias
Match in model.6.m.0.cv2.conv.we

KeyError: 1062

In [27]:
activation = {}
def getActivation(name):
    # the hook signature
    def hook(module, input, output):
        activation[name] = output.detach().cpu().numpy()
    return hook

In [31]:
model.model

ClassificationModel(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): C3k2(
      (cv1): Conv(
        (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (cv2): Conv(
        (conv): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
  