In [2]:
import os
import torch
import pickle
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence

In [3]:
#sample.py
from torchvision import transforms 
from build_vocab import Vocabulary

In [4]:
#train.py
from data_loader import get_loader

In [5]:
def load_image(image_path, transform=None):
    image = Image.open(image_path)
    image = image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

In [6]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features
    
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True) 
        hiddens, _ = self.lstm(packed)
        hiddens = pad_packed_sequence(hiddens, batch_first=True)
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids 

## Notes

 - 9956 words in our vocabulary
 - created a vocab class with word_to_idx and idx_to_word attributes (dicts) using pycocotools
 - a minimum word count threshold of 4 was chosen as default value to identify and remove rare words

# Let's try some model training

###  Define parameters and filepaths

In [7]:
class Args():
    model_path = 'nb_models'
    crop_size = 224
    vocab_path = 'data/vocab.pkl'
    image_dir = 'data/resized2014'
    caption_path = 'data/annotations/captions_train2014.json'
    log_step = 10
    save_step = 1000
    embed_size = 256
    hidden_size = 512
    num_layers = 1
    num_epochs = 1
    batch_size = 128
    num_workers = 4
    learning_rate = 0.001
args = Args()

In [8]:
if not os.path.exists(args.model_path):
    os.makedirs(args.model_path)

In [9]:
device = torch.device("cuda:0")

### Image preprocessing, normalization for the pretrained resnet

In [10]:
transform = transforms.Compose([ 
    transforms.RandomCrop(args.crop_size),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))])

### Load vocabulary wrapper

In [11]:
with open(args.vocab_path, 'rb') as f:
    vocab = pickle.load(f)
print("Size of Vocabulary: %d" % len(vocab))

Size of Vocabulary: 9956


### Build data loader

In [12]:
%%time
data_loader = get_loader(args.image_dir, args.caption_path, vocab, 
                         transform, args.batch_size,
                         shuffle=True, num_workers=args.num_workers)

loading annotations into memory...
Done (t=1.86s)
creating index...
index created!
CPU times: user 708 ms, sys: 168 ms, total: 876 ms
Wall time: 2.15 s


### Build the models

In [13]:
encoder = EncoderCNN(args.embed_size).to(device)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)

### Loss and optimizer

In [14]:
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)

In [15]:
%%time
for i, (images, captions, lengths) in enumerate(data_loader):
    print(i)
    if i > 0:
        break

0
1
CPU times: user 28 ms, sys: 142 ms, total: 170 ms
Wall time: 17.9 s


20.46 p.m.

In [16]:
print("input -> minibatch of images: ", images.shape)
print("input -> minibatch of captions: ", captions.shape)
print("input -> minibatch of lengths: ", len(lengths))

input -> minibatch of images:  torch.Size([128, 3, 224, 224])
input -> minibatch of captions:  torch.Size([128, 27])
input -> minibatch of lengths:  128


In [17]:
images = images.to(device)
captions = captions.to(device)
targets = pack_padded_sequence(captions, lengths, batch_first=True)
print("target -> :", targets[0].shape)

target -> : torch.Size([1693])


In [18]:
# Forward, backward and optimize
features = encoder(images)
outputs = decoder(features, captions, lengths)

In [19]:
print(features.shape)
print(outputs.shape)

torch.Size([128, 256])
torch.Size([128, 27, 9956])


In [20]:
packed = pack_padded_sequence(outputs, lengths, batch_first=True)
print(packed[0].shape)
print(criterion(packed[0], targets[0]))

torch.Size([1693, 9956])
tensor(9.2124, device='cuda:0', grad_fn=<NllLossBackward>)


In [21]:
from fede.distance_senteces import Distance_Sentences
from fede.distance_image import Distance_Image

In [23]:
class Discriminator(nn.Module):
    """
    A CaptioningRNN produces captions from image features using a recurrent
    neural network.

    The RNN receives input vectors of size D, has a vocab size of V, works on
    sequences of length T, has an RNN hidden dimension of H, uses word vectors
    of dimension W, and operates on minibatches of size N.

    Note that we don't use any regularization for the CaptioningRNN.
    """

    def __init__(self, 
                 word_to_idx, 
                 input_Dim=1, 
                 wordvec_Dim=128,
                 hidden_Dim=128,
                 num_layers=1,
                 N=128, 
                 O=128, 
                 image_input=512, 
                 set_size=10, 
                 use_cuda = False,
                 device = torch.device("cpu")):
        
        super(Discriminator, self).__init__()

        if torch.cuda.is_available() and use_cuda:
            self.use_cuda = True
            self.device = device
        else:
            self.device = torch.device("cpu")
            
        self.sentence_embedding= CaptioningModel(word_to_idx, 
                                                 input_dim=input_Dim, 
                                                 wordvec_dim=wordvec_Dim,
                                                 num_layers = num_layers,
                                                 hidden_dim=hidden_Dim, 
                                                 use_cuda = self.use_cuda, 
                                                 device = self.device)         
        
        vocab_size=len(word_to_idx)
        self.distance_layer_sentences = Distance_Sentences(vocab_size, N, O)
        self.distance_layer_images = Distance_Image(vocab_size, N, O, image_input)
        
        
        self.set_size = set_size
        self.projection = nn.Linear((self.set_size+1)*O, 2)
        
        
    def forward(self, captions, features):
        
        nsamples, _ = captions.shape
        
        ft = torch.zeros(nsamples,1).to(self.device)
        print("input to embedding", ft.shape)
        S = self.sentence_embedding.forward(ft, captions)
        print("Sentence Embedding Shape", S.shape)
        S = S[:,-1,:]
        print("Sentence Embedding Modified Shape", S.shape)
        
        S = S.view(S.shape[0] // self.set_size, self.set_size, -1)
        print(S.shape)
        
        o_sentence = self.distance_layer_sentences.forward(S)
        print("sentence distance", o_sentence.shape)

        o_image = self.distance_layer_images.forward(S,features)
        print("image distance", o_image.shape)
        
        o = torch.cat((o_image, o_sentence), 1).to(self.device)        
        print("distances concatenated", o.shape)
        
        D = nn.functional.log_softmax(self.projection(o),dim=1)
        print("logmax performed", D.shape)
        
        return D

### Train the models


In [1]:
total_step = len(data_loader)
for epoch in range(args.num_epochs):
    for i, (images, captions, lengths) in enumerate(data_loader):
        # Set mini-batch dataset
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % args.log_step == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                  .format(epoch, args.num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 

        # Save the model checkpoints
        if (i+1) % args.save_step == 0:
            torch.save(decoder.state_dict(), os.path.join(
                args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            torch.save(encoder.state_dict(), os.path.join(
                args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))