In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from torch.autograd import Variable 
from torchvision import transforms 
from data_loader import build_vocab 
from model import EncoderCNN, DecoderRNN
from model import ResNet, ResidualBlock
from PIL import Image
from attn_model import ResidualBlock, AttnEncoder, AttnDecoderRnn


embed_size = 128
hidden_size = 256 
num_layers = 1 
feature_size = 128 

#decoder_path = './models/2object_color/decoder-10-150.pkl'
#encoder_path = './models/2object_color/encoder-10-150.pkl'
decoder_path = './models/attn/2object/decoder-10-150.pkl'
encoder_path = './models/attn/2object/encoder-10-150.pkl'

def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)

def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([64, 64], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

transform = transforms.Compose([ 
    transforms.ToTensor(), 
    transforms.Normalize((0.033, 0.032, 0.033), 
                         (0.027, 0.027, 0.027))])



#vocab_path = './data/vocab_2object_color.pkl'
vocab_path = './data/attn/vocab2.pkl'
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)
len_vocab = vocab.idx

# Build Models
#encoder = ResNet(ResidualBlock, [3, 3, 3], len_vocab)
#encoder.eval()  # evaluation mode (BN uses moving mean/variance)
#decoder = DecoderRNN(len_vocab, hidden_size, len(vocab), num_layers)

encoder = AttnEncoder(ResidualBlock, [3, 3, 3])
encoder.eval()  # evaluation mode (BN uses moving mean/variance)
decoder = AttnDecoderRnn(feature_size, hidden_size, 
                     len(vocab), num_layers)



# Load the trained model parameters
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

# If use gpu
if torch.cuda.is_available():
    encoder.cuda()
    decoder.cuda()

    

In [2]:
test_dir = 'data/2object_test/bitmap/'
cap_dir = 'data/2object_test/caption/'
test_list = os.listdir(test_dir)
cnt = 0 

shape_match = 0
location_match = 0
location_match_rect= 0 
radius_match = 0 
rect_x_match = 0 
rect_y_match = 0 
color_match = 0

location_err = [] 
location_err_rect = [] 

rect_x_err = []
rect_y_err = []
radius_err = [] 


for fname in test_list: 
    #cnt+=1 
    if cnt>2:
        break;
    test_path = test_dir + fname
    test_image = load_image(test_path, transform)
    image_tensor = to_var(test_image)
    
    # Generate caption from image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    ids_arr = []
    for element in sampled_ids: 
        temp = element.cpu().data.numpy()
        ids_arr.append(int(temp))
    print(fname)

    # Decode word_ids to words
    sampled_caption = []
    for word_id in ids_arr:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    in_caption = sampled_caption[1:-1]
    print(in_caption)
    
    #read target caption
    cap_target = cap_dir+fname
    cap_path = cap_target.split('.')[0]+'.svg'
    trg_caption = [] 
    with open(cap_path, 'r') as f:
        target = f.readline().split()
        for word in target:
            trg_caption.append(word)
    print(trg_caption)
    print('------------------------------------')
    
    try:
        if in_caption[0] == trg_caption[0]:
            shape_match+=1 
        if in_caption[5] == trg_caption[5]:
            shape_match+=1
        if in_caption[1] == trg_caption[1]:
            location_match+=1
        else:
            error = int(trg_caption[1]) - int(in_caption[1])
            location_err.append(abs(error))
        if in_caption[2] == trg_caption[2]:
            location_match+=1
        else:
            error = int(trg_caption[2]) - int(in_caption[2])
            location_err.append(abs(error))
        if in_caption[6] == trg_caption[6]:
            location_match_rect+=1
        else:
            error = int(trg_caption[6]) - int(in_caption[6])
            location_err_rect.append(abs(error))
        if in_caption[7] == trg_caption[7]:
            location_match_rect+=1    
        else:
            error = int(trg_caption[7]) - int(in_caption[7])
            location_err_rect.append(abs(error))
        if in_caption[8] == trg_caption[8]:
            rect_x_match+=1
        else:
            error = int(trg_caption[8]) - int(in_caption[8])
            rect_x_err.append(abs(error))
        if in_caption[9] == trg_caption[9]:
            rect_y_match+=1
        else:
            error = int(trg_caption[9]) - int(in_caption[9])
            rect_y_err.append(abs(error))
        if in_caption[3] == trg_caption[3]:
            radius_match+=1       
        else:
            error = int(trg_caption[3]) - int(in_caption[3])
            radius_err.append(abs(error))
        if in_caption[4] == trg_caption[4]:
            color_match+=1   
        if in_caption[10] == trg_caption[10]:
            color_match+=1  
    except:
        continue



9908.png
['circle', '7', '3', '80', 'skyblue', 'rect', '9', '9', '80', '80', 'spring_green']
['circle', '6', '3', '80', 'skyblue', 'rect', '9', '9', '80', '80', 'spring_green']
------------------------------------
9759.png
['circle', '4', '8', '20', 'orange', 'rect', '4', '6', '20', '20', 'cyan']
['circle', '4', '8', '20', 'orange', 'rect', '4', '6', '20', '20', 'cyan']
------------------------------------
9903.png
['circle', '1', '9', '110', 'yellow', 'rect', '1', '7', '110', '110', 'green']
['circle', '1', '8', '110', 'yellow', 'rect', '1', '9', '110', '110', 'blue']
------------------------------------
9980.png
['circle', '8', '7', '50', 'lime', 'rect', '8', '9', '50', '50', 'green']
['circle', '8', '7', '50', 'lime', 'rect', '9', '9', '50', '50', 'green']
------------------------------------
9728.png
['circle', '9', '10', '30', 'blue', 'rect', '7', '4', '30', '30', 'pink']
['circle', '10', '10', '50', 'blue', 'rect', '7', '3', '50', '50', 'pink']
-----------------------------------

In [3]:
        
print('shape match: %.4f, location match: %.4f,location match_rect: %.4f, radius_match: %.4f, rect_x_match: %.4f, rect_y_match: %.4f color_match: %.4f'
     %(shape_match/(len(test_list)*2), location_match/(len(test_list)*2), location_match_rect/(len(test_list)*2),
      radius_match/len(test_list), rect_x_match/len(test_list), rect_y_match/len(test_list), color_match/(len(test_list)*2)))


shape match: 1.0000, location match: 0.8467,location match_rect: 0.8450, radius_match: 0.8300, rect_x_match: 0.7900, rect_y_match: 0.7900 color_match: 0.8833


In [None]:
x_ = {x:rect_x_err.count(x) for x in rect_x_err}
y_ = {x:rect_y_err.count(x) for x in rect_y_err}
r_ = {x:radius_err.count(x) for x in radius_err}
loc_ ={x:location_err.count(x) for x in location_err}
loc_rect_ = {x:location_err_rect.count(x) for x in location_err_rect}

In [None]:
loc_

In [None]:
test_dir = 'data/2object_test/bitmap/'
cap_dir = 'data/2object_test/caption/'
test_list = os.listdir(test_dir)
cnt = 0 

for fname in test_list: 
    test_path = test_dir + fname
    test_image = load_image(test_path, transform)
    image_tensor = to_var(test_image)
    
    # Generate caption from image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    ids_arr = []
    for element in sampled_ids: 
        temp = element.cpu().data.numpy()
        ids_arr.append(int(temp))


    # Decode word_ids to words
    sampled_caption = []
    for word_id in ids_arr:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)
    print(sentence)
    predict = sentence.split()[1]
    
    cap_target = cap_dir+fname
    cap_path = cap_target.split('.')[0]+'.svg'
    with open(cap_path, 'r') as f:
        target = f.readline()
        print('target: '+ target)
    if predict == target:
        print('corret')
    #    cnt+=1 
        
print(cnt)
    
    
    

In [None]:
feature.unsqueeze(1)

In [None]:
cap_path = img_list[0].split('.')[0]+'.svg'

In [None]:
cap_path