In [None]:
import pickle
import numpy as np
from scipy.spatial.distance import pdist, squareform

import matplotlib
import matplotlib.pyplot as plt
import networkx as nx

from dataloader.QuickdrawDataset4dict_2nn4nnjnn import *

In [None]:
def plot_graph(p, coord, W, title=None, invert=False):
    """
    Args:
        p: Matplotlib figure/subplot
        coord: Coordinates of nodes
        W: Adjacency matrix (where 0 -> connected; -1e10 -> not connected)
        title: Title of figure/subplot
    
    Returns:
        p: Updated figure/subplot
    """
    
    def _W_to_node_pairs(W):
        """Helper function to convert adjacency matrix into pairs of adjacent nodes
        """
        pairs = []
        for r in range(len(W)):
            for c in range(len(W)):
                if W[r][c] == 0:
                    pairs.append((r, c))
        return pairs
    
    # Get pair-wise distance (Only for graph building in networkx)
    W_val = squareform(pdist(coord, metric='euclidean'))
    G = nx.from_numpy_matrix(W_val)
    # Create node_id : node_coord dictionary (for plotting nodes)
    pos = dict(zip(range(len(coord)), coord.tolist()))
    # Create (edge_id_1, edge_id_2) list (for plotting connected edges) 
    edgelist = _W_to_node_pairs(W)
    
    # Draw graph nodes
    nx.draw_networkx_nodes(G, pos, node_color='black', node_size=10)
    # Draw graph edges
    nx.draw_networkx_edges(G, pos, edgelist, alpha=1, width=1, edge_color='grey')
    
    # Set title
    if title is not None:
        p.set_title(title)
        
    # Invert axes (QuickDraw coordinates are upside-down)
    if invert:
        plt.gca().invert_yaxis()
    
    return p

In [None]:
# Manually plot for individual array
name = 'cat_3'

def get_stroke_len(arr):
    for i in range(len(arr)-1,-1,-1):
        if ((arr[i] == np.array([0, 0, 0, 0])).all()):
            return i
    return 100

def get_flags(input_array):
    out_array = np.zeros([100, 1], int)
    assert input_array.shape == (100, 2)
    for idx, bits in enumerate(input_array):
        if ((bits == [1, 0]).all()):
            out_array[idx] = 100
        elif ((bits == [0, 1]).all()):
            out_array[idx] = 101
        else:
            out_array[idx] = 102
    return out_array

arr = np.load(f'images/load/{name}.npy')
stroke_len = get_stroke_len(arr)
coord = arr[:, :2]
flags = get_flags(arr[:, 2:])
Ws = [np.ones([100, 100], int)*-1e10,
     produce_adjacent_matrix_2_neighbors(flags, stroke_len), 
     produce_adjacent_matrix_4_neighbors(flags, stroke_len),
     produce_adjacent_matrix_joint_neighbors(flags, stroke_len)]

f_idx = 0
f = plt.figure(f_idx, figsize=(len(Ws)*5, 5))
f.set_tight_layout(True)
for idx, mask in enumerate(Ws):
    ax = f.add_subplot(101 + len(Ws)*10 + idx)
    plot_graph(ax, 
               coord=coord[:stroke_len], 
               W=mask[:stroke_len, :stroke_len],
               title=f"Graph {idx+1}",
               invert=True)
# plt.savefig(f"images/png/{name}-graphs.png", format='png', dpi=600, bbox_inches="tight")
plt.show()
f_idx += 1