<a href="https://colab.research.google.com/github/aryanasadianuoit/pytorch_tricks/blob/master/resnet_retrain_torch_use_logits_using_hooks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Exploiting Logits in a pre_trained ResNet Model**

In this code snippet, I have used a ResNet18 model, pre_trained with Image-Net for CIFAR10 classification. The main reason for this code,is how to exploit logits or in a more genral way of saying, any intermediate layer activation in Pytorch.

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models, transforms,datasets
import matplotlib.pyplot as plt
import numpy as np
from torch import cuda

Since the model has been already trained on Imagenet, I have resized the images in CIFAR10 to be the size of Imagenet data-images, i.e., 224 * 224

In [None]:
train_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

test_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

train_set = datasets.CIFAR10("./dataset",train= True, transform= train_transform, download= True)
test_set = datasets.CIFAR10("./dataset",train= False, transform= test_transform, download= True)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified


**This is the scheme of pre-trained ResNet-18.**

In [None]:
base_resnet_model = models.resnet18(pretrained= True, progress= True)

print(base_resnet_model)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
 

As you can see, the last layer of pre-trained ResNet18 has been trained for 1000 class classification task. but in our case, we have to classify unout images to 10 classes in CIFAR10. For this purpose, we have to replace the last layer of ResNet18 with a new layer with 10 output features.

In [None]:
base_resnet_model.fc  = nn.Linear(in_features=512, out_features= 10)


In [None]:
print(base_resnet_model)

In [None]:
for child in base_resnet_model.children():
  print("CHILD ====> ",child,"\n***************************\n")

CHILD ====>  Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
***************************

CHILD ====>  BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 
***************************

CHILD ====>  ReLU(inplace=True) 
***************************

CHILD ====>  MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) 
***************************

CHILD ====>  Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(6

The model except the last newly replaced layer, has already been trained on a ginat Imagenet dataset, So they will generalize well. Therefore, I freeze those layer's parameters, to prevent them from re-training.

In [None]:
for child in base_resnet_model.children():
  if child != "fc":
    child.requires_grad = False

In [None]:
from torch import optim


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(base_resnet_model.parameters(), lr=0.001, momentum=0.9)

In [None]:
previous_saved_loss = 0.0
device = torch.device("cuda")
base_resnet_model.to(device)
for epoch in range(5):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
      # get the inputs; data is a list of [inputs, labels]
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # zero the parameter gradients
      optimizer.zero_grad()
      # forward + backward + optimize
      outputs = base_resnet_model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # print statistics
      running_loss += loss.item()
      if i % 1562 == 1561: 
        print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / 1562))
        correct = 0
        total = 0
        with torch.no_grad():

          for i,data in enumerate(test_loader,0):

            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = base_resnet_model(images)
            _, predicted = torch.max(outputs.data, 1)
            if i == 1 :
              print(predicted)  
              total += labels.size(0)
              correct += (predicted == labels).sum().item()
              print('Accuracy of the network on the 10000 test images: %d %%' % (
               100 * correct / total))
      
        

      


[1,  1562] loss: 0.152
tensor([8, 1, 4, 4, 4, 5, 5, 4, 2, 4, 8, 3, 1, 6, 9, 1, 2, 9, 3, 5, 7, 9, 5, 3,
        4, 0, 2, 6, 6, 9, 4, 5], device='cuda:0')
Accuracy of the network on the 10000 test images: 93 %
[2,  1562] loss: 0.081
tensor([9, 0, 1, 6, 0, 3, 7, 7, 9, 8, 5, 8, 4, 2, 6, 5, 6, 9, 4, 9, 6, 9, 3, 4,
        7, 3, 2, 5, 2, 4, 2, 6], device='cuda:0')
Accuracy of the network on the 10000 test images: 87 %
[3,  1562] loss: 0.044
tensor([8, 3, 2, 9, 2, 3, 8, 7, 6, 2, 4, 8, 1, 9, 6, 6, 6, 9, 9, 6, 8, 6, 6, 8,
        2, 8, 0, 0, 0, 4, 4, 8], device='cuda:0')
Accuracy of the network on the 10000 test images: 93 %
[4,  1562] loss: 0.027
tensor([3, 6, 1, 1, 1, 9, 4, 6, 5, 1, 0, 2, 7, 1, 4, 1, 2, 4, 7, 8, 8, 7, 0, 9,
        2, 3, 2, 3, 7, 7, 2, 8], device='cuda:0')
Accuracy of the network on the 10000 test images: 96 %
[5,  1562] loss: 0.017
tensor([6, 7, 9, 2, 7, 7, 5, 5, 2, 1, 3, 5, 9, 6, 1, 8, 8, 6, 2, 4, 5, 6, 1, 8,
        3, 1, 9, 8, 1, 1, 1, 0], device='cuda:0')
Accuracy of the

Now it's time to use **hook** to exploit the outputs of any required intermediate layer.

In [None]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook



base_resnet_model.fc.register_forward_hook(get_activation('fc'))
#x = torch.randn(1, 25)
#

logits = [[]]

for i,data in enumerate(test_loader,0):
  images, labels = data
  images, labels = images.to(device), labels.to(device)
  outputs = base_resnet_model(images)
  _, predicted = torch.max(outputs.data, 1)
  if i == 1 :
    print(predicted)  
    logits = activation['fc']
    print(logits)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (
     100 * correct / total))
    print("*****************\n\n\n")
    probs = F.softmax(logits, dim=1)
    print(probs)
    print("*****************\n\n\n")
    _, preds = torch.max(probs, 1)
    print(preds)
      

tensor([6, 2, 9, 0, 8, 1, 0, 8, 4, 6, 9, 7, 7, 7, 0, 1, 5, 0, 5, 5, 6, 2, 9, 2,
        1, 6, 4, 4, 8, 9, 8, 9], device='cuda:0')
tensor([[-3.7331e+00, -1.3730e+00,  1.1711e+00,  2.7021e+00, -1.7922e+00,
         -1.0524e+00,  1.6344e+01, -1.8513e+00, -4.2315e+00, -4.3397e+00],
        [ 7.6460e-01, -1.4557e+00,  6.9749e+00, -3.9320e-01, -2.9899e+00,
         -2.1957e+00,  4.5609e+00, -2.7747e+00,  1.4060e+00, -4.1557e+00],
        [-2.1042e+00,  2.6190e+00, -4.5435e+00,  4.2004e+00, -5.1583e+00,
         -1.7244e+00, -5.8664e+00, -1.4226e+00,  5.1304e+00,  9.2242e+00],
        [ 1.1929e+01, -8.8369e-01,  4.1632e+00,  1.5992e+00, -1.9542e+00,
         -2.4676e+00, -5.1772e+00, -1.4144e+00, -1.5531e+00, -3.2480e+00],
        [ 1.7679e-01, -3.1631e+00, -1.5332e+00, -1.5622e+00,  2.6856e+00,
         -2.0327e+00, -2.1232e+00, -3.1344e+00,  1.2942e+01, -3.3473e+00],
        [-2.7848e+00,  1.0757e+01,  4.6447e-01, -3.2248e+00, -1.8374e-01,
         -1.0210e+00, -3.6931e+00, -9.0454e-01, -1.