In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Sequential as Seq
import numpy as np
import cv2
from torchvision.transforms import Compose, Normalize, ToTensor

In [2]:
class Stem(nn.Module):
    def __init__(self, input_dim, output_dim, activation=nn.GELU):
        super(Stem, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(input_dim, output_dim // 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(output_dim // 2),
            nn.GELU(),
            nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(output_dim),
            nn.GELU()   
        )
        
    def forward(self, x):
        return self.stem(x)

In [3]:
class MobileViG(torch.nn.Module):
    def __init__(self, local_channels):
        super(MobileViG, self).__init__()
        self.stem = Stem(input_dim=3, output_dim=local_channels[0])

    def forward(self, inputs):
        x = self.stem(inputs)
        return x

In [4]:
def preprocess_image(img: np.ndarray, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> torch.Tensor:
    preprocessing = Compose([ToTensor(), Normalize(mean=mean, std=std)])
    return preprocessing(img.copy()).unsqueeze(0)

def deprocess_image(img):
    img = (img - np.mean(img)) / (np.std(img) + 1e-5)
    img = img * 0.1 + 0.5
    img = np.clip(img, 0, 1)
    return np.uint8(img * 255)

In [5]:
 import torchsummary
# test model with (1, 3, 224, 224) input 
checkpoint_path = 'MobileViG_B_82_6.pth.tar'
checkpoint = torch.load(f'./{checkpoint_path}', map_location='cpu')

img = cv2.imread('./image_test/demo.jpg')
img = cv2.resize(img, (224, 224))
img = np.float32(img) / 255

img = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
local_channels=[42, 84, 240]
model = MobileViG(local_channels)
state_dict = {k:v for k,v in checkpoint['state_dict'].items() if 'stem' in k}
model.load_state_dict(state_dict, strict=False)
model.eval()
out = model(img)
out = out.detach().numpy()
out = out[0]

# visualize feature map

for i in range(len(out)):
    img_result = deprocess_image(out[i])
    color_img = cv2.cvtColor(img_result, cv2.COLOR_GRAY2BGR)
    cv2.imwrite(f'./test/feature_map_{i}.jpg', color_img)
    print(f'feature map {i} is saved')
    


feature map 0 is saved
feature map 1 is saved
feature map 2 is saved
feature map 3 is saved
feature map 4 is saved
feature map 5 is saved
feature map 6 is saved
feature map 7 is saved
feature map 8 is saved
feature map 9 is saved
feature map 10 is saved
feature map 11 is saved
feature map 12 is saved
feature map 13 is saved
feature map 14 is saved
feature map 15 is saved
feature map 16 is saved
feature map 17 is saved
feature map 18 is saved
feature map 19 is saved
feature map 20 is saved
feature map 21 is saved
feature map 22 is saved
feature map 23 is saved
feature map 24 is saved
feature map 25 is saved
feature map 26 is saved
feature map 27 is saved
feature map 28 is saved
feature map 29 is saved
feature map 30 is saved
feature map 31 is saved
feature map 32 is saved
feature map 33 is saved
feature map 34 is saved
feature map 35 is saved
feature map 36 is saved
feature map 37 is saved
feature map 38 is saved
feature map 39 is saved
feature map 40 is saved
feature map 41 is saved
