In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import os
import random
from collections import defaultdict
from PIL import Image
from torch.utils.data import DataLoader, Dataset


The goal is to import resnet 50, change all the Batchnorm layers so that they run using batch statistics even during inference time following the "Simple Cues lead to a Strong Object Tracker" (2022) paper, and update the fully connected layer with the BNNeck from the "Bag of Tricks and a Strong Baseline for Person ReID" paper

In [8]:
resnet50 = torchvision.models.resnet50(weights="DEFAULT")

In [9]:
resnet50.fc.in_features

2048

In [3]:
for i, layer in enumerate(resnet50.modules()):
    if isinstance(layer, nn.BatchNorm2d):
        #use batch stats during inference
        layer.track_running_stats = False 

In [5]:
batch_norms = [x for x in resnet50.modules() 
               if isinstance(x, nn.BatchNorm2d)]

#since max is false, all are false
max([x.track_running_stats for x in batch_norms])

False

In [6]:
max([True, False])

True

We will replace the fc layer with identity first, since we want the features.

In [7]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x


In [None]:
resnet50.fc = Identity()

In [None]:
def get_backbone():
    """Can be modified to change backbone."""
    resnet50 = torchvision.models.resnet50(weights="DEFAULT")
    
    for i, layer in enumerate(resnet50.modules()):
        if isinstance(layer, nn.BatchNorm2d):
            #use batch stats during inference
            layer.track_running_stats = False 
            
    #dummy variable to transport info
    resnet50._infeatures_temp = resnet50.fc.in_features
    #discard fc layer
    resnet50.fc = Identity()
    
    return resnet50
    

In [10]:
resnet50.eval()

resnet50.training

False

In [None]:
class Net(nn.Module):
    """The full Siamese network with BNNeck."""
    def __init__(self, num_classes, backbone=get_backbone()):
        """num_classes is the number of identities being classified."""
        #gives features, used to calculate triplet and center loss
        self.backbone = backbone 
        
        in_features = backbone._infeatures_temp
        
        #output of this is used during inference
        self.batch_norm = nn.BatchNorm2d(num_features=in_features,
                                         track_running_stats=False)
        
        self.fc = nn.Linear(in_features=in_features, 
                            out_features=num_classes,
                            bias=False)
        
    def forward(self, x):
        """During test time returns features and output of fc layers.
        During inference time returns output after batch normalization."""
        
        features = self.backbone(x)
        
        if self.training:
            class_probs = self.fc(self.batch_norm(features))
            return class_probs, features 
        else: #inference mode
            return self.batch_norm(features)