## Install Packages

In [1]:
%%capture

!pip install pycocotools
!pip install nltk

## Imports

In [2]:
# Python Modules
# --------------------------------------------------
import sys
import math
import os
import requests
import time
import nltk
import numpy as np


# Torch Modules
# --------------------------------------------------
import torch
import torch.nn         as nn
import torch.utils.data as data

from torchvision import transforms


# Third Party Modules
# --------------------------------------------------
from pycocotools.coco import COCO


# Custom Modules
# --------------------------------------------------
% load_ext autoreload
% autoreload 2

import config
import utils

from Model import (
    EncoderCNN, 
    DecoderRNN
)


# Settings
# --------------------------------------------------
sys.path.append(config.COCO_API_PATH)

nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

## Initialize Image Transformer

In [3]:
transform_train = transforms.Compose([ 
    transforms.Resize(config.EMBED_SIZE),                        
    transforms.RandomCrop(config.INPUT_SIZE),                      
    transforms.RandomHorizontalFlip(),             
    transforms.ToTensor(),                          
    transforms.Normalize(
        config.IMAGENET_MU,      
        config.IMAGENET_SIGMA
    )
])

## Initialize Data Loader

In [4]:
data_loader = utils.get_loader(
    transform       = transform_train,
    mode            = "train",
    batch_size      = config.TRAIN_BATCH_SIZE,
    ann_file        = config.ANN_CAPTIONS_TRAIN_FILE,
    vocab_from_file = config.LOAD_VOCAB_FILE
)

loading annotations into memory...
Done (t=0.89s)
creating index...


  0%|          | 383/414113 [00:00<01:48, 3826.87it/s]

index created!


100%|██████████| 414113/414113 [01:27<00:00, 4718.10it/s]


## Initialize Encoder and Decoder

In [5]:
# Get Vocabulary Size
# --------------------------------------------------
vocab_size = len(data_loader.dataset.vocab)


# Initialize Models
# --------------------------------------------------
encoder = EncoderCNN()
decoder = DecoderRNN(vocab_size)


# Move Models to Device
# --------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder.to(device)
decoder.to(device)


# Initialize Loss Function
# -------------------------------------------------- 
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()


# Initialize Hyperparameters
# -------------------------------------------------- 
params = list(decoder.parameters())       + \
         list(encoder.embed.parameters()) + \
         list(encoder.norm.parameters())


# Initialize Optimizer
# -------------------------------------------------- 
optimizer = torch.optim.Adam(params)


# Define Training Steps
# -------------------------------------------------- 
total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.torch/models/resnet50-19c8e357.pth
100%|██████████| 102502400/102502400 [00:03<00:00, 27570857.81it/s]


## Train Model

In [2]:
# Open Train Log
# --------------------------------------------------
f = open(config.TRAIN_LOG_FILE, "w")


# Initialize Time
# --------------------------------------------------
start_time = time.time()
response   = requests.request(
    "GET", 
    "http://metadata.google.internal/computeMetadata/v1/instance/attributes/keep_alive_token", 
    headers = {
        "Metadata-Flavor" : "Google"
    }
)


# Train Model
# --------------------------------------------------    
for epoch in range(1, config.TRAIN_EPOCHS + 1):
    for i_step in range(1, total_step + 1):
        if time.time() - start_time > 60:
            start_time = time.time()
            
            requests.request(
                "POST", 
                "https://nebula.udacity.com/api/v1/remote/keep-alive", 
                headers = {
                    "Authorization": "STAR " + response.text
                })


        # Sample Captions by Length to Generate Batch
        # --------------------------------------------------
        indices     = data_loader.dataset.get_train_indices()
        new_sampler = data.sampler.SubsetRandomSampler(indices = indices)

        data_loader.batch_sampler.sampler = new_sampler
        images, captions                  = next(iter(data_loader))


        # Move Batch to Training Device
        # --------------------------------------------------
        images   = images.to(device)
        captions = captions.to(device)


        # Zero Gradients
        # --------------------------------------------------
        decoder.zero_grad()
        encoder.zero_grad()


        # Send Inputs to Encoder and Decoder
        # --------------------------------------------------
        features = encoder(images)
        outputs  = decoder(features, captions)


        # Calculate Loss
        # --------------------------------------------------
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))


        # Update Weights
        # --------------------------------------------------
        loss.backward()
        optimizer.step()


        # Get Training Statistics
        # --------------------------------------------------
        stats = (
            f"--------------------------------------------------\n"
            f"Epoch:      [{epoch}/{config.TRAIN_EPOCHS}]\n"
            f"Step:       [{i_step}/{total_step}]\n"
            f"Loss:       {round(loss.item(), 4)}\n"
            f"Perplexity: {round(np.exp(loss.item()), 5)}\n"
        )


        # Print Stats to Console\
        # --------------------------------------------------
        print("\r" + stats, end = "")
        sys.stdout.flush()

        if i_step % config.TRAIN_PRINT_FREQ == 0:
            print("\r" + stats)


        # Save Stats to Logs
        # --------------------------------------------------
        f.write(stats + "\n")
        f.flush()


    # Save Epoch Weights
    # --------------------------------------------------
    if epoch % config.TRAIN_SAVE_FREQ == 0:
        torch.save(decoder.state_dict(), f"{config.MODEL_DIR}/decoder-{epoch}.pkl")
        torch.save(encoder.state_dict(), f"{config.MODEL_DIR}/encoder-{epoch}.pkl")


# Close Train Log
# --------------------------------------------------
f.close()