# Training Pipeline

In this notebook, we will train the CNN-RNN model.  


### Parameters & Hyperparameters

Begin by setting the following variables:
- `batch_size` - the batch size of each training batch.  It is the number of image-caption pairs used to amend the model weights in each training step. 
- `vocab_threshold` - the minimum word count threshold.    
- `vocab_from_file` - a Boolean that decides whether to load the vocabulary from file. 
- `embed_size` - the dimensionality of the image and word embeddings.  
- `hidden_size` - the number of features in the hidden state of the RNN decoder.  
- `num_epochs` - the number of epochs to train the model.
- `save_every` - determines how often to save the model weights.
- `print_every` - determines how often to print the batch loss to the Jupyter notebook while training.
- `log_file` - the name of the text file containing - for every step - how the loss and perplexity evolved during training.

### CNN-RNN architecture

The CNN-RNN architecture uses a pre-trained ResNet18 as the encoder to extract image features, followed by a linear layer to project them to embed_size=256. The decoder is an LSTM-based RNN with an embedding layer (vocab to embed_size), a single-layer LSTM (input=embed_size, hidden=512), and a linear output to vocab_size. During training, image features are concatenated as the first input to the LSTM sequence, followed by embedded caption words

### Trainable Parameters

All parameters in the decoder (embedding, LSTM, linear) and only the final linear embedding layer in the encoder. This is a good choice because the ResNet18 backbone is pre-trained and frozen (to leverage transfer learning), while the embedding layer adapts features to the decoder's input space

### Optimizer

Adam because it's widely used in image captioning for its adaptive learning rates, momentum, and stability on noisy gradients. lr=0.001 is a standard starting point

In [None]:
import nltk
nltk.download('punkt')

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import sys
sys.path.append('/mnt/disks/legacy-jupytergpu-data/cocoapi/PythonAPI')
from pycocotools.coco import COCO
from data_loader import get_loader
from model import EncoderCNN, DecoderRNN
import math

batch_size = 64           # batch size
vocab_threshold = 5       # minimum word count threshold
vocab_from_file = True    # if True, load existing vocab file
embed_size = 256          # dimensionality of image and word embeddings
hidden_size = 512         # number of features in hidden state of the RNN decoder
num_epochs = 3            # number of training epochs
save_every = 1            # determines frequency of saving model weights
print_every = 100         # determines window for printing average loss
log_file = 'training_log.txt'       # name of file with saved training loss and perplexity

transform_train = transforms.Compose([ 
    transforms.Resize(256),                          # smaller edge of image resized to 128
    transforms.RandomCrop(224),                      # get 224x224 crop from random location
    transforms.ToTensor(),                           # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),      # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))])

# very large value would increase training time considerably
data_loader = get_loader(transform=transform_train,
                         mode='train',
                         batch_size=batch_size,
                         vocab_threshold=vocab_threshold,
                         vocab_from_file=vocab_from_file,
                         subset_size=20000)

# The size of the vocabulary.
vocab_size = len(data_loader.dataset.vocab)

# Initialize the encoder and decoder. 
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to GPU if CUDA is available. 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

# Define the loss function. 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

# Specify the learnable parameters of the model.
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Define the optimizer.
optimizer = torch.optim.Adam(params, lr=0.001)

# Set the total number of training steps per epoch.
total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

Vocabulary successfully loaded from vocab.pkl file!
loading annotations into memory...
Done (t=1.02s)
creating index...
index created!
Obtaining caption lengths for the subset...


100%|██████████| 20000/20000 [00:01<00:00, 11810.84it/s]
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/student/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 93.7MB/s]


## Training the Model

Due to hardware and time constraints, for this project we just need to demonstrate that the model has learned **_something_** when you generate captions on the test data. We'll train it for 3 epochs and transition to the next notebook in the sequence (**inference.ipynb**) to see how the model performs on the test data.

In [3]:
import torch.utils.data as data
import numpy as np
import os
import requests
import time

# Open the training log file.
f = open(log_file, 'w')

accumulation_steps = 4  # Define how many steps to accumulate gradients
optimizer.zero_grad()   # Initialize gradients to zero

for epoch in range(1, num_epochs+1):
    
    for i_step in range(1, total_step+1):
        
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices.
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, captions = next(iter(data_loader))

        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        captions = captions.to(device)
        
        # Pass the inputs through the CNN-RNN model.
        features = encoder(images)
        outputs = decoder(features, captions)
        
        # Calculate the batch loss.
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        loss = loss / accumulation_steps  # Scale the loss by the accumulation steps

        # Backward pass.
        loss.backward()

        if i_step % accumulation_steps == 0 or i_step == total_step:
            # Update the parameters in the optimizer and zero the gradients.
            optimizer.step()
            optimizer.zero_grad()
            
        # Get training statistics.
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Perplexity: %5.4f' % (epoch, num_epochs, i_step, total_step, loss.item() * accumulation_steps, np.exp(loss.item() * accumulation_steps))
        
        # Print training statistics (on same line).
        print('\r' + stats, end="")
        sys.stdout.flush()
        
        # Print training statistics to file.
        f.write(stats + '\n')
        f.flush()
        
        # Print training statistics (on different line).
        if i_step % print_every == 0:
            print('\r' + stats)
            
    # Save the weights.
    if epoch % save_every == 0:
        torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pkl' % epoch))
        torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pkl' % epoch))

# Close the training log file.
f.close()

Epoch [1/3], Step [100/313], Loss: 4.2907, Perplexity: 73.02008
Epoch [1/3], Step [200/313], Loss: 3.9915, Perplexity: 54.13332
Epoch [1/3], Step [300/313], Loss: 3.9909, Perplexity: 54.10365
Epoch [2/3], Step [100/313], Loss: 3.6277, Perplexity: 37.62434
Epoch [2/3], Step [200/313], Loss: 3.4773, Perplexity: 32.37148
Epoch [2/3], Step [300/313], Loss: 3.3537, Perplexity: 28.6071
Epoch [3/3], Step [100/313], Loss: 4.3922, Perplexity: 80.8159
Epoch [3/3], Step [200/313], Loss: 3.2772, Perplexity: 26.5027
Epoch [3/3], Step [300/313], Loss: 3.3029, Perplexity: 27.1919
Epoch [3/3], Step [313/313], Loss: 3.1445, Perplexity: 23.2089