# Image Captioning

---

## Training Setup

Customize the training of your CNN-RNN model by specifying hyperparameters and setting other options that are important to the training procedure.  

Setting the following variables:
- `batch_size` - the batch size of each training batch.  It is the number of image-caption pairs used to amend the model weights in each training step. 
- `vocab_threshold` - the minimum word count threshold.  Note that a larger threshold will result in a smaller vocabulary, whereas a smaller threshold will include rarer words and result in a larger vocabulary.  
- `vocab_from_file` - a Boolean that decides whether to load the vocabulary from file. 
- `embed_size` - the dimensionality of the image and word embeddings.  
- `hidden_size` - the number of features in the hidden state of the RNN decoder.  
- `num_epochs` - the number of epochs to train the model.  We recommend that you set `num_epochs=3`, but feel free to increase or decrease this number as you wish.  [This paper](https://arxiv.org/pdf/1502.03044.pdf) trained a captioning model on a single state-of-the-art GPU for 3 days, but you'll soon see that you can get reasonable results in a matter of a few hours!  (_But of course, if you want your model to compete with current research, you will have to train for much longer._)
- `save_every` - determines how often to save the model weights.  We recommend that you set `save_every=1`, to save the model weights after each epoch.  This way, after the `i`th epoch, the encoder and decoder weights will be saved in the `models/` folder as `encoder-i.pkl` and `decoder-i.pkl`, respectively.
- `print_every` - determines how often to print the batch loss to the Jupyter notebook while training.  Note that you **will not** observe a monotonic decrease in the loss function while training - this is perfectly fine and completely expected!  You are encouraged to keep this at its default value of `100` to avoid clogging the notebook, but feel free to change it.
- `log_file` - the name of the text file containing - for every step - how the loss and perplexity evolved during training.

If you're not sure where to begin to set some of the values above, you can use [this paper](https://arxiv.org/pdf/1502.03044.pdf) and [this paper](https://arxiv.org/pdf/1411.4555.pdf) for useful guidance! 

### CNN-RNN architecture

**The CNN encoder** is a ResNet, the type of architcture help to deal with the vanishing/exploding gradient. This pre trained model will be used to encode the picture. Other pre trained model are available [here]( https://pytorch.org/docs/master/torchvision/models.html).
**The RNN decoder** architecture follow what is describe in the paper about image captioning.

In [2]:
import torch
import torch.nn as nn
from torchvision import transforms
import sys
sys.path.append('/opt/cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from model import EncoderCNN, DecoderRNN
import math
import nltk
nltk.download('punkt')


## Select values for the Python variables below.
batch_size = 32          # batch size
vocab_threshold = 5       # minimum word count threshold
vocab_from_file = True    # if True, load existing vocab file
embed_size = 512           # dimensionality of image and word embeddings
hidden_size = 512         # number of features in hidden state of the RNN decoder
num_epochs = 3            # number of training epochs
save_every = 1             # determines frequency of saving model weights
print_every = 100          # determines window for printing average loss
log_file = 'training_log.txt'       # name of file with saved training loss and perplexity

# Amend the image transform below.
transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 256
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(),               # horizontally flip image with probability=0.5
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

# Build data loader.
data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=vocab_from_file)

# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)

# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to GPU if CUDA is available. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

# Define the loss function. 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

# Specify the learnable parameters of the model.
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Define the optimizer.
optimizer = torch.optim.Adam(params, lr=0.001)

# Set the total number of training steps per epoch.
total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...



  0%|          | 0/414113 [00:00<?, ?it/s][A
  0%|          | 413/414113 [00:00<01:40, 4126.22it/s][A
  0%|          | 863/414113 [00:00<01:37, 4228.74it/s]

Done (t=1.03s)
creating index...
index created!
Obtaining caption lengths...


[A
  0%|          | 1307/414113 [00:00<01:36, 4289.38it/s][A
  0%|          | 1700/414113 [00:00<01:38, 4172.05it/s][A
  1%|          | 2135/414113 [00:00<01:37, 4223.16it/s][A
  1%|          | 2587/414113 [00:00<01:35, 4306.33it/s][A
  1%|          | 3032/414113 [00:00<01:34, 4346.49it/s][A
  1%|          | 3449/414113 [00:00<01:35, 4289.60it/s][A
  1%|          | 3893/414113 [00:00<01:34, 4331.45it/s][A
  1%|          | 4350/414113 [00:01<01:33, 4399.86it/s][A
  1%|          | 4798/414113 [00:01<01:32, 4423.08it/s][A
  1%|▏         | 5262/414113 [00:01<01:31, 4484.64it/s][A
  1%|▏         | 5706/414113 [00:01<01:32, 4429.87it/s][A
  1%|▏         | 6146/414113 [00:01<01:32, 4406.87it/s][A
  2%|▏         | 6585/414113 [00:01<01:33, 4355.35it/s][A
  2%|▏         | 7041/414113 [00:01<01:32, 4412.95it/s][A
  2%|▏         | 7484/414113 [00:01<01:32, 4417.12it/s][A
  2%|▏         | 7933/414113 [00:01<01:31, 4438.10it/s][A
  2%|▏         | 8390/414113 [00:01<01:30, 4474.89it

 30%|██▉       | 122887/414113 [00:27<01:06, 4393.96it/s][A
 30%|██▉       | 123335/414113 [00:27<01:05, 4416.87it/s][A
 30%|██▉       | 123786/414113 [00:28<01:05, 4441.76it/s][A
 30%|██▉       | 124233/414113 [00:28<01:05, 4449.83it/s][A
 30%|███       | 124679/414113 [00:28<01:05, 4438.33it/s][A
 30%|███       | 125156/414113 [00:28<01:03, 4530.71it/s][A
 30%|███       | 125610/414113 [00:28<01:03, 4530.53it/s][A
 30%|███       | 126069/414113 [00:28<01:03, 4548.17it/s][A
 31%|███       | 126525/414113 [00:28<01:03, 4499.96it/s][A
 31%|███       | 126976/414113 [00:28<01:03, 4497.29it/s][A
 31%|███       | 127426/414113 [00:28<01:03, 4487.14it/s][A
 31%|███       | 127890/414113 [00:28<01:03, 4531.58it/s][A
 31%|███       | 128360/414113 [00:29<01:02, 4580.40it/s][A
 31%|███       | 128819/414113 [00:29<01:02, 4562.69it/s][A
 31%|███       | 129276/414113 [00:29<01:03, 4517.53it/s][A
 31%|███▏      | 129729/414113 [00:29<01:03, 4496.67it/s][A
 31%|███▏      | 130180/

 59%|█████▉    | 244352/414113 [00:55<00:37, 4483.78it/s][A
 59%|█████▉    | 244812/414113 [00:55<00:37, 4516.87it/s][A
 59%|█████▉    | 245264/414113 [00:55<00:37, 4490.94it/s][A
 59%|█████▉    | 245714/414113 [00:55<00:37, 4476.99it/s][A
 59%|█████▉    | 246162/414113 [00:55<00:37, 4476.98it/s][A
 60%|█████▉    | 246614/414113 [00:55<00:37, 4487.04it/s][A
 60%|█████▉    | 247066/414113 [00:55<00:37, 4494.51it/s][A
 60%|█████▉    | 247527/414113 [00:55<00:36, 4528.06it/s][A
 60%|█████▉    | 247980/414113 [00:55<00:36, 4497.42it/s][A
 60%|█████▉    | 248430/414113 [00:56<00:36, 4492.91it/s][A
 60%|██████    | 248880/414113 [00:56<00:37, 4428.30it/s][A
 60%|██████    | 249333/414113 [00:56<00:36, 4455.69it/s][A
 60%|██████    | 249779/414113 [00:56<00:36, 4451.68it/s][A
 60%|██████    | 250225/414113 [00:56<00:38, 4287.81it/s][A
 61%|██████    | 250660/414113 [00:56<00:37, 4304.46it/s][A
 61%|██████    | 251113/414113 [00:56<00:37, 4369.30it/s][A
 61%|██████    | 251553/

 88%|████████▊ | 366309/414113 [01:22<00:10, 4436.42it/s][A
 89%|████████▊ | 366753/414113 [01:22<00:10, 4392.97it/s][A
 89%|████████▊ | 367193/414113 [01:22<00:10, 4364.54it/s][A
 89%|████████▉ | 367633/414113 [01:22<00:10, 4375.10it/s][A
 89%|████████▉ | 368078/414113 [01:22<00:10, 4395.89it/s][A
 89%|████████▉ | 368533/414113 [01:23<00:10, 4438.46it/s][A
 89%|████████▉ | 368991/414113 [01:23<00:10, 4477.56it/s][A
 89%|████████▉ | 369440/414113 [01:23<00:09, 4479.31it/s][A
 89%|████████▉ | 369889/414113 [01:23<00:10, 4399.55it/s][A
 89%|████████▉ | 370335/414113 [01:23<00:09, 4411.61it/s][A
 90%|████████▉ | 370777/414113 [01:23<00:09, 4351.06it/s][A
 90%|████████▉ | 371228/414113 [01:23<00:09, 4395.76it/s][A
 90%|████████▉ | 371668/414113 [01:23<00:09, 4334.33it/s][A
 90%|████████▉ | 372102/414113 [01:23<00:09, 4327.50it/s][A
 90%|████████▉ | 372536/414113 [01:23<00:09, 4331.12it/s][A
 90%|█████████ | 372985/414113 [01:24<00:09, 4376.48it/s][A
 90%|█████████ | 373428/

In [3]:
import torch.utils.data as data
import numpy as np
import os
import requests
import time

# Open the training log file.
f = open(log_file, 'w')

old_time = time.time()
response = requests.request("GET", 
                            "http://metadata.google.internal/computeMetadata/v1/instance/attributes/keep_alive_token", 
                            headers={"Metadata-Flavor":"Google"})

for epoch in range(1, num_epochs+1):
    
    for i_step in range(1, total_step+1):
        
        if time.time() - old_time > 60:
            old_time = time.time()
            requests.request("POST", 
                             "https://nebula.udacity.com/api/v1/remote/keep-alive", 
                             headers={'Authorization': "STAR " + response.text})
        
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)
        
        # Zero the gradients.
        decoder.zero_grad()
        encoder.zero_grad()
        
        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        outputs = decoder(features, captions)
        
        # Calculate the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        
        # Backward pass.
        loss.backward()
        
        # Update the parameters in the optimizer.
        optimizer.step()
            
        # Get training statistics.
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.item(), np.exp(loss.item()))
        
        # Print training statistics (on same line).
        print('\r' + stats, end="")
        sys.stdout.flush()
        
        # Print training statistics to file.
        f.write(stats + '\n')
        f.flush()
        
        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print('\r' + stats)
            
    # Save the weights.
    if epoch % save_every == 0:
        torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
        torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))

# Close the training log file.
f.close()

Epoch [1/3], Step [100/12942], Loss: 3.7029, Perplexity: 40.5637
Epoch [1/3], Step [200/12942], Loss: 3.2600, Perplexity: 26.04925
Epoch [1/3], Step [300/12942], Loss: 3.1407, Perplexity: 23.12118
Epoch [1/3], Step [400/12942], Loss: 3.4171, Perplexity: 30.48000
Epoch [1/3], Step [500/12942], Loss: 2.9796, Perplexity: 19.6803
Epoch [1/3], Step [600/12942], Loss: 3.0412, Perplexity: 20.9305
Epoch [1/3], Step [700/12942], Loss: 3.0196, Perplexity: 20.48269
Epoch [1/3], Step [800/12942], Loss: 3.0193, Perplexity: 20.4765
Epoch [1/3], Step [900/12942], Loss: 3.1613, Perplexity: 23.6009
Epoch [1/3], Step [1000/12942], Loss: 2.8121, Perplexity: 16.6450
Epoch [1/3], Step [1100/12942], Loss: 2.6062, Perplexity: 13.54773
Epoch [1/3], Step [1200/12942], Loss: 2.4563, Perplexity: 11.6614
Epoch [1/3], Step [1300/12942], Loss: 3.3484, Perplexity: 28.4565
Epoch [1/3], Step [1400/12942], Loss: 2.9515, Perplexity: 19.1342
Epoch [1/3], Step [1500/12942], Loss: 2.6618, Perplexity: 14.3220
Epoch [1/3], S

Epoch [2/3], Step [11800/12942], Loss: 2.0734, Perplexity: 7.95147
Epoch [2/3], Step [11900/12942], Loss: 2.0568, Perplexity: 7.82068
Epoch [2/3], Step [12000/12942], Loss: 1.8501, Perplexity: 6.36079
Epoch [2/3], Step [12100/12942], Loss: 2.0179, Perplexity: 7.52280
Epoch [2/3], Step [12200/12942], Loss: 2.1506, Perplexity: 8.59005
Epoch [2/3], Step [12300/12942], Loss: 2.1804, Perplexity: 8.85001
Epoch [2/3], Step [12400/12942], Loss: 2.1925, Perplexity: 8.95767
Epoch [2/3], Step [12500/12942], Loss: 1.8679, Perplexity: 6.47477
Epoch [2/3], Step [12600/12942], Loss: 2.0025, Perplexity: 7.40765
Epoch [2/3], Step [12700/12942], Loss: 2.0324, Perplexity: 7.63275
Epoch [2/3], Step [12800/12942], Loss: 1.9940, Perplexity: 7.34507
Epoch [2/3], Step [12900/12942], Loss: 2.0166, Perplexity: 7.51315
Epoch [3/3], Step [100/12942], Loss: 2.2074, Perplexity: 9.0924826
Epoch [3/3], Step [200/12942], Loss: 2.3474, Perplexity: 10.4586
Epoch [3/3], Step [300/12942], Loss: 2.1566, Perplexity: 8.64149