# Converter

Converts PyTorch model to ONNX and `vocab.pkl` to `vocab.txt`

----

## Step 1
- Download models files from: https://www.dropbox.com/s/ne0ixz5d58ccbbz/pretrained_model.zip?dl=0

- Download vocab file from: https://www.dropbox.com/s/26adb7y9m98uisa/vocap.zip?dl=0

_(Make sure the downloaded files are in the same directory as this file)_

## Step 2
- Install PyTorch by running `pip install torch torchvision`

## Step 3
- Import the required python modules

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence
from PIL import Image
import pickle

## Step 4
- Define the Vocabulary class and load the vocab.pkl file

In [None]:
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

with open('./vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

## Step 5
- Read vocab to list and write to simple text file

In [None]:
words = []
for i in range(len(vocab)):
    word = vocab.idx2word[i]
    words.append(word)
    
with open('./vocab.txt', 'w') as f:
    f.writelines('\n'.join(words))

## Step 6
- Define the model classes and load the pretrained model file

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward2(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def forward(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids

In [None]:
# Build models
encoder = EncoderCNN(256).eval()  # eval mode (batchnorm uses moving mean/variance)
decoder = DecoderRNN(256, 512, len(words), 1) # len(words) = 9956 in our case

# Load the trained model parameters
encoder.load_state_dict(torch.load('./encoder-5-3000.pkl'))
decoder.load_state_dict(torch.load('./decoder-5-3000.pkl'))

## Step 7
- Load and transform sample image

In [None]:
image = Image.open('./test.jpg').convert('RGB')
image = image.resize([224, 224], Image.LANCZOS)
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                               ])
image = transform(image)
image = image.unsqueeze(0)

## Step 8
- Convert the models to ONNX

In [None]:
input_names = ['image']
output_names = ['feature']

torch.onnx.export(encoder, image, 'encoder.onnx', opset_version=12,
                  input_names=input_names, output_names=output_names, do_constant_folding=True)

# Sample encoder output -> decoder input
feature = encoder(image)

input_names = ['feature']
output_names = ['sample_ids']
torch.onnx.export(decoder, feature, 'decoder.onnx', opset_version=12,
                  input_names=input_names, output_names=output_names, do_constant_folding=True)

## Step 9
- Delete the original model and vocab files (`encoder-5-3000.pkl`, `decoder-5-3000.pkl`, `vocab.pkl`)