In [2]:
import sys
import torch  
import numpy as np  
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pandas as pd
from torch.distributions.categorical import Categorical
import math
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from preprocess import mean, std, preprocess_input_function
from settings import train_dir, test_dir, train_push_dir, train_batch_size, test_batch_size, train_push_batch_size
from settings import base_architecture, img_size, prototype_shape, num_classes, prototype_activation_function, add_on_layers_type
from receptive_field import compute_rf_prototype
import cv2
#from reward_model import construct_PrefNet, paired_cross_entropy_loss, PrefNet
from tqdm import tqdm
from settings import joint_optimizer_lrs, joint_lr_step_size
import skimage as sk
import skimage.io as skio
import train_and_test as tnt
from torch.utils.data import Subset
import time
import heapq
import model
from PIL import Image
import protopformer

In [3]:
normalize = transforms.Normalize(mean=mean, std=std)
img_preprocess = transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ])
mask_preprocess = transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
    ])

train_dataset = datasets.ImageFolder(
        train_push_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=80, shuffle=False,
    num_workers=1, pin_memory=False)

test_dataset = datasets.ImageFolder(
        test_dir,
        transforms.Compose([
            transforms.Resize(size=(img_size, img_size)),
            transforms.ToTensor(),
            normalize,
        ]))

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False,
    num_workers=1, pin_memory=False)

mask_dir = './activation_mask/segmentations/'
mask_dataset = datasets.ImageFolder(
        mask_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ]))

mask_loader = torch.utils.data.DataLoader(
    mask_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=1, pin_memory=False)

In [8]:
# ProtoPFormer ProtoPFormer/output_cosine/CUB2011U/deit_tiny_patch16_224 ProtoPFormer/output_cosine/CUB2011U/deit_tiny_patch16_224/1028--adamw-0.05-200-protopformer/checkpoints
base_architecture = 'cait_xxs24_224'
num_classes = 200
prototype_shape = [2000, 192, 1, 1]
reserve_layers=[11]
reserve_token_nums=[81]
use_global=True
use_ppc_loss=True
ppc_cov_thresh=1.
ppc_mean_thresh=2.
global_coe=0.5
global_proto_per_class=10
prototype_activation_function='log'
add_on_layers_type='regular'

#ckpt = torch.load(f'../ProtoPFormer/output_cosine/CUB2011U/'+ base_architecture + ' ProtoPFormer/1028--adamw-0.05-200-protopformer/checkpoints/epoch-best.pth')
ckpt = torch.load('../ProtoPFormer/output_cosine/CUB2011U/cait_xxs24_224/1028--adamw-0.05-200-protopformer/checkpoints/epoch-best.pth')
ppnet = protopformer.construct_PPNet(base_architecture=base_architecture,
                                pretrained=True, img_size=img_size,
                                prototype_shape=prototype_shape,
                                num_classes=num_classes,
                                reserve_layers=reserve_layers,
                                reserve_token_nums=reserve_token_nums,
                                use_global=use_global,
                                use_ppc_loss=use_ppc_loss,
                                ppc_cov_thresh=ppc_cov_thresh,
                                ppc_mean_thresh=ppc_mean_thresh,
                                global_coe=global_coe,
                                global_proto_per_class=global_proto_per_class,
                                prototype_activation_function=prototype_activation_function,
                                add_on_layers_type=add_on_layers_type)

ppnet = ppnet.cuda()
ppnet.load_state_dict(ckpt['model'])
ppnet = torch.nn.DataParallel(ppnet)
#pf_model = torch.load('./human_comparisons/pref_model_700_random_rating_split0.7_acc0.915.pth')

In [9]:
bad_img_idx = [193, 1764, 2472, 3082]
#bad_img_idx = []
percentile_threshold = 95
#total_overlap = 0
num_imgs = len(test_dataset.imgs)

#GZ
total_local_overlap = 0 
total_global_overlap = 0 
for i in tqdm(range(num_imgs)):
    if i in bad_img_idx:
        continue
    test_img_dir = test_dataset.imgs[i][0]
    sub_dir = test_img_dir[44:-4]
    test_img = Image.open(test_img_dir)
        
    test_img = img_preprocess(test_img)
    mask_img_dir = mask_dir + sub_dir + '.png'
    mask_img = Image.open(mask_img_dir)
    mask_img = mask_preprocess(mask_img).numpy()
    
    ppnet.eval()
    n_prototypes = ppnet.module.num_prototypes
    num_per_class = n_prototypes // 200
    prototype_shape = ppnet.module.prototype_shape

    test_img = test_img.unsqueeze(0).cuda()
    _, proto_acts = ppnet.module.push_forward(test_img)
    proto_dist = proto_acts.detach().cpu().numpy()

    local_img_overlap = 0
    global_img_overlap = 0
    
    class_identity = test_dataset.imgs[i][1]
    for j in range(num_per_class):
        act_pattern = np.log((proto_dist[0][class_identity * num_per_class + j] + 1)/(proto_dist[0][class_identity * num_per_class + j] + ppnet.module.epsilon))
        upsampled_act_pattern = cv2.resize(act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC)
        th = np.percentile(upsampled_act_pattern, percentile_threshold)
        upsampled_act_pattern = (upsampled_act_pattern > th) * upsampled_act_pattern

        prototype_overlap = np.sum(np.multiply(mask_img, upsampled_act_pattern)) / np.sum(upsampled_act_pattern)
        local_img_overlap += prototype_overlap
        
    local_img_overlap /= num_per_class
    total_local_overlap += local_img_overlap
    
    #global precision  
    cls_tokens, _ = ppnet.module.prototype_distances(test_img)
    global_activations, _ = ppnet.module.get_activations(cls_tokens, ppnet.module.prototype_vectors_global)
    global_activations = global_activations.detach().cpu().numpy()
    global_activation_pattern = global_activations[0].reshape(img_size, img_size)
    global_overlap = np.sum(np.multiply(mask_img, global_activation_pattern)) / np.sum(global_activation_pattern)
    total_global_overlap += global_overlap
    
avg_local_overlap = total_local_overlap / num_imgs
avg_global_overlap = total_global_overlap / num_imgs

print("Average Local Overlap: ", avg_local_overlap)
print("Average Global Overlap: ", avg_global_overlap)  

  0%|          | 0/5794 [00:00<?, ?it/s]


UnboundLocalError: cannot access local variable 'cls_token_attn' where it is not associated with a value

In [12]:
bad_img_idx = [193, 1764, 2472, 3082]
#bad_img_idx = []
percentile_threshold = 95
#total_overlap = 0
num_imgs = len(test_dataset.imgs)

#GZ
total_local_overlap = 0 
total_global_overlap = 0 
for i in tqdm(range(num_imgs)):
    if i in bad_img_idx:
        continue
    test_img_dir = test_dataset.imgs[i][0]
    sub_dir = test_img_dir[44:-4]
    test_img = Image.open(test_img_dir)
        
    test_img = img_preprocess(test_img)
    mask_img_dir = mask_dir + sub_dir + '.png'
    mask_img = Image.open(mask_img_dir)
    mask_img = mask_preprocess(mask_img).numpy()
    
    ppnet.eval()
    #extract local prototypes? 
    n_prototypes = ppnet.module.num_prototypes
    num_per_class = n_prototypes // 200
    prototype_shape = ppnet.module.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]
    protoL_rf_info = ppnet.module.proto_layer_rf_info

    test_img = test_img.unsqueeze(0).cuda()
    protoL_input_torch, proto_dist_torch = ppnet.module.push_forward(test_img)
    proto_dist = proto_dist_torch.detach().cpu().numpy()
    
    #img_overlap = 0
    #GZ 
    local_img_overlap = 0 
    global_activations, _ = ppnet.module.get_activations(cls_tokens, ppnet.module.prototype_vectors_global)

    class_identity = test_dataset.imgs[i][1]
    for j in range(num_per_class):
        act_pattern = np.log((proto_dist[0][class_identity * num_per_class + j] + 1)/(proto_dist[0][class_identity * num_per_class + j] + ppnet.module.epsilon))
        upsampled_act_pattern = cv2.resize(act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC)
        th = np.percentile(upsampled_act_pattern, percentile_threshold)
        upsampled_act_pattern = (upsampled_act_pattern > th) * upsampled_act_pattern
        
        prototype_overlap = np.sum(np.multiply(mask_img, upsampled_act_pattern)) / np.sum(upsampled_act_pattern)
        #GZ
        local_img_overlap += prototype_overlap 
    
    #GZ 
    local_img_overlap = local_img_overlap / num_per_class
    total_local_overlap += local_img_overlap 
    
    total_overlap += img_overlap 
    
print("Final score: ", total_overlap / num_imgs)
    

  0%|          | 0/5794 [00:00<?, ?it/s]


UnboundLocalError: cannot access local variable 'cls_token_attn' where it is not associated with a value

In [88]:
total_overlap / 5794

5

In [27]:
rescaled_act_pattern.shape

(224,)

In [9]:
mask0[100][150]

1.0