In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import cv2
from collections import Counter
import matplotlib.pyplot as plt 
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torch.optim as torch_optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import random
import os

In [2]:
class Encoder(nn.Module):
    def __init__(self, emb_dim):
        super(Encoder, self).__init__()
        resnet = models.resnet34(pretrained=True)
        layers = list(resnet.children())[:8]
        self.features1 = nn.Sequential(*layers[:6])  # Blocks 1-6
        self.features2 = nn.Sequential(*layers[6:])  # Blocks 7-8
        self.pool = nn.AdaptiveAvgPool2d((1,1))  # 1 pooling layer
        self.linear = nn.Linear(resnet.fc.in_features, emb_dim)  # we have to pass this vector to LSTM
        self.bn = nn.BatchNorm1d(emb_dim, momentum=0.01)
        
    def forward(self, x):
        x = self.features1(x)  # pass through first 6 layers
        x = self.features2(x)  # pass through last 2 layers
        x = F.relu(x)  # activation
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = self.linear(x)
        x = self.bn(x)
        return x

In [4]:
enc = Encoder(20)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /home/ubuntu/.cache/torch/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




In [7]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_dim)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_dim)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_dim)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids

In [8]:
Decoder(20, 50, 1000, 2)

Decoder(
  (embed): Embedding(1000, 20)
  (lstm): LSTM(20, 50, num_layers=2, batch_first=True)
  (linear): Linear(in_features=50, out_features=1000, bias=True)
)