In [1]:
import numpy as np
import os
from data_loader import build_vocab, get_loader
from model import EncoderCNN, DecoderRNN 
from model import ResNet, ResidualBlock
import torch
from torch.autograd import Variable 
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
import pickle


def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda(1)
    return Variable(x, volatile=volatile)

def rearrange_tensor(x, batch_size, caption_size):
    for i in range(caption_size):
        temp = x[i*batch_size:(i+1)*batch_size].view(batch_size, -1)
        if i == 0:
            temp_cat = temp 
        else: 
            temp_cat = torch.cat((temp_cat,  temp), 1)

    return temp_cat

root_path ='data/bitmap2svg_samples2/'
vocab_path ='data/vocab.pkl'
batch_size= 128 
num_workers = 2 
embed_size = 256
hidden_size = 512
num_layers =1 

In [2]:
# Image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))])

# Load vocabulary wrapper
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)
len_vocab = vocab.idx
num_class = 9
data_loader = get_loader(root_path, vocab, 
                         transform, batch_size,
                         shuffle=True, num_workers=num_workers) 

encoder = ResNet(ResidualBlock, [3, 3, 3],9)

if torch.cuda.is_available():
        encoder.cuda(1)


In [4]:
vocab.idx2word

{0: '<pad>',
 1: '<start>',
 2: '<end>',
 3: '<unk>',
 4: '70',
 5: '6',
 6: '30',
 7: '9',
 8: '120',
 9: '8',
 10: '90',
 11: '40',
 12: '2',
 13: '5',
 14: '7',
 15: '100',
 16: '1',
 17: '110',
 18: '80',
 19: '60',
 20: 'circle',
 21: '20',
 22: '50',
 23: '3',
 24: '4'}

In [12]:
for i, (images, captions, lengths) in enumerate(data_loader):
    if i > 1 : 
        break;
    images = to_var(images) 
    caption_loc = captions[:,1]
    features = encoder(images)

In [9]:

idx_arr = []
for element in caption_loc:
    idx_arr.append(int(vocab.idx2word[element]) - 1)
temp_arr= np.array(idx_arr)
trg_arr = torch.from_numpy(temp_arr)
target = to_var(trg_arr)

In [15]:
target

Variable containing:
 5
 0
 5
 0
 5
 7
 7
 6
 4
 1
 5
 8
 1
 0
 0
 8
 4
 6
 2
 5
 6
 5
 1
 4
 2
 8
 5
 3
 3
 0
 2
 6
 4
 0
 6
 8
 1
 4
 1
 1
 1
 5
 6
 5
 3
 1
 3
 0
 1
 7
 5
 6
 6
 5
 1
 3
 5
 4
 5
 5
 5
 5
 7
 5
 1
 7
 3
 8
 7
 6
 6
 0
 7
 8
 1
 7
 2
 1
 4
 1
 4
 4
 2
 4
 3
 4
 7
 7
 5
 5
 4
 5
 4
 1
 6
 1
 1
 2
 0
 8
 3
 4
 7
 1
 7
 4
 5
 8
 1
 5
 3
 8
 5
 5
 7
 7
 4
 1
 1
 5
 7
 5
 2
 0
 0
 2
 3
 4
[torch.cuda.LongTensor of size 128 (GPU 1)]

In [14]:
features.max(1)[1]

Variable containing:
    8
    8
    8
    8
    1
    8
    8
    6
    8
    8
    8
    8
    8
    8
    1
    6
    6
    8
    8
    1
    1
    8
    1
    8
    5
    8
    8
    8
    8
    1
    1
    1
    8
    1
    8
    1
    8
    6
    8
    1
    8
    6
    8
    8
    8
    1
    8
    1
    1
    8
    1
    1
    8
    8
    8
    8
    1
    6
    1
    8
    1
    1
    8
    8
    1
    8
    8
    1
    8
    8
    8
    1
    8
    8
    8
    5
    1
    8
    5
    1
    8
    5
    8
    8
    6
    1
    1
    5
    5
    8
    8
    8
    8
    8
    8
    8
    8
    1
    8
    8
    8
    8
    5
    8
    1
    8
    8
    8
    5
    1
    1
    8
    1
    8
    8
    8
    8
    8
    8
    1
    8
    8
    1
    8
    8
    1
    8
    8
[torch.cuda.LongTensor of size 128x1 (GPU 1)]

In [None]:
re_out_max = rearrange_tensor(features.max(1)[1], batch_size, 1)

In [None]:
re_out_max

In [None]:
re_target = rearrange_tensor(target, batch_size, 1)

In [None]:
trg_  = trg_arr.unsqueeze(1)
one_hot_ = torch.FloatTensor(caption_loc.size(0),9).zero_()
one_hot_trg = one_hot_.scatter_(1, trg_, 1)
