In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from attn_model import ResidualBlock, AttnEncoder, AttnDecoderRnn
from torch.autograd import Variable 
from torchvision import transforms 
from data_loader import build_vocab 
from PIL import Image


embed_size = 256
hidden_size = 512 
feature_size= 256
num_layers = 1 
decoder_path = './models/attn/1object/decoder-10-150.pkl'
encoder_path = './models/attn/1object/encoder-10-150.pkl'


def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)

def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([64, 64], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

transform = transforms.Compose([ 
    transforms.ToTensor(), 
    transforms.Normalize((0.033, 0.032, 0.033), 
                         (0.027, 0.027, 0.027))])



vocab_path = './data/attn/vocab1.pkl'
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)
len_vocab = vocab.idx

# Build Models
encoder = AttnEncoder(ResidualBlock, [3, 3, 3])
encoder.eval()  # evaluation mode (BN uses moving mean/variance)
decoder = AttnDecoderRnn(feature_size, hidden_size, 
                     len(vocab), num_layers)


# Load the trained model parameters
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

# If use gpu
if torch.cuda.is_available():
    encoder.cuda()
    decoder.cuda()

    

In [None]:
image_dir = 'data/circle_and_rect/bitmap/1.png'
test_image = load_image(image_dir,transform)
image_tensor = to_var(test_image)

In [None]:
features = encoder(image_tensor)
h,c = decoder.init_lstm(features)

In [None]:
sums = torch.sum(features, 1)
out = torch.mul(sums, 1/features.size(1))
out = out.squeeze(1).unsqueeze(0)
out = decoder.init_layer(out.squeeze(0)).unsqueeze(0)

In [None]:
import torch.nn.functional as F
de_hidden = decoder.attn.attn(h.squeeze(0)).unsqueeze(2)
attn_weight = torch.bmm(features, de_hidden)
attn_weight = attn_weight.squeeze(2)
attn_weight = F.softmax(attn_weight).unsqueeze(2)
#attn_weight = F.softmax(attn_weight)

In [None]:
context = torch.bmm(attn_weight.transpose(2,1),features)

In [None]:
context

In [None]:
x = Variable(torch.rand(1,1,256)).cuda()
#x = decoder.embed((temp))
context = decoder.attn(h, features)
lstm_input = torch.cat((context, x) ,2)
lstm_out, (h,c) = decoder.lstm(lstm_input, (h,c)) 

In [None]:
F.softmax(attn_weight.squeeze(2)).unsq

In [None]:
de_hidden = decoder.attn.attn(h.squeeze(0)).unsqueeze(2)
attn_weight = torch.bmm(features, de_hidden)

In [None]:
out_ = torch.mul(sums,1/features.size(1))
out2= decoder.init_layer(out_)

In [None]:
import torch.nn.functional as F
F.tanh(out2)

In [None]:
de_hidden = decoder.attn.attn(h.squeeze(0)).unsqueeze(2)
attn_weight = torch.bmm(features, de_hidden)

In [None]:
de_hidden

In [None]:
features = encoder(image_tensor)
h,c = decoder.init_lstm(features)
x = Variable(torch.rand(1,1,256)).cuda()
#x = decoder.embed((temp))
context = decoder.attn(h, features)
lstm_input = torch.cat((context, x) ,2)
lstm_out, (h,c) = decoder.lstm(lstm_input, (h,c))          # (batch_size, 1, hidden_size), 
out = decoder.decode_lstm(x, context, h, lstm_out)
predicted = out.max(1)[1]

In [None]:
de_hidden = decoder.attn.attn(h.squeeze(0)).unsqueeze(2)
attn_weight = torch.bmm(features, de_hidden)

In [None]:
attn_weight

In [None]:
temp = Variable(torch.Tensor([1]).cuda().unsqueeze(1).long())

In [None]:
decoder.embed(predicted)

In [2]:
test_dir = 'data/circle_and_rect/bitmap/'
cap_dir = 'data/circle_and_rect/caption/'
test_list = os.listdir(test_dir)
cnt = 0 

shape_match = 0
location_match = 0
radius_match = 0 
rect_x_match = 0 
rect_y_match = 0 
length_match = 0
for fname in test_list: 
    #cnt+=1 
    if cnt>2:
        break;
    test_path = test_dir + fname
    test_image = load_image(test_path, transform)
    image_tensor = to_var(test_image)
    
    # Generate caption from image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    ids_arr = []
    for element in sampled_ids: 
        temp = element.cpu().data.numpy()
        ids_arr.append(int(temp))
    print(fname)

    # Decode word_ids to words
    sampled_caption = []
    for word_id in ids_arr:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    in_caption = sampled_caption[1:-1]
    print(in_caption)
    
    #read target caption
    cap_target = cap_dir+fname
    cap_path = cap_target.split('.')[0]+'.svg'
    trg_caption = [] 
    with open(cap_path, 'r') as f:
        target = f.readline().split()
        for word in target:
            trg_caption.append(word)
    print(trg_caption)
    print('-------------------------------------------------------------------')
    
    
    if(len(in_caption) == len(trg_caption)):        
        length_match += 1
        
print(length_match) 


748.png
['circle', '9', '5', '100', 'lime']
['circle', '9', '5', '100', 'deep_pink']
-------------------------------------------------------------------
1788.png
['rect', '10', '10', '60', '60', 'blue']
['rect', '10', '10', '110', '110', 'orange']
-------------------------------------------------------------------
1594.png
['circle', '10', '3', '100', 'lime']
['circle', '10', '3', '110', 'pink']
-------------------------------------------------------------------
1121.png
['circle', '1', '9', '60', 'orange']
['circle', '1', '9', '60', 'deep_pink']
-------------------------------------------------------------------
1288.png
['circle', '8', '6', '110', 'lime']
['circle', '8', '6', '100', 'purple']
-------------------------------------------------------------------
606.png
['rect', '5', '2', '70', '70', 'spring_green']
['circle', '6', '3', '40', 'pink']
-------------------------------------------------------------------
468.png
['circle', '6', '7', '70', 'cyan']
['rect', '5', '5', '120', '

KeyboardInterrupt: 

In [3]:
test_dir = 'data/circle_and_rect/bitmap/'
cap_dir = 'data/circle_and_rect/caption/'
test_list = os.listdir(test_dir)
cnt = 0 

shape_match = 0
location_match = 0
radius_match = 0 
rect_x_match = 0 
rect_y_match = 0 
length_match = 0
circle_cnt = 0 
rect_cnt = 0 
color_match= 0

location_err = [] 
rect_x_err = []
rect_y_err = []
radius_err = [] 

for fname in test_list: 
    #cnt+=1 
    if cnt>2:
        break;
    test_path = test_dir + fname
    test_image = load_image(test_path, transform)
    image_tensor = to_var(test_image)
    
    print(fname)
    
    # Generate caption from image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    ids_arr = []
    for element in sampled_ids: 
        temp = element.cpu().data.numpy()
        ids_arr.append(int(temp))
    #print(fname)

    # Decode word_ids to words
    sampled_caption = []
    for word_id in ids_arr:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    in_caption = sampled_caption[1:-1]
    print(in_caption)
    
    #read target caption
    cap_target = cap_dir+fname
    cap_path = cap_target.split('.')[0]+'.svg'
    trg_caption = [] 
    with open(cap_path, 'r') as f:
        target = f.readline().split()
        for word in target:
            trg_caption.append(word)
    print(trg_caption)
    print('-------------------------------------------------------------------')
    
    if(trg_caption[0]) == 'circle':
        if in_caption[0] == trg_caption[0]:
            shape_match+=1
            circle_cnt+=1
        if (in_caption[1] == trg_caption[1]):
            location_match+=1
        if (in_caption[2] == trg_caption[2]):
            location_match+=1        
        if (in_caption[3] == trg_caption[3]):
            radius_match+=1
        else:
            error = int(trg_caption[3]) - int(in_caption[3])
            radius_err.append(abs(error))
        if (in_caption[4] == trg_caption[4]):
            color_match+=1
    elif(trg_caption[0]) == 'rect':
        if in_caption[0] == trg_caption[0]:
            shape_match+=1
            rect_cnt+=1
        if (in_caption[1] == trg_caption[1]):
            location_match+=1
        if (in_caption[2] == trg_caption[2]):
            location_match+=1        
        if (in_caption[3] == trg_caption[3]):
            rect_x_match+=1
        else: 
            error = int(trg_caption[3]) - int(in_caption[3])
            rect_x_err.append(abs(error))
        if (in_caption[4] == trg_caption[4]):
            rect_y_match+=1
        else: 
            error = int(trg_caption[4]) - int(in_caption[4])
            rect_y_err.append(abs(error))            
        if (in_caption[5] == trg_caption[5]):
            color_match+=1
            
    
    if(trg_caption[0] != in_caption[0]):
        print(fname)
    if(trg_caption[1] != in_caption[1]):
        try:
            error = int(trg_caption[1]) - int(in_caption[1])
        except:
            print(fname)
            print(in_caption)
            pass
        location_err.append(abs(error))
    if(trg_caption[2] != in_caption[2]):
        error = int(trg_caption[1]) - int(in_caption[2])
        location_err.append(abs(error))
            


748.png
['circle', '9', '5', '100', 'lime']
['circle', '9', '5', '100', 'deep_pink']
-------------------------------------------------------------------
1788.png
['rect', '10', '10', '60', '60', 'blue']
['rect', '10', '10', '110', '110', 'orange']
-------------------------------------------------------------------
1594.png
['circle', '10', '3', '100', 'lime']
['circle', '10', '3', '110', 'pink']
-------------------------------------------------------------------
1121.png
['circle', '1', '9', '60', 'orange']
['circle', '1', '9', '60', 'deep_pink']
-------------------------------------------------------------------
1288.png
['circle', '8', '6', '110', 'lime']
['circle', '8', '6', '100', 'purple']
-------------------------------------------------------------------
606.png
['rect', '5', '2', '70', '70', 'spring_green']
['circle', '6', '3', '40', 'pink']
-------------------------------------------------------------------
606.png
468.png
['circle', '6', '7', '70', 'cyan']
['rect', '5', '5', 

ValueError: invalid literal for int() with base 10: 'cyan'