In [3]:
import sys
import torch  
import gym
import numpy as np  
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pandas as pd
from torch.distributions.categorical import Categorical
import math
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from preprocess import mean, std, preprocess_input_function
from settings import train_dir, test_dir, train_push_dir, train_batch_size, test_batch_size, train_push_batch_size
from settings import base_architecture, img_size, prototype_shape, num_classes, prototype_activation_function, add_on_layers_type
from receptive_field import compute_rf_prototype
import cv2
from preference_model import construct_PrefNet, paired_cross_entropy_loss, PrefNet
from tqdm import tqdm
from settings import joint_optimizer_lrs, joint_lr_step_size
import skimage as sk
import skimage.io as skio
import train_and_test as tnt
from torch.utils.data import Subset
import time
import heapq

In [3]:
'''
Use PPnet's forward pass as the policy network (actor network); what about the network for value function (critic network)?
Since there are only determinant actions, this is essentially A2C...
'''
class A3C_PPnet(nn.Module):
    def __init__(self, PPnet, preference_model, k=3, p=5, learning_rate=1e-7, dummy_reward=False, train_batch_size=80):
        super(A3C_PPnet, self).__init__()
        
        
        self.PPnet = PPnet.cuda()
        #for param in self.PPnet.features.parameters():
        #    param.requires_grad = True
        self.k = k
        self.pf_model = preference_model.cuda()
        
        #self.PPnet_multi = torch.nn.DataParallel(self.PPnet)
        self.PPnet_multi = self.PPnet
        for p in self.PPnet_multi.module.features.parameters():
            p.requires_grad = False
        for p in self.PPnet_multi.module.add_on_layers.parameters():
            p.requires_grad = True
        self.PPnet_multi.module.prototype_vectors.requires_grad = True
        for p in self.PPnet_multi.module.last_layer.parameters():
            p.requires_grad = False
        #self.critic_model = self.construct_critic().cuda()
        self.p = p
        self.critic_model = Critic().cuda()
        self.train_batch_size = train_batch_size
        self.num_epoch = 0
        policy_optimizer_specs = [#{'params': self.PPnet.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, 
                                  #{'params': self.PPnet.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3},
                                  #{'params': self.PPnet.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']},
                                  #{'params': self.PPnet.add_on_layers.parameters(), 'lr': 1e-6, 'weight_decay': 0},
                                  {'params': self.PPnet.module.prototype_vectors, 'lr': 1e-4, 'weight_decay': 1e-3}
            
                                  ]
        self.policy_optimizer = torch.optim.Adam(policy_optimizer_specs)
        #self.policy_optimizer = torch.optim.Adam(self.PPnet.features.parameters())
        self.critic_optimizer = torch.optim.Adam(self.critic_model.parameters())
        self.num_iteration = 0
        
    def get_heatmaps(self, batch_x, labels, dummy=False, track=False, save_prototypes=[], save_epochs=[]):
        self.PPnet_multi.eval()
        n_prototypes = self.PPnet_multi.module.num_prototypes
        prototype_shape = self.PPnet_multi.module.prototype_shape
        max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]
        protoL_rf_info = self.PPnet_multi.module.proto_layer_rf_info
        
        batch_x = batch_x.cuda()
        protoL_input_torch, proto_dist_torch = self.PPnet_multi.module.push_forward(batch_x)
        
        proto_dist_ = proto_dist_torch.view(proto_dist_torch.shape[0], proto_dist_torch.shape[1], -1)
        distances = torch.amin(proto_dist_, axis=-1)
        #distances = torch.tensor(distances)
        #print("Distances grad: ", distances.grad)
        actions = self.sample_from_distances(distances, labels, track=track)
        proto_dist = torch.clone(proto_dist_torch)
        # Move to cpu and cast to numpy here
        # proto_dist shape: (1000, 80, 7, 7)
        proto_dist = torch.transpose(proto_dist, 0, 1)
        proto_dist = proto_dist.detach().cpu().numpy()
        heatmaps = []
        joint_log_probs = []
        r = []
        patch_idx_batch = []
        for action in actions:
            img_idx, probs, j, class_identity = action[0], action[1], action[2], action[3]
            heatmaps_j = []
            r_j = []
            #patch_idx_prototype
            for i in img_idx:
                # patch idx [0-6, 0-6]
                closest_patch_indices_in_distance_map_j = list(np.unravel_index(np.argmin(proto_dist[j][i],axis=None), proto_dist[j][i].shape))
                closest_patch_indices_in_distance_map_j = [0] + closest_patch_indices_in_distance_map_j
                #print(closest_patch_indices_in_distance_map_j)
                closest_patch_indices_in_img = compute_rf_prototype(batch_x.size(2), closest_patch_indices_in_distance_map_j, protoL_rf_info)
                closest_patch = \
                    batch_x[i, :, closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2], closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]]
                closest_patch = closest_patch.cpu().numpy()
                closest_patch = np.transpose(closest_patch, (1, 2, 0))

                original_img = batch_x[i].cpu().numpy()
                original_img = np.transpose(original_img, (1, 2, 0))
                if self.PPnet_multi.module.prototype_activation_function == 'log':
                    act_pattern = np.log((proto_dist[j][i] + 1)/(proto_dist[j][i] + self.PPnet_multi.module.epsilon))
                elif self.PPnet_multi.module.prototype_activation_function == 'linear':
                    act_pattern = max_dist - proto_dist[j][i]
                else:
                    act_pattern = prototype_activation_function_in_numpy(proto_dist[j][i])

                patch_indices = closest_patch_indices_in_img[1:5]
                #print(j)
                #if j in [0, 1, 2, 3, 4]:
                #print(patch_indices)
                
                img_size = original_img.shape[0]
                
                # dummy 1: centralize
                #score = img_size * 2 - np.absolute(img_size//2 - patch_indices[0]) - np.absolute(img_size//2 - patch_indices[1]) - np.absolute(img_size//2 - patch_indices[2]) - np.absolute(img_size//2 - patch_indices[3])
                #score = -(patch_indices[1] - patch_indices[0]) * (patch_indices[3] - patch_indices[2])
            
                # dummy 2: maximize area of heatmap
                #score = (patch_indices[1]-patch_indices[0]) * (patch_indices[3] - patch_indices[2])
                
                #print("act_pattern: ", act_pattern.shape)
                upsampled_act_pattern = cv2.resize(act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC)
                rescaled_act_pattern = upsampled_act_pattern - np.amin(upsampled_act_pattern)
                rescaled_act_pattern = rescaled_act_pattern / np.amax(rescaled_act_pattern)
                
                heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_pattern), cv2.COLORMAP_JET)
                heatmap = np.float32(heatmap) / 255
                heatmap = heatmap[..., ::-1]
                overlayed_original_img = 0.5 * original_img + 1.0 * heatmap
                overlayed_original_img = overlayed_original_img - np.amin(overlayed_original_img)
                overlayed_original_img = overlayed_original_img / np.amax(overlayed_original_img)
                
                # dummy centralize:
                center_filter = np.zeros((224, 224))
                center_filter[90:134, 90:134] = 1
                
                center_filter[:30, :] = -0.01
                center_filter[:, :30] = -0.01
                center_filter[-30:, :] = -0.01
                center_filter[:, -30:] = -0.01
                
                center_filter = cv2.GaussianBlur(center_filter, (15, 15), 100)
                score = np.sum(rescaled_act_pattern * center_filter)
                
                #print(self.num_epoch, j)
                #if self.num_epoch in save_epochs and j in save_prototypes:
                    
                    #plt.imsave(r'./A3C_results/004_epoch_'+str(self.num_epoch)+'_prototype_'+str(j)+'_best_'+str(len(heatmaps_j)+1)+'.jpg', overlayed_original_img)
                if dummy:
                    heatmaps_j.append(overlayed_original_img)
                    r_j.append(score)
                else:
                    heatmaps_j.append(overlayed_original_img)
            joint_log_prob = torch.prod(probs) * math.factorial(self.k)
            #print(joint_log_prob.grad_fn)
            heatmaps.append(heatmaps_j)
            if dummy:
                r.append(r_j)
            joint_log_probs.append(joint_log_prob)
                
        # num_prototypes * self.k heatmaps in total
        # num_prototypes probs
        #for prob in joint_log_probs:
        #    print(prob.grad_fn)
        if dummy:
            r = np.sum(np.array(r), axis=1)
            r = torch.tensor(r)
            return heatmaps, joint_log_probs, distances, r
        return heatmaps, joint_log_probs, distances
    
    def sample_from_distances(self, distances, labels, track_iters=[], track=False):
        '''
        Takes in distances of shape (80, 1000)
        returns actions of shape (1000, ), one for each prototype
        '''
        
        distances = torch.clip(distances, min=1e-7, max=None)
        similarities = 1 / distances
        #print(similarities)
        softmax_dist = F.log_softmax(similarities, dim=0)
        softmax_dist = torch.transpose(softmax_dist, 0, 1)
        # Maybe using combinatorics?
        actions = []
        # For each of the 1000 prototypes...
        for i in range(softmax_dist.shape[0]):
            class_identity = torch.argmax(self.PPnet_multi.module.prototype_class_identity[i])
            #print(class_identity, class_identity.shape)
            class_dist = softmax_dist[i][labels==class_identity]
            
            #print(class_dist, class_dist.shape)
            if len(class_dist) > self.k:
                #print(class_dist)
                dist = Categorical(class_dist)
                img_idx = dist.sample(sample_shape=torch.tensor([self.k]))
                probs = dist.log_prob(img_idx)
                probs = torch.exp(probs)
                actions.append([img_idx, probs, i, class_identity])
        return actions
    
    def construct_critic(self):
        critic_model = nn.Sequential(
                        nn.Linear(512 * self.k * 7 * 7, 120),
                        nn.Sigmoid(),
                        nn.Linear(120, 20),
                        nn.Sigmoid(),
                        nn.Linear(20, 1)
                        )
        return critic_model
    
    # Currently just the same architecture as the pref_net
    '''
    def critic(self, heatmaps):
        values = torch.empty(len(heatmaps))
        for i in tqdm(range(len(heatmaps))):
            #x = torch.tensor(heatmaps[i])
            x = np.concatenate(heatmaps[i], axis=1)
            x = torch.tensor(x).cuda()
            x = torch.unsqueeze(x, axis=0)
            x = torch.transpose(x, 1, 3)
            with torch.no_grad():
                x = self.pf_model.conv_features(x)
                x = torch.flatten(x, 1) # flatten all dimensions except batch
            x = self.critic_model(x)
            values[i] = x
            #print(i)
        return values
    '''
    
    # Need to vectorize
    def get_critic_inputs(self, heatmaps, dummy=False):
        critic_inputs = []
        for i in range(len(heatmaps)):
            x = np.concatenate(heatmaps[i], axis=1)
            x = torch.tensor(x).cuda()
            x = torch.unsqueeze(x, axis=0)
            x = torch.transpose(x, 1, 3)
            with torch.no_grad():
                x = self.pf_model.conv_features(x)
                x = torch.flatten(x, 1) # flatten all dimensions except batch
            critic_inputs.append(x)
        #print(len(critic_inputs), critic_inputs[0].shape)
        critic_inputs = torch.stack(critic_inputs, dim=0)
        critic_inputs = critic_inputs.view(critic_inputs.shape[0], -1)
        #print(critic_inputs.shape)
        return critic_inputs
        
    def get_rewards(self, heatmaps, dummy=False):
        if dummy:
            h = heatmaps
            rewards = np.empty(len(h))
            for i in range(len(h)):
                score = -np.sum(np.square(h[i][0]-h[i][1])) - np.sum(np.square(h[i][1]-h[i][2])) - np.sum(np.square(h[i][0]-h[i][2]))
            rewards[i] = score
            return torch.tensor(rewards)
        with torch.no_grad():
            rewards = torch.empty(len(heatmaps))
            for i in range(len(heatmaps)):
                pf_input = torch.tensor(np.array(heatmaps[i])).cuda()
                pf_input = pf_input.view(pf_input.shape[0]*pf_input.shape[1], pf_input.shape[2], pf_input.shape[3])
                pf_input = torch.transpose(pf_input, 0, 2)
                pf_input = torch.transpose(pf_input, 1, 2)
                pf_input = torch.unsqueeze(pf_input, axis=0)
                reward = self.pf_model(pf_input)
                rewards[i] = reward
                #print(i)
        return rewards
        
    def update_v1(self, rewards, values, probs):
        self.policy_optimizer.zero_grad()
        for prob in probs:
            prob = prob.cuda()
        rewards = rewards.cuda()
        values = values.cuda()
        policy_loss = 0
        '''
        Customized A2C
        '''
        for i in range(len(rewards)):
            policy_loss -= probs[i] * (rewards[i] - values[i])    
        
        '''
        Reward Filtering
        '''
        #for i in range(len(rewards)):
        #    if rewards[i] > 3000:
        #        policy_loss -= probs[i] * rewards[i]
        #if policy_loss != 0:
        #    policy_loss.backward(retain_graph=True)
        self.policy_optimizer.step()
        self.critic_optimizer.zero_grad()
        #print(list(self.critic_model.parameters())[0])
        critic_loss = 0
        for i in range(len(rewards)):
            critic_loss += (rewards[i] - values[i]) ** 2
        #print(critic_loss.grad_fn)
        critic_loss.backward()
        self.critic_optimizer.step()
        #print(list(self.critic_model.parameters())[0])
        
        return 
    
    def update_v2(self, rewards, values, probs):
        
        return
    
    
    def run(self, batch_x, labels, save_prototypes=[], save_epochs=[], track=False):

        # action is n_prototypes * k heatmaps
        heatmaps, probs, img_distances = self.get_heatmaps(batch_x, labels, track=track)
        
        '''
        Not necessary for reward filtering
        '''
        critic_inputs = self.get_critic_inputs(heatmaps)
        values = self.critic_model(critic_inputs)
        
        #print("Finished calculating values")
        rewards = self.get_rewards(heatmaps)
        
        self.update_v1(rewards, values, probs)
        
        #print("Finished updating. Done.")
        #if len(save_prototypes) > 0 and self.num_epoch in save_epochs:
            #for p in save_prototypes:
                #print(len(heatmaps))
            #    for k in range(len(heatmaps[p])):
            #        plt.imsave(r'./A3C_results/iter_'+str(self.num_iteration)+'_prototype_'+str(p)+'_best_'+str(k+1)+'.jpg', heatmaps[p][k])
            
        self.num_iteration += 1
        if self.num_iteration == 75:
            self.num_iteration = 0
            self.num_epoch += 1
        
        return rewards, values, probs
    
    def run_v2():
        return
    
    def run_dummy(self, batch_x, labels, save_prototypes=[], save_epochs=[], track=False):
        heatmaps, probs, img_distances, rewards  = self.get_heatmaps(batch_x, labels, dummy=True, track=track, save_prototypes=save_prototypes, save_epochs=save_epochs)
        
        critic_inputs = self.get_critic_inputs(heatmaps)
        values = self.critic_model(critic_inputs)
        #values = []
        self.update_v1(rewards, values, probs)
        #if len(save_prototypes) > 0 and self.num_iteration in save_iters:
        #    for p in save_prototypes:
        #        
        #        for k in range(len(heatmaps[p])):
        #            plt.imsave(r'./A3C_results/epoch_'+str(self.num_epoch)+'_prototype_'+str(p)+'_best_'+str(k+1)+'_dummy.jpg', heatmaps[p][k])
            
        self.num_iteration += 1
        if self.num_iteration == 75:
            self.num_iteration = 0
            self.num_epoch += 1
        
        return rewards, values, probs, heatmaps
    
    
'''
Not very useful so far
'''
def visualize_prototypes(ppnet_multi, data_loader, save_prototypes=[], exp_num=0, dummy=True):
    ppnet_multi.eval()
    n_prototypes = ppnet_multi.module.num_prototypes
    prototype_shape = ppnet_multi.module.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]
    protoL_rf_info = ppnet_multi.module.proto_layer_rf_info

    for i, (batch_x, labels) in tqdm(enumerate(dataloader)):
        batch_x = batch_x.cuda()
        protoL_input_torch, proto_dist_torch = ppnet_multi.module.push_forward(batch_x)

        proto_dist_ = proto_dist_torch.view(proto_dist_torch.shape[0], proto_dist_torch.shape[1], -1)
        distances = torch.amin(proto_dist_, axis=-1)
        distances = torch.transpose(distances, 0, 1)
        similarities = 1 / distances
        # print(distances.shape)
        
        for j in p_idx:
            top_idx = torch.topk(similarities[j], k).indices.cpu()
            class_dist = similarities[j][labels==j]
            
            #print(class_dist, class_dist.shape)
            if len(class_dist) > 3:
                print("Prototype "+str(j)+": ", labels[top_idx])
    

In [4]:
class Critic(nn.Module):
    def __init__(self, k=3, learning_rate=3e-4):
        super(Critic, self).__init__()
        
        '''
        self.critic_linear1 = nn.Linear(num_inputs, hidden_size)
        self.critic_linear2 = nn.Linear(hidden_size, 1)

        self.actor_linear1 = nn.Linear(num_inputs, hidden_size)
        self.actor_linear2 = nn.Linear(hidden_size, num_actions)
        '''
        
        self.k = k
        self.fc1 = nn.Linear(512 * k * 7 * 7, 120)
        self.fc2 = nn.Linear(120, 20)
        self.fc3 = nn.Linear(20, 1)
        
    
    def forward(self, x):

        out = torch.sigmoid(self.fc1(x))
        out = torch.sigmoid(self.fc2(out))
        out = self.fc3(out)

        return out

In [5]:
# rescaled pattern should be (224, 224)
def get_dummy_reward(rescaled_pattern):
    '''
    center
    '''
    '''
    center_filter = np.zeros((224, 224))
    center_filter[90:134, 90:134] = 1

    center_filter[:30, :] = -0.01
    center_filter[:, :30] = -0.01
    center_filter[-30:, :] = -0.01
    center_filter[:, -30:] = -0.01

    center_filter = cv2.GaussianBlur(center_filter, (15, 15), 100)
    score = np.sum(rescaled_pattern * center_filter)
    '''
    
    '''
    upper left corner
    '''
    '''
    corner_filter = np.zeros((224, 224))
    corner_filter[0:40, 0:40] = 1
    corner_filter = cv2.GaussianBlur(corner_filter, (15, 15), 100)
    score = np.sum(rescaled_pattern * corner_filter)
    '''
    
    '''
    bottom right corner
    '''
    corner_filter = np.zeros((224, 224))
    corner_filter[180:224, 180:224] = 1
    corner_filter = cv2.GaussianBlur(corner_filter, (15, 15), 100)
    score = np.sum(rescaled_pattern * corner_filter)
    
    return score

In [6]:
ppnet = torch.load(r'../saved_models/vgg19/004/100_7push0.7344.pth')
ppnet = torch.nn.DataParallel(ppnet)
pf_model = construct_PrefNet("resnet18")
pf_model.load_state_dict(torch.load("./human_comparisons/pref_model_009_65+35_ep50_adam_0.0001"))
#pf_model = torch.load(r'./human_comparisons/pref_model_009_65+35_ep50_adam_0.0001_1')

<All keys matched successfully>

In [7]:
normalize = transforms.Normalize(mean=mean, std=std)

train_dataset = datasets.ImageFolder(
        train_push_dir,
        transforms.Compose([
        transforms.Resize(size=(img_size, img_size)),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=80, shuffle=False,
    num_workers=2, pin_memory=False)

test_dataset = datasets.ImageFolder(
        test_dir,
        transforms.Compose([
            transforms.Resize(size=(img_size, img_size)),
            transforms.ToTensor(),
            normalize,
        ]))
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False,
    num_workers=2, pin_memory=False)


#shuffled_dataset = Subset(train_dataset, shuffled_indices)

In [8]:
indices = []

'''
Each batch of size 80 consists of 16 shuffled blocks
'''
for i in range(200):
    class_i = [ind for ind, ele in enumerate(train_dataset.targets) if ele == i]
    indices.append(class_i[:5])
    indices.append(class_i[5:10])
    indices.append(class_i[10:15])
    indices.append(class_i[15:20])
    indices.append(class_i[20:25])
    indices.append(class_i[25:])

In [9]:
a3c = A3C_PPnet(ppnet, pf_model)

In [11]:
def reselect_prototypes(a3c, reward_threshold, data_loader, heatmaps):
    # get the heatmaps by searching for the closest images in the entire dataset
    # can use different k values
    # heatmaps: (1000, 1, 224, 224, 3)
    # rewards: (1000,)
    prototype_shape = a3c.PPnet_multi.module.prototype_shape
    bad_prototype_idx = []
    rewards = []
    for i in range(len(heatmaps)):
        reward = get_dummy_reward(heatmaps[i][0])
        rewards.append(reward)
        if reward < reward_threshold:
            bad_prototype_idx.append(i)
    
    global_max_rewards = np.zeros((200, 5))
    global_best_patches = torch.zeros((200, 5, 128, 1, 1)).cuda()
    
    class_comps = np.zeros((200, 5))
    
    for idx, (batch_x, labels) in tqdm(enumerate(data_loader)):
        a3c.PPnet_multi.eval()
        n_prototypes = a3c.PPnet_multi.module.num_prototypes
        prototype_shape = a3c.PPnet_multi.module.prototype_shape
        max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]
        protoL_rf_info = a3c.PPnet_multi.module.proto_layer_rf_info
        
        batch_x = batch_x.cuda()
        # conv_outs: (80, 128, 7, 7)
        conv_outs = a3c.PPnet_multi.module.conv_features(batch_x)

        # loop over each prototype
        for j in bad_prototype_idx:
            class_identity = torch.argmax(a3c.PPnet_multi.module.prototype_class_identity[j])
            class_outs = conv_outs[labels == class_identity]
            if class_outs.shape[0] == 0:
                continue
            height = class_outs.shape[2]
            width = class_outs.shape[3]
            
            for img_idx in range(class_outs.shape[0]):
                img_max_reward = 0
                img_best_patch = a3c.PPnet_multi.module.prototype_vectors.data[j]
                for h in range(height):
                    for w in range(width):
                        # actually old_vec might not be necessary
                        # old_vec = np.copy(a3c.PPnet_multi.module.prototype_vectors.data[j])
                        patch_candidate = class_outs[img_idx, :, h:h+1, w:w+1]
                        a3c.PPnet_multi.module.prototype_vectors.data[j] = patch_candidate
                        #a3c.PPnet_multi.module.prototype_vectors.data.copy_(vec)
                        # distances: (7, 7)
                        distances = a3c.PPnet_multi.module._l2_convolution(class_outs[img_idx])[j].detach().cpu().numpy()
                        
                        if a3c.PPnet_multi.module.prototype_activation_function == 'log':
                            act_pattern = np.log((distances + 1)/(distances + a3c.PPnet_multi.module.epsilon))
                        elif a3c.PPnet_multi.module.prototype_activation_function == 'linear':
                            act_pattern = max_dist - distances
                        else:
                            act_pattern = prototype_activation_function_in_numpy(distances)
                            
                        upsampled_act_pattern = cv2.resize(act_pattern, dsize=(img_size, img_size), interpolation=cv2.INTER_CUBIC)
                        rescaled_act_pattern = upsampled_act_pattern - np.amin(upsampled_act_pattern)
                        rescaled_act_pattern = rescaled_act_pattern / np.amax(rescaled_act_pattern)
                        
                        patch_reward = get_dummy_reward(rescaled_act_pattern)
                        
                        if patch_reward > img_max_reward:
                            img_max_reward = patch_reward
                            img_best_patch = patch_candidate
                        #else:
                        #    a3c.PPnet_multi.module.prototype_vectors.data[j] = old_vec
                            
                
                min_index = int(class_comps[class_identity][-1])
                if img_max_reward > global_max_rewards[class_identity][min_index]:
                    global_max_rewards[class_identity][min_index] = img_max_reward
                    global_best_patches[class_identity][min_index] = img_best_patch
                    class_comps[class_identity] = np.flip(np.argsort(global_max_rewards[class_identity]))
                   
    #print(class_comps)
    #print(global_max_rewards)
    for i in bad_prototype_idx:
        class_num = int(i // 5)
        p_num = int(class_comps[class_num][0])
        class_comps[class_num] = np.roll(class_comps[class_num], -1)
        a3c.PPnet_multi.module.prototype_vectors.data[i] = global_best_patches[class_num][p_num]
        
    return global_max_rewards
                        
            
            

In [13]:
heatmaps = find_k_nearest_patches_to_prototypes(train_loader, a3c.PPnet_multi, k=1)

find nearest patches
batch 0
batch 1
batch 2
batch 3
batch 4
batch 5
batch 6
batch 7
batch 8
batch 9
batch 10
batch 11
batch 12
batch 13
batch 14
batch 15
batch 16
batch 17
batch 18
batch 19
batch 20
batch 21
batch 22
batch 23
batch 24
batch 25
batch 26
batch 27
batch 28
batch 29
batch 30
batch 31
batch 32
batch 33
batch 34
batch 35
batch 36
batch 37
batch 38
batch 39
batch 40
batch 41
batch 42
batch 43
batch 44
batch 45
batch 46
batch 47
batch 48
batch 49
batch 50
batch 51
batch 52
batch 53
batch 54
batch 55
batch 56
batch 57
batch 58
batch 59
batch 60
batch 61
batch 62
batch 63
batch 64
batch 65
batch 66
batch 67
batch 68
batch 69
batch 70
batch 71
batch 72
batch 73
batch 74
	find nearest patches time: 	374.49913930892944


In [14]:
global_max_rewards = reselect_prototypes(a3c, 300, train_loader, heatmaps)

75it [35:53, 28.71s/it]


In [16]:
torch.save(a3c.PPnet_multi, r'./A3C_results/reselection_dummy_a3c_corner_bottom_right.pth')

In [17]:
#heatmaps.shape

In [9]:
def update_prototypes(a3c):
    return

In [4]:
class ImagePatch:

    def __init__(self, patch, label, distance,
                 original_img=None, act_pattern=None, patch_indices=None):
        self.patch = patch
        self.label = label
        self.negative_distance = -distance

        self.original_img = original_img
        self.act_pattern = act_pattern
        self.patch_indices = patch_indices

    def __lt__(self, other):
        return self.negative_distance < other.negative_distance


class ImagePatchInfo:

    def __init__(self, label, distance):
        self.label = label
        self.negative_distance = -distance

    def __lt__(self, other):
        return self.negative_distance < other.negative_distance

def find_k_nearest_patches_to_prototypes(dataloader, # pytorch dataloader (must be unnormalized in [0,1])
                                         prototype_network_parallel, # pytorch network with prototype_vectors
                                         k=3,
                                         preprocess_input_function=None, # normalize if needed
                                         full_save=False, # save all the images
                                         root_dir_for_saving_images='./nearest',
                                         log=print,
                                         prototype_activation_function_in_numpy=None, heatmap_ratio = 1.0):
    prototype_network_parallel.eval()
    '''
    full_save=False will only return the class identity of the closest
    patches, but it will not save anything.
    '''
    log('find nearest patches')
    start = time.time()
    n_prototypes = prototype_network_parallel.module.num_prototypes
    
    prototype_shape = prototype_network_parallel.module.prototype_shape
    max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3]

    protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info

    heaps = []
    # allocate an array of n_prototypes number of heaps
    for _ in range(n_prototypes):
        # a heap in python is just a maintained list
        heaps.append([])

    for idx, (search_batch_input, search_y) in tqdm(enumerate(dataloader)):
        #print('batch {}'.format(idx))
        if preprocess_input_function is not None:
            # print('preprocessing input for pushing ...')
            # search_batch = copy.deepcopy(search_batch_input)
            search_batch = preprocess_input_function(search_batch_input)

        else:
            search_batch = search_batch_input

        with torch.no_grad():
            search_batch = search_batch.cuda()
            protoL_input_torch, proto_dist_torch = \
                prototype_network_parallel.module.push_forward(search_batch)

        #protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy())
        proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy())
        
        # proto_dist_: (80, 1000, 7, 7)

        for img_idx, distance_map in enumerate(proto_dist_):
            for j in range(n_prototypes):
                # find the closest patches in this batch to prototype j

                closest_patch_distance_to_prototype_j = np.amin(distance_map[j])



                closest_patch_indices_in_distance_map_j = \
                    list(np.unravel_index(np.argmin(distance_map[j],axis=None),
                                          distance_map[j].shape))
                closest_patch_indices_in_distance_map_j = [0] + closest_patch_indices_in_distance_map_j
                closest_patch_indices_in_img = \
                    compute_rf_prototype(search_batch.size(2),
                                         closest_patch_indices_in_distance_map_j,
                                         protoL_rf_info)
                closest_patch = \
                    search_batch_input[img_idx, :,
                                       closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2],
                                       closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]]
                closest_patch = closest_patch.numpy()
                closest_patch = np.transpose(closest_patch, (1, 2, 0))

                original_img = search_batch_input[img_idx].numpy()
                original_img = np.transpose(original_img, (1, 2, 0))

                if prototype_network_parallel.module.prototype_activation_function == 'log':
                    act_pattern = np.log((distance_map[j] + 1)/(distance_map[j] + prototype_network_parallel.module.epsilon))
                elif prototype_network_parallel.module.prototype_activation_function == 'linear':
                    act_pattern = max_dist - distance_map[j]
                else:
                    act_pattern = prototype_activation_function_in_numpy(distance_map[j])

                # 4 numbers: height_start, height_end, width_start, width_end
                patch_indices = closest_patch_indices_in_img[1:5]

                # construct the closest patch object
                closest_patch = ImagePatch(patch=closest_patch,
                                           label=search_y[img_idx],
                                           distance=closest_patch_distance_to_prototype_j,
                                           original_img=original_img,
                                           act_pattern=act_pattern,
                                           patch_indices=patch_indices)
                '''
                else:
                
                closest_patch = ImagePatchInfo(label=search_y[img_idx],
                                                   distance=closest_patch_distance_to_prototype_j)
                '''

                # add to the j-th heap 
                if len(heaps[j]) < k:
                    heapq.heappush(heaps[j], closest_patch)
                else:
                    # heappushpop runs more efficiently than heappush
                    # followed by heappop
                    heapq.heappushpop(heaps[j], closest_patch)
                    

    # after looping through the dataset every heap will
    # have the k closest prototypes
    heatmaps = []
    for j in range(n_prototypes):
        # finally sort the heap; the heap only contains the k closest
        # but they are not ranked yet
        heaps[j].sort()
        heaps[j] = heaps[j][::-1]

        
        heatmaps_j = []
        for i, patch in enumerate(heaps[j]):
            
            img_size = patch.original_img.shape[0]
            upsampled_act_pattern = cv2.resize(patch.act_pattern,
                                               dsize=(img_size, img_size),
                                               interpolation=cv2.INTER_CUBIC)
            rescaled_act_pattern = upsampled_act_pattern - np.amin(upsampled_act_pattern)
            rescaled_act_pattern = rescaled_act_pattern / np.amax(rescaled_act_pattern)
            
            # No need for these if using dummy reward model
            '''
            heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_pattern), cv2.COLORMAP_JET)
            heatmap = np.float32(heatmap) / 255
            heatmap = heatmap[...,::-1]

            overlayed_original_img = 0.5 * patch.original_img + heatmap_ratio * heatmap
            overlayed_original_img = overlayed_original_img - np.amin(overlayed_original_img)
            overlayed_original_img = overlayed_original_img / np.amax(overlayed_original_img)
            '''
            heatmaps_j.append(rescaled_act_pattern)
        heatmaps.append(heatmaps_j)
    end = time.time()
    log('\tfind nearest patches time: \t{0}'.format(end - start))

    return heatmaps

In [None]:
def get_dummy_reward(rescaled_pattern):
    '''
    center
    '''
    '''
    center_filter = np.zeros((224, 224))
    center_filter[90:134, 90:134] = 1

    center_filter[:30, :] = -0.01
    center_filter[:, :30] = -0.01
    center_filter[-30:, :] = -0.01
    center_filter[:, -30:] = -0.01

    center_filter = cv2.GaussianBlur(center_filter, (15, 15), 100)
    score = np.sum(rescaled_pattern * center_filter)
    '''
    
    '''
    upper left corner
    '''
    '''
    corner_filter = np.zeros((224, 224))
    corner_filter[0:40, 0:40] = 1
    corner_filter = cv2.GaussianBlur(corner_filter, (15, 15), 100)
    score = np.sum(rescaled_pattern * corner_filter)
    '''
    
    '''
    bottom right corner
    '''
    corner_filter = np.zeros((224, 224))
    corner_filter[180:224, 180:224] = 1
    corner_filter = cv2.GaussianBlur(corner_filter, (15, 15), 100)
    score = np.sum(rescaled_pattern * corner_filter)
    
    return score

In [None]:
avg_reward_004 = []
reselection_epochs = [19, 59]

In [None]:
a3c = A3C_PPnet(ppnet, pf_model)

In [None]:
'''
Actual human feedback data & prototype reselection
'''
for epoch in range(100):
    order = np.random.permutation(1200)
    shuffled_idx = []
    for idx in order:
        shuffled_idx += indices[idx]
    #print(shuffled_idx)
    shuffled_dataset = Subset(train_dataset, shuffled_idx)
    
    dataloader = torch.utils.data.DataLoader(
    shuffled_dataset, batch_size=80, shuffle=False,
    num_workers=2, pin_memory=False)
    
    epoch_reward = 0
    
    for i, (batch, labels) in tqdm(enumerate(dataloader)):
        rewards, values, probs, heatmaps = a3c.run(batch, labels, save_prototypes=[], save_epochs=[])
        
        total_reward = 0
        mse_loss = 0
        for j in range(len(probs)):
            
            probs[j].detach().cpu().numpy()
            rewards[j].detach().cpu().numpy()
            total_reward += probs[j] * rewards[j]
            #mse_loss += (rewards[j] - values[j]) ** 2
            
            #probs[j].detach().cpu().numpy()
            #rewards[j].detach().cpu().numpy()
        epoch_reward += total_reward.item()
    
    if epoch in reselection_epochs:
        heatmaps = find_k_nearest_patches_to_prototypes(train_loader, a3c.PPnet_multi, k=3)
        
        
    
    print("Epoch "+str(a3c.num_epoch)+" "+str(i)+" average reward: ", epoch_reward/i)
    avg_reward_004.append(epoch_reward/i)