In [None]:
# GPU config
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import collections
import pickle
import time
import numpy as np
from scipy.spatial.distance import pdist, squareform
from tqdm import tqdm

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import networkx as nx

import torch
from torch.utils.data import DataLoader

# from dataloader.QuickdrawDataset4dict import *
from dataloader.QuickdrawDataset4dict_2nn4nnjnn import *

from utils.AverageMeter import AverageMeter
from utils.accuracy import *

In [None]:
batch_size = 200
num_workers = 16
# exp = "gt_2nn"
# epoch = 10

# from network.graph_transformer import *
# from network.graph_mlp import *
# from network.graph_convnet import *
# from network.graph_attention_net import *
from network.final.graph_transformer import *

In [None]:
data_dict = pickle.load(open('./dataloader/tiny_test_dataset_dict.pickle', 'rb'))
dataset = QuickdrawDataset("/home/peng/dataset/tiny_quickdraw_coordinate/test/", "./dataloader/tiny_test_set.txt", data_dict)

In [None]:
network_configs=collections.OrderedDict()
network_configs['output_dim']=345
network_configs['n_heads']=8
network_configs['embed_dim']=256
network_configs['n_layers']=4
network_configs['feed_forward_hidden']=4*network_configs['embed_dim']
network_configs['normalization']='batch'
network_configs['dropout']=0.1

net = make_model(n_classes=345, coord_input_dim=2, feat_input_dim=2, feat_dict_size=103, 
                 n_layers=network_configs['n_layers'], n_heads=network_configs['n_heads'], 
                 embed_dim=network_configs['embed_dim'], feedforward_dim=network_configs['feed_forward_hidden'], 
                 normalization=network_configs['normalization'], dropout=network_configs['dropout'])
net = net.cuda()
# net.load_state_dict(torch.load(f"./experimental_results/{exp}/checkpoints/{exp}_net_epoch{epoch}")["network"])
net.load_state_dict(torch.load("final_checkpoint")["network"])

In [None]:
def validate_function(dataloader):
    validation_loss = AverageMeter()
    validation_acc_1 = AverageMeter()
    validation_acc_5 = AverageMeter()
    validation_acc_10 = AverageMeter()
    
    net.eval()
    with torch.no_grad():
        for idx, (coordinate, label, flag_bits, stroke_len, attention_mask1, attention_mask2, attention_mask3, padding_mask, position_encoding) in enumerate(tqdm(dataloader, ascii=True)):
            
            coordinate = coordinate.cuda()
            label = label.cuda()
            flag_bits = flag_bits.cuda()
            stroke_len = stroke_len.cuda()
            attention_mask1 = attention_mask1.cuda()
            attention_mask2 = attention_mask2.cuda()
            attention_mask3 = attention_mask3.cuda()
            padding_mask = padding_mask.cuda()
            position_encoding = position_encoding.cuda()

            # Resize inputs
            flag_bits.squeeze_(2)
            position_encoding.squeeze_(2)
            stroke_len.unsqueeze_(1)
            
            output = net(coordinate, flag_bits, position_encoding, 
                         attention_mask1, attention_mask2, attention_mask3, 
                         padding_mask, stroke_len)

            batch_loss = nn.CrossEntropyLoss()(output, label)

            validation_loss.update(batch_loss.item(), coordinate.size(0))
             
            acc_1, acc_5, acc_10 = accuracy(output, label, topk = (1,5,10))
            validation_acc_1.update(acc_1, coordinate.size(0))
            validation_acc_5.update(acc_5, coordinate.size(0))
            validation_acc_10.update(acc_10, coordinate.size(0))

    print("loss: {}  acc@1:{}  acc@5:{}  acc@10:{}".format(
        validation_loss.avg, validation_acc_1.avg, validation_acc_5.avg, validation_acc_10.avg))

    return validation_loss, validation_acc_1, validation_acc_5, validation_acc_10

In [None]:
# # Sanity check for loading checkpoint: test accuracy should match reported performance
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# validate_function(dataloader)

In [None]:
def plot_graph(p, coord, W, title=None):
    """
    Args:
        p: Matplotlib figure/subplot
        coord: Coordinates of nodes
        W: Adjacency matrix (where 0 -> connected; -1e10 -> not connected)
        title: Title of figure/subplot
    
    Returns:
        p: Updated figure/subplot
    """
    
    def _W_to_node_pairs(W):
        """Helper function to convert adjacency matrix into pairs of adjacent nodes
        """
        pairs = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] == 0:
                    pairs.append((r, c))
        return pairs
    
    W_val = squareform(pdist(coord, metric='euclidean'))
    G = nx.from_numpy_matrix(W_val)
    pos = dict(zip(range(len(coord)), coord.tolist()))
    edgelist = _W_to_node_pairs(W)
    
    nx.draw_networkx_nodes(G, pos, node_color='black', node_size=10)
    nx.draw_networkx_edges(G, pos, edgelist, alpha=1, width=1, edge_color='grey')
    
    if title is not None:
        p.set_title(title)
        
    plt.gca().invert_yaxis()
    return p

def plot_heatmap(p, coord, W, W_pred, title=None):
    """
    Args:
        p: Matplotlib figure/subplot
        coord: Coordinates of nodes
        W: Adjacency matrix (where 0 -> connected; -1e10 -> not connected)
        W_pred: Edge prediction/attention matrix
        title: Title of figure/subplot
    
    Returns:
        p: Updated figure/subplot
    """
    
    # Convert to symmetric matrix
    W_pred = (W_pred + W_pred.T) * 0.5

    def _W_pred_to_node_pairs(W, W_pred):
        """Helper function to convert edge predictions into pairs of adjacent nodes
        """
        pairs = []
        edge_preds = []
        for r in range(len(W_pred)):
            for c in range(len(W_pred)):
                if W[r][c] == 0:
                    pairs.append((r, c))
                    edge_preds.append(W_pred[r][c])
        return pairs, edge_preds
    
    W_val = squareform(pdist(coord, metric='euclidean'))
    G = nx.from_numpy_matrix(W_val)
    pos = dict(zip(range(len(coord)), coord.tolist()))
    edgelist, edgepred = _W_pred_to_node_pairs(W, W_pred)
    
    nx.draw_networkx_nodes(G, pos, node_color='black', node_size=10)
    nx.draw_networkx_edges(G, pos, edgelist, edge_color=edgepred, edge_cmap=plt.cm.Reds, width=1)
    
    if title is not None:
        p.set_title(title)
    
    plt.gca().invert_yaxis()
    return p

def plot_heatmap_old(p, coord, W, W_pred, title=None):
    """
    Args:
        p: Matplotlib figure/subplot
        coord: Coordinates of nodes
        W: Adjacency matrix (where 0 -> connected; -1e10 -> not connected)
        W_pred: Edge prediction/attention matrix
        title: Title of figure/subplot
    
    Returns:
        p: Updated figure/subplot
    """
    
    def _W_pred_to_node_pairs(W, W_pred):
        """Helper function to convert edge predictions into pairs of adjacent nodes
        """
        pairs = []
        edge_preds = []
        for r in range(len(W_pred)):
            for c in range(len(W_pred)):
                if W[r][c] == 0:
                    pairs.append((r, c))
                    edge_preds.append(W_pred[r][c])
        return pairs, edge_preds
    
    W_val = squareform(pdist(coord, metric='euclidean'))
    G = nx.from_numpy_matrix(W_val)
    pos = dict(zip(range(len(coord)), coord.tolist()))
    edgelist, edgepred = _W_pred_to_node_pairs(W, W_pred)
    
    nx.draw_networkx_nodes(G, pos, node_color='black', node_size=10)
    nx.draw_networkx_edges(G, pos, edgelist, edge_color=edgepred, edge_cmap=plt.cm.Reds, width=1)
    
    if title is not None:
        p.set_title(title)
    
    plt.gca().invert_yaxis()
    return p

In [None]:
# # Manually plot for individual array
# name = 'cat_3'

# def get_stroke_len(arr):
#     for i in range(len(arr)-1,-1,-1):
#         if ((arr[i] == np.array([0, 0, 0, 0])).all()):
#             return i
#     return 100

# def get_flags(input_array):
#     out_array = np.zeros([100, 1], int)
#     assert input_array.shape == (100, 2)
#     for idx, bits in enumerate(input_array):
#         if ((bits == [1, 0]).all()):
#             out_array[idx] = 100
#         elif ((bits == [0, 1]).all()):
#             out_array[idx] = 101
#         else:
#             out_array[idx] = 102
#     return out_array

# arr = np.load(f'images/load/{name}.npy')
# stroke_len = get_stroke_len(arr)
# coord = arr[:, :2]
# flags = get_flags(arr[:, 2:])
# Ws = [np.ones([100, 100], int)*-1e10,
#      produce_adjacent_matrix_2_neighbors(flags, stroke_len), 
#      produce_adjacent_matrix_4_neighbors(flags, stroke_len),
#      produce_adjacent_matrix_joint_neighbors(flags, stroke_len)]

# f_idx = 0
# f = plt.figure(f_idx, figsize=(len(Ws)*5, 5))
# f.set_tight_layout(True)
# for idx, mask in enumerate(Ws):
#     ax = f.add_subplot(101 + len(Ws)*10 + idx)
#     plot_graph(ax, 
#                coord=coord[:stroke_len], 
#                W=mask[:stroke_len, :stroke_len],
#                title=f"Graph {idx+1}")
# plt.savefig(f"images/png/{name}-graphs.png", format='png', dpi=600, bbox_inches="tight")
# plt.show()
# f_idx += 1

In [None]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=num_workers)

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

net.eval()
with torch.no_grad():
    for batch_idx, (coordinate, label, flag_bits, stroke_len, attention_mask1, attention_mask2, attention_mask3, padding_mask, position_encoding) in enumerate(dataloader): 
        
        if int(label) == 67:
            
            name = {
                5: 'clock',
                76: 'clock', 
                67: 'cat', 
                94: 'dog', 
                98: 'dragon',
                26: 'bear',
                307: 'teddy',
                312: 'tiger',
                34: 'bird'
            }.get(int(label))

            coordinate = coordinate.cuda()
            label = label.cuda()
            flag_bits = flag_bits.cuda()
            stroke_len = stroke_len.cuda()
            attention_mask1 = attention_mask1.cuda()
            attention_mask2 = attention_mask2.cuda()
            attention_mask3 = attention_mask3.cuda()
            padding_mask = padding_mask.cuda()
            position_encoding = position_encoding.cuda()

            # Resize inputs
            flag_bits.squeeze_(2)
            position_encoding.squeeze_(2)
            stroke_len.unsqueeze_(1)

            output = net(coordinate, flag_bits, position_encoding, 
                         attention_mask1, attention_mask2, attention_mask3, 
                         padding_mask, stroke_len)
            
            acc_1, acc_5, acc_10 = accuracy(output, label, topk = (1,5,10))
            print(f"Groundtruth: {int(label)}, Prediction: {int(output.argmax())}")
            print(f"Accuracy @1: {int(acc_1)}, @5: {int(acc_5)}, @10: {int(acc_10)}")
            
            attention_1 = []
            attention_2 = []
            attention_3 = []
            for layer in net.encoder.transformer_layers:
                attention_1.append(layer.self_attention1.module.attn)
                attention_2.append(layer.self_attention2.module.attn)
                attention_3.append(layer.self_attention3.module.attn)

            # Convert back to numpy format for plotting
            coordinate = coordinate.cpu().numpy()
            label = label.cpu().numpy()
            flag_bits = flag_bits.cpu().numpy()
            stroke_len = stroke_len.cpu().numpy()
            attention_mask1 = attention_mask1.cpu().numpy()
            attention_mask2 = attention_mask2.cpu().numpy()
            attention_mask3 = attention_mask3.cpu().numpy()
            padding_mask = padding_mask.cpu().numpy()
            position_encoding = position_encoding.cpu().numpy()
            
            attention_1 = [attn.cpu().numpy() for attn in attention_1]
            attention_2 = [attn.cpu().numpy() for attn in attention_2]
            attention_3 = [attn.cpu().numpy() for attn in attention_3]
            
            stroke_len = int(stroke_len[0])
            
            f_idx = 0
            f = plt.figure(f_idx, figsize=(3*5, 5))
            f.set_tight_layout(True)
            for idx, mask in enumerate([attention_mask1, attention_mask2, attention_mask3]):
                ax = f.add_subplot(131 + idx)
                plot_graph(ax, 
                           coord=coordinate[0][:stroke_len], 
                           W=mask[0][:stroke_len, :stroke_len],
                           title=f"Graph {idx+1}")
            plt.show()
            f_idx += 1
            
            if input() == "y":

                for idx, (mask, attention) in enumerate(zip([attention_mask1, attention_mask2, attention_mask3], [attention_1, attention_2, attention_3])):
                    f = plt.figure(f_idx, figsize=(8*5, 4*5))
                    f.set_tight_layout(True)
                    subf_idx = 1
                    
                    print(f"Graph {idx+1}")
                    for layer in range(network_configs['n_layers']):
                        for head in range(network_configs['n_heads']):
                            ax = f.add_subplot(4, 8, subf_idx)
                            plot_heatmap_old(ax,
                                         coord=coordinate[0][:stroke_len], 
                                         W=mask[0][:stroke_len, :stroke_len],
                                         W_pred=attention[layer][head][0][:stroke_len, :stroke_len],
                                         title=None)
                            # "Graph {idx+1}, Layer {layer+1}, Head {head+1}"
                            subf_idx += 1
                    plt.savefig(f"images/final/{name}-g{idx+1}-old.pdf", format='pdf', dpi=1200, bbox_inches="tight")
                    plt.show()
                f_idx += 1
                
                for idx, (mask, attention) in enumerate(zip([attention_mask1, attention_mask2, attention_mask3], [attention_1, attention_2, attention_3])):
                    f = plt.figure(f_idx, figsize=(8*5, 4*5))
                    f.set_tight_layout(True)
                    subf_idx = 1
                    
                    print(f"Graph {idx+1}")
                    for layer in range(network_configs['n_layers']):
                        for head in range(network_configs['n_heads']):
                            ax = f.add_subplot(4, 8, subf_idx)
                            plot_heatmap(ax,
                                         coord=coordinate[0][:stroke_len], 
                                         W=mask[0][:stroke_len, :stroke_len],
                                         W_pred=attention[layer][head][0][:stroke_len, :stroke_len],
                                         title=None)
                            # "Graph {idx+1}, Layer {layer+1}, Head {head+1}"
                            subf_idx += 1
                    plt.savefig(f"images/final/{name}-g{idx+1}-new.pdf", format='pdf', dpi=1200, bbox_inches="tight")
                    plt.show()
                f_idx += 1

                break
            
            else:
                continue

In [None]:
# name = {
#     5: 'clock',
#     76: 'clock', 
#     67: 'cat', 
#     94: 'dog', 
#     98: 'dragon2',
#     26: 'bear',
#     307: 'bear',
#     312: 'tiger3',
#     34: 'bird'
# }.get(int(label))

# # Saving figures      
# f_idx = 0
# f = plt.figure(f_idx, figsize=(3*5, 5))
# f.set_tight_layout(True)
# for idx, mask in enumerate([attention_mask1, attention_mask2, attention_mask3]):
#     ax = f.add_subplot(131 + idx)
#     plot_graph(ax, 
#                coord=coordinate[0][:stroke_len], 
#                W=mask[0][:stroke_len, :stroke_len],
#                title=f"Graph {idx+1}")
# plt.savefig(f"images/{name}-graphs.pdf", format='pdf', dpi=1200, bbox_inches="tight")
# plt.savefig(f"images/png/{name}-graphs.png", format='png', dpi=300, bbox_inches="tight")
# plt.show()
# f_idx += 1


# # f = plt.figure(f_idx, figsize=(8*5, 4*3*5))
# # f.set_tight_layout(True)
# # subf_idx = 1
# # for idx, (mask, attention) in enumerate(zip([attention_mask1, attention_mask2, attention_mask3], [attention_1, attention_2, attention_3])):
# #     for layer in range(network_configs['n_layers']):
# #         for head in range(network_configs['n_heads']):
# #             ax = f.add_subplot(12, 8, subf_idx)
# #             plot_heatmap(ax,
# #                          coord=coordinate[0][:stroke_len], 
# #                          W=mask[0][:stroke_len, :stroke_len],
# #                          W_pred=attention[layer][head][0][:stroke_len, :stroke_len],
# #                          title=f"Graph {idx+1}, Layer {layer+1}, Head {head+1}")
# #             subf_idx += 1
# # plt.savefig(f"images/{name}_vert.pdf", format='pdf', dpi=1200, bbox_inches="tight")
# # plt.savefig(f"images/png/{name}_vert.png", format='png', dpi=300, bbox_inches="tight")
# # plt.show()
# # f_idx += 1


# f = plt.figure(f_idx, figsize=(3*4*5, 8*5))
# f.set_tight_layout(True)
# subf_idx = 1

# for head in range(network_configs['n_heads']):
#     for idx, (mask, attention) in enumerate(zip([attention_mask1, attention_mask2, attention_mask3], [attention_1, attention_2, attention_3])):
#         for layer in range(network_configs['n_layers']):
#             ax = f.add_subplot(8, 12, subf_idx)
#             plot_heatmap(ax,
#                          coord=coordinate[0][:stroke_len], 
#                          W=mask[0][:stroke_len, :stroke_len],
#                          W_pred=attention[layer][head][0][:stroke_len, :stroke_len],
#                          title=f"Graph {idx+1}, Layer {layer+1}, Head {head+1}")
#             subf_idx += 1
# plt.savefig(f"images/{name}_horz.pdf", format='pdf', dpi=1200, bbox_inches="tight")
# plt.savefig(f"images/png/{name}_horz.png", format='png', dpi=300, bbox_inches="tight")
# plt.show()
# f_idx += 1

In [None]:
# # Make GIFs
# def update(layer):
#     subf_idx = 1
#     for idx, (mask, attention) in enumerate(zip([attention_mask1, attention_mask2, attention_mask3], [attention_1, attention_2, attention_3])):
    
#         for head in range(network_configs['n_heads']):
#             ax = f.add_subplot(3, 8, subf_idx)
#             plot_heatmap(ax,
#                          coord=coordinate[0][:stroke_len], 
#                          W=mask[0][:stroke_len, :stroke_len],
#                          W_pred=attention[layer][head][0][:stroke_len, :stroke_len],
#                          title=f"Mask {idx+1}, Layer {layer+1}, Head {head+1}")
#             subf_idx += 1
            
# f = plt.figure(0, figsize=(8*5, 3*5))
# f.set_tight_layout(True)
# anim = FuncAnimation(f, update, frames=np.arange(0, network_configs['n_layers']), interval=1000)
# anim.save(f'images/gif/{name}.gif', dpi=300, writer='imagemagick')