In [7]:
import sys
import cv2
import os
import math
import pickle
import copy
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("C:\\Users\\user\\Documents\\Workspace\\MeronymNet-PyTorch\\src")

import torch
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import DataLoader
import torch.utils.data as data_utils
from torch.utils.tensorboard import SummaryWriter

from losses import BoxVAE_losses as loss
from components.AutoEncoder import GCNAutoEncoder
from components.Decoder import Decoder


In [8]:
colors = [(1, 0, 0),
          (0.737, 0.561, 0.561),
          (0.255, 0.412, 0.882),
          (0.545, 0.271, 0.0745),
          (0.98, 0.502, 0.447),
          (0.98, 0.643, 0.376),
          (0.18, 0.545, 0.341),
          (0.502, 0, 0.502),
          (0.627, 0.322, 0.176),
          (0.753, 0.753, 0.753),
          (0.529, 0.808, 0.922),
          (0.416, 0.353, 0.804),
          (0.439, 0.502, 0.565),
          (0.784, 0.302, 0.565),
          (0.867, 0.627, 0.867),
          (0, 1, 0.498),
          (0.275, 0.51, 0.706),
          (0.824, 0.706, 0.549),
          (0, 0.502, 0.502),
          (0.847, 0.749, 0.847),
          (1, 0.388, 0.278),
          (0.251, 0.878, 0.816),
          (0.933, 0.51, 0.933),
          (0.961, 0.871, 0.702)]
colors = (np.asarray(colors)*255)
canvas_size = 550
def plot_bbx(bbx):
    bbx = bbx*canvas_size
    canvas = np.ones((canvas_size,canvas_size,3), np.uint8) * 255
    for i, coord in enumerate(bbx):
        x_minp, y_minp,x_maxp , y_maxp= coord
        if [x_minp, y_minp,x_maxp , y_maxp]!=[0,0,0,0]:
            cv2.rectangle(canvas, (int(x_minp), int(y_minp)), (int(x_maxp) , int(y_maxp) ), colors[i], 6)
    return canvas


In [3]:
def inference(decoder, nodes, obj_class, latent_dims, batch_size):

    decoder.cuda()
    z_latent = torch.normal(torch.zeros([batch_size,latent_dims])).cuda()
    nodes = torch.reshape(nodes,(batch_size,decoder.num_nodes))
    obj_class = torch.reshape(obj_class,(batch_size, decoder.class_size))
    conditioned_z = torch.cat([nodes, z_latent],dim=-1)
    conditioned_z = torch.cat([obj_class, conditioned_z],dim=-1)
    
    x_bbx, x_lbl, x_edge, class_pred = decoder(conditioned_z)
        
    return x_bbx, x_lbl, x_edge, class_pred

In [5]:
batch_size = 64
test_idx = np.random.randint(1,len(X_test),3000)
outfile = 'D:/meronym_data/X_test.np'
with open(outfile, 'rb') as pickle_file:
    X_test = pickle.load(pickle_file)

outfile = 'D:/meronym_data/adj_test.np'
with open(outfile, 'rb') as pickle_file:
    adj_test = pickle.load(pickle_file)
    
outfile = 'D:/meronym_data/class_v_test.np'
with open(outfile, 'rb') as pickle_file:
    class_v_test = pickle.load(pickle_file)
    
val_list = []
for idx, batch in enumerate(zip(copy.deepcopy(X_test[test_idx]),
                                copy.deepcopy(class_v_test[test_idx]), 
                                copy.deepcopy(adj_test[test_idx]))):
    edge_index, _ = dense_to_sparse(torch.from_numpy(batch[2]).cuda().float())
    val_list.append(Data(x = torch.from_numpy(batch[0]).cuda().float(),
                         y = torch.from_numpy(batch[1]).cuda().float(),
                         edge_index = edge_index
                                )
                     )
batch_val_loader = DataLoader(val_list, batch_size=batch_size)




In [9]:
#testing loop
latent_dims = 64
batch_size = 64
num_nodes = 24
bbx_size = 4
num_classes = 10
label_shape = 1
nb_epochs = 200
klw = loss.frange_cycle_linear(nb_epochs)
learning_rate = 0.000062
hidden1 = 32
hidden2 = 16
hidden3 = 128

model_path = ('D:/meronym_data/model/GCN-lr-'
                        +str(learning_rate)
                        +'-batch-'+str(batch_size)
                        +'-h1-'+str(hidden1)
                        +'-h2-'+str(hidden2)
                        +'-h3-'+str(hidden3)+'-test')
summary_path = ('D:/meronym_data/runs/GCN-lr-'
                        +str(learning_rate)
                        +'-batch-'+str(batch_size)
                        +'-h1-'+str(hidden1)
                        +'-h2-'+str(hidden2)
                        +'-h3-'+str(hidden3)+'-test')
writer = SummaryWriter(summary_path)

vae = GCNAutoEncoder(latent_dims,num_nodes,bbx_size,num_classes,label_shape,hidden1, hidden2, hidden3)
vae.load_state_dict(torch.load(model_path+ '/model_weights.pth'))

decoder = vae.decoder
image_shape = [num_nodes, bbx_size]
global_step = 250000
for i, val_data in enumerate(batch_val_loader, 0):
    
    val_data.cuda()
    node_data_true = val_data.x
    label_true = node_data_true[:,:1]
    class_true = val_data.y
    output = inference(decoder, label_true , class_true, latent_dims, batch_size)
    node_data_pred_test, _, _, _ = output
    
    break

for i in range(batch_size):    
    image = plot_bbx(np.reshape((node_data_true[24*i:24*(i+1),1:]*label_true[24*i:24*(i+1)]).detach().to("cpu").numpy(),
                                image_shape)).astype(float)/255
    writer.add_image('test_result/images/'+str(i)+'/input/', image, global_step, dataformats='HWC')
    image = plot_bbx((node_data_pred_test[i]*label_true[24*i:24*(i+1)]).detach().to("cpu").numpy()).astype(float)/255
    writer.add_image('test_result/images/'+str(i)+'/generated/', image, global_step, dataformats='HWC')

writer.flush()
writer.close()