In [None]:
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from timm.data.transforms import _pil_interp
from smt import SMT, build_transforms, build_transforms4display

import os
import numpy as np
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as image

In [2]:
'''
build model
'''
img_size = 224

# SMT
model = SMT(
        embed_dims=[64, 128, 256, 512], ca_num_heads=[4, 4, 4, -1], sa_num_heads=[-1, -1, 8, 16], mlp_ratios=[4, 4, 4, 2], 
        qkv_bias=True, depths=[3, 4, 18, 2], ca_attentions=[1, 1, 1, 0], head_conv=3, expand_ratio=2,).cuda()


In [None]:
'''
build data transform
'''
eval_transforms = build_transforms(img_size, center_crop=False)
display_transforms = build_transforms4display(img_size, center_crop=False)

In [None]:
'''
load checkpoint
'''
ckpt_path = "path/to/smt_small.pth"
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['model'])
model.eval()

In [None]:
# visualize modulator 
upsampler = nn.Upsample(scale_factor=4, mode='bilinear')

img_folder = "./vis_images/"
img_paths = os.listdir(img_folder)
for i, img_path in enumerate(img_paths):
    img = Image.open(img_folder + img_path)
    img_t = eval_transforms(img) 
    img_d = display_transforms(img)
    out = model(img_t.unsqueeze(0).cuda())    

    fig=plt.figure(figsize=(36, 8))
    
    # ori image
    fig.add_subplot(1, 4, 1)       
    img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()
    x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    
    plt.subplots_adjust(wspace=None, hspace=None)

    # Modulator vis in stage 1
    fig.add_subplot(1, 4, 2) 
    modulator = torch.abs((model.block1[-1].attn.modulator)).mean(1, keepdim=True)
    print(modulator.size())
    modulator = upsampler(modulator)
    x = plt.imshow((modulator.squeeze(1)).permute(1, 2, 0).cpu().detach().contiguous().numpy())    
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    
    plt.subplots_adjust(wspace=0, hspace=0)  

    # Modulator vis in stage 2
    fig.add_subplot(1, 4, 3)    
    modulator = torch.abs((model.block2[-1].attn.modulator)).mean(1, keepdim=True)
    print(modulator.size())
    modulator = upsampler(modulator)
    x = plt.imshow((modulator.squeeze(1)).permute(1, 2, 0).cpu().detach().contiguous().numpy())    
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    
    plt.subplots_adjust(wspace=0, hspace=0)  

    # Modulator vis in stage 3
    fig.add_subplot(1, 4, 4)    
    modulator = torch.abs((model.block3[-2].attn.modulator)).mean(1, keepdim=True)
    print(modulator.size())
    modulator = upsampler(modulator)
    x = plt.imshow((modulator.squeeze(1)).permute(1, 2, 0).cpu().detach().contiguous().numpy())    
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)   

    # plt.savefig('./figures/img_modulator_{}.png'.format(i),dpi=600)  