In [14]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from torch.autograd import Variable 
from torchvision import transforms 
from data_loader import build_vocab 
from model import EncoderCNN, DecoderRNN
from PIL import Image


def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)

def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

embed_size = 256
hidden_size = 512 
num_layers = 1 
in_image = 'data/bitmap2_test/bitmap/902.png'
decoder_path = './models/decoder-5-15.pkl'
encoder_path = './models/encoder-5-15.pkl'

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))])

# Load vocabulary wrapper
# Build vocab
vocab_path = './data/vocab.pkl'
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)
#vocab = build_vocab('bitmap2svg_samples2/', threshold=0)

# Build Models
encoder = EncoderCNN(embed_size)
encoder.eval()  # evaluation mode (BN uses moving mean/variance)
decoder = DecoderRNN(embed_size, hidden_size, 
                     len(vocab), num_layers)


# Load the trained model parameters
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

# Prepare Image
image = load_image(in_image, transform)
image_tensor = to_var(image, volatile=True)
#print(image_tensor)
image_tensor = image_tensor[:,0:-1,:,:]
#print(image_tensor)

# If use gpu
if torch.cuda.is_available():
    encoder.cuda()
    decoder.cuda()

    

In [22]:
import torch.nn as nn
temp = list(encoder.parameters())


In [24]:
temp.grad

AttributeError: 'list' object has no attribute 'grad'

In [2]:
test_dir = 'data/bitmap2svg_samples2/bitmap/'
cap_dir = 'data/bitmap2svg_samples2/caption/'
test_list = os.listdir(test_dir)
cnt = 0 

for fname in test_list: 
    test_path = test_dir + fname
    #print(test_path)
    test_image = load_image(test_path, transform)
    image_tensor = to_var(test_image,volatile=True)
    #image_tensor = image_tensor[:,0:-1,:,:]
    
    # Generate caption from image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    ids_arr = []
    for element in sampled_ids: 
        temp = element.cpu().data.numpy()
        ids_arr.append(int(temp))

    #sampled_ids = sampled_ids.cpu().data.numpy()

    # Decode word_ids to words
    sampled_caption = []
    for word_id in ids_arr:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    sentence = ' '.join(sampled_caption)
    #print(sentence)
    predict = sentence.split()[1]
    
    cap_target = cap_dir+fname
    cap_path = cap_target.split('.')[0]+'.svg'
    with open(cap_path, 'r') as f:
        target = f.readline().split()[0]
        #print(f.readline())
    if predict == target:
        cnt+=1 
        
print(cnt)
    
    
    

<start> 60 circle 90 <end>
<start> 120 circle 90 <end>
<start> 60 circle 90 <end>
<start> circle circle 8 <end>
<start> 60 circle 90 <end>
<start> 7 circle 40 <end>
<start> 7 circle 40 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 7 circle 40 <end>
<start> 120 circle 90 <end>
<start> 60 circle 90 <end>
<start> 120 circle 40 <end>
<start> 7 circle 40 <end>
<start> 50 circle 40 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 120 circle 90 <end>
<start> 7 circle 40 <end>
<start> 7 circle 40 <end>
<start> 7 circle 40 <end>
<start> 7 circle 40 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 120 circle 90 <end>
<start> 7 circle 40 <end>
<start> 60 circle 90 <end>
<start> 50 circle 40 <end>
<start> 60 circle 90 <end>
<start> 60 circle 90 <end>
<start> 60 circle 90 <end>
<start> 