In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

import torch.nn as nn
from torchvision import transforms
from _code.color_lib import RGBmean,RGBstdv
from visualization.Reader2 import ImageReader
from visualization.Resnet import resnet18
from torch.utils.data.sampler import Sampler
import random

seed = 2

In [None]:
record1 = torch.load('./record/record1.pth')    #EPSHN
record2 = torch.load('./record/record2.pth')    #EPSHN with background
# print(record1)

#plot training accuracy
x = np.linspace(0,60,61)

plt.xlabel("epoch")
plt.ylabel("training accuracy")
plt.plot(x,record1.T[2].tolist(),label='EPSHN')
plt.plot(x,record2.T[2].tolist(),label='EPSHN with background')
plt.legend()
plt.title('Training Accuracy')
plt.show()

In [None]:
#plot validation accuracy
x = np.linspace(0,60,61)
# plt.figure(figsize=(8,8))
plt.xlabel("epoch")
plt.ylabel("validation accuracy")
plt.plot(x,record1.T[3].tolist(),label='EPSHN')
plt.plot(x,record2.T[3].tolist(),label='EPSHN with background')
plt.legend()
plt.ylim(0.30,0.60)
plt.title('Validation Accuracy')
plt.show()

In [None]:
Data='CUB'
CUB_data_dir = '/Users/yinjia/Downloads/Background/CUB_200_2011/CUB_200_2011/images'
Branch_data_dir = '/Users/yinjia/Downloads/Background/branch_images'

data_transforms = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(256),
                                      transforms.ToTensor(),
                                      transforms.Normalize(RGBmean[Data], RGBstdv[Data])])

CUB_dsets = ImageReader(CUB_data_dir, data_transforms)
Branch_dsets = ImageReader(Branch_data_dir, data_transforms)

out_dim = 64
avg= 8

#load model
model = resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, out_dim)
model.avgpool = nn.AvgPool2d(avg)
model.load_state_dict(torch.load('./models/model_params.pth')) 

In [None]:
#Sampler 
class BalanceSampler(Sampler):
    def __init__(self, intervals, GSize=2):
        
        class_len = len(intervals)
        list_sp = []
        
        # find the max interval
        interval_list = [np.arange(b[0],b[1]) for b in intervals]
        len_max = max([b[1]-b[0] for b in intervals])
        
        if len_max>1000:
            len_max = 100
        
        # exact division
        if len_max%GSize != 0:
            len_max = len_max+(GSize-len_max%GSize)
        
        for l in interval_list:
            if l.shape[0]<len_max:
                l_ext = np.random.choice(l,len_max-l.shape[0])
                l_ext = np.concatenate((l, l_ext), axis=0)
                l_ext = np.random.permutation(l_ext)
            elif l.shape[0]>len_max:
                l_ext = np.random.choice(l,len_max,replace=False)
                l_ext = np.random.permutation(l_ext)
            elif l.shape[0]==len_max:
                l_ext = np.random.permutation(l)
            
            list_sp.append(l_ext)
            
        random.shuffle(list_sp)
        self.idx = np.vstack(list_sp).reshape((GSize*class_len,-1)).T.reshape((1,-1)).flatten().tolist()

    def __iter__(self):
        return iter(self.idx)
    
    def __len__(self):
        return len(self.idx)

In [None]:
def embed(dsets, whichModel,batch_size):
    whichModel.train(False)
    intervals = dsets.intervals
    dataLoader = torch.utils.data.DataLoader(dsets, batch_size, sampler=BalanceSampler(intervals), num_workers=12)
    # iterate batch
    V,M1,M2,M3,M4,L = [],[],[],[],[],[]
    for data in dataLoader:
        with torch.set_grad_enabled(False):
            origins, inputs_bt, labels_bt = data # <FloatTensor> <LongTensor>
            
            # fvec is the unnormalized feature after fc layer
            # fmap is the magnitude feature before fc layer
            fvec, fmap1, fmap2, fmap3, fmap4= whichModel(inputs_bt)
            
        Images = data[0]
        V.extend(fvec)
        M1.extend(fmap1)
        M2.extend(fmap2)
        M3.extend(fmap3)
        M4.extend(fmap4)
        L.extend(labels_bt)
        break

    return Images, V, M1, M2, M3, M4, L

In [None]:
#Visualization of CUB Dataset (norm)
random.seed(seed)
np.random.seed(seed)
batch_size = 20
Images, feats, maps1, maps2, maps3, maps4, labels = embed(CUB_dsets, model, batch_size)
print(Images[0].size())

num = len(Images)
plt.figure(figsize=(32,100))
norm1 = matplotlib.colors.Normalize(vmin=0, vmax=0.9)
norm2 = matplotlib.colors.Normalize(vmin=0, vmax=0.45)
norm3 = matplotlib.colors.Normalize(vmin=0, vmax=0.35)
norm4 = matplotlib.colors.Normalize(vmin=0, vmax=5)


for i in range(num):
    plt.subplot(num,5,5*i+1)
    plt.imshow(Images[i].permute(1,2,0))
    plt.axis('off')
    plt.subplot(num,5,5*i+2)
    plt.imshow(maps4[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm4)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+3)
    plt.imshow(maps3[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm3)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+4)
    plt.imshow(maps2[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm2)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+5)
    plt.imshow(maps1[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm1)
    plt.colorbar()
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
#Visualization of CUB Dataset (topk)
random.seed(seed)
np.random.seed(seed)
batch_size = 20
Images, feats, maps1, maps2, maps3, maps4, labels = embed(CUB_dsets, model, batch_size)
print(Images[0].size())

num = len(Images)
plt.figure(figsize=(32,100))
norm1 = matplotlib.colors.Normalize(vmin=0, vmax=10)
norm2 = matplotlib.colors.Normalize(vmin=0, vmax=10)
norm3 = matplotlib.colors.Normalize(vmin=0, vmax=10)
norm4 = matplotlib.colors.Normalize(vmin=0, vmax=15)


for i in range(num):
    plt.subplot(num,5,5*i+1)
    plt.imshow(Images[i].permute(1,2,0))
    plt.axis('off')
    plt.subplot(num,5,5*i+2)
    plt.imshow(maps4[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm4)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+3)
    plt.imshow(maps3[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm3)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+4)
    plt.imshow(maps2[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm2)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+5)
    plt.imshow(maps1[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm1)
    plt.colorbar()
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
#Visualization of Branch Images (norm)
random.seed(seed)
np.random.seed(seed)
batch_size = 20
Images, feats, maps1, maps2, maps3, maps4, labels = embed(Branch_dsets, model, batch_size)
print(Images[0].size())

num = len(Images)
plt.figure(figsize=(32,100))
norm1 = matplotlib.colors.Normalize(vmin=0, vmax=0.9)
norm2 = matplotlib.colors.Normalize(vmin=0, vmax=0.45)
norm3 = matplotlib.colors.Normalize(vmin=0, vmax=0.35)
norm4 = matplotlib.colors.Normalize(vmin=0, vmax=5)


for i in range(num):
    plt.subplot(num,5,5*i+1)
    plt.imshow(Images[i].permute(1,2,0))
    plt.axis('off')
    plt.subplot(num,5,5*i+2)
    plt.imshow(maps4[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm4)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+3)
    plt.imshow(maps3[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm3)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+4)
    plt.imshow(maps2[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm2)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+5)
    plt.imshow(maps1[i].pow(2).mean(0).sqrt(),cmap='rainbow',norm=norm1)
    plt.colorbar()
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
#Visualization of Branch Images (topk)
random.seed(seed)
np.random.seed(seed)
batch_size = 20
Images, feats, maps1, maps2, maps3, maps4, labels = embed(Branch_dsets, model, batch_size)
print(Images[0].size())

num = len(Images)
plt.figure(figsize=(32,100))
norm1 = matplotlib.colors.Normalize(vmin=0, vmax=10)
norm2 = matplotlib.colors.Normalize(vmin=0, vmax=10)
norm3 = matplotlib.colors.Normalize(vmin=0, vmax=10)
norm4 = matplotlib.colors.Normalize(vmin=0, vmax=15)


for i in range(num):
    plt.subplot(num,5,5*i+1)
    plt.imshow(Images[i].permute(1,2,0))
    plt.axis('off')
    plt.subplot(num,5,5*i+2)
    plt.imshow(maps4[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm4)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+3)
    plt.imshow(maps3[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm3)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+4)
    plt.imshow(maps2[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm2)
    plt.colorbar()
    plt.axis('off')
    plt.subplot(num,5,5*i+5)
    plt.imshow(maps1[i].topk(3,dim=0)[0].mean(0),cmap='rainbow',norm=norm1)
    plt.colorbar()
    plt.axis('off')
plt.tight_layout()
plt.show()