# Image captioning

## Import libraries

In [1]:
# faster version of json
# !pip install ujson

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from models import Encoder, Decoder

from pathlib import Path
from get_loader import MSCOCODataset, get_loader

from data_prep_utils import *
from utils import train, save_model, load_model, plot_loss
import json
import numpy as np
import time

###### download the data we need
# !cd ~/INM706-image-captioning/Datasets/coco/images/
# !wget http://images.cocodataset.org/zips/train2017.zip
# !wget http://images.cocodataset.org/zips/val2017.zip
# !unzip train2017.zip
# !unzip val2017.zip
# !rm train2017.zip
# !rm val2017.zip

##### run code below if nltk hasn't been set up in clound instance yet
# !python -m nltk.downloader -d /usr/local/share/nltk_data all

###### run code below to save pre-trained weights if needed
# cd ~/INM706-image-captioning/model
# !wget https://download.pytorch.org/models/resnet152-394f9c45.pth
# !mv resnet152-394f9c45.pth resnet152_model.pth

### Choose hyper parameters

In [2]:
# for building vocab
FREQ_THRESHOLD = 5
CAPTIONS_FILE = 'captions_train2017.json'

# for data loader
BATCH_SIZE = 32
CAPS_PER_IMAGE = 5 # how many captions for each image to include in data set

# for encoder and decoder
EMBED_SIZE = 512 # dimension of vocab embedding vector
HIDDEN_SIZE = 512
NUM_LAYERS = 1 #hidden layers in LTSM

# optimiser parameters
OPT_PARAMS = {'lr':1e-3, 'weight_decay':1e-3, 'betas':(0.9, 0.999), 'eps':1e-08}

# training parameters
TOTAL_EPOCH = 15
CHECKPOINT = '../model/image_captioning_model_v10.pth'
PRINT_EVERY = 300 # run print_every batches and then
# print running results. For bigger batch size make this 
# number smaller if you want to see regular output

## Load dataset and dataloader

In [None]:
root = Path('../Datasets/coco')
imgs_path = root/'images'/'train2017'
imgs_path_test = root/'images'/'val2017'

prepare_datasets(train_percent = 0.87, super_categories=None,
                    max_train=45000, max_val=15000, max_test=5000)

train_captions_path = root/'annotations'/'custom_captions_train.json'
val_captions_path = root/'annotations'/'custom_captions_val.json'
test_captions_path = root/'annotations'/'custom_captions_test.json'

#### build vocab using full original coco train
# build_vocab(freq_threshold=FREQ_THRESHOLD,
#             captions_file=CAPTIONS_FILE)

# load vocab
with open('../vocabulary/string_to_index.json') as json_file:
    word2idx = json.load(json_file)

In [None]:
# to boost the performence of CUDA use:
# torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_loader_params = {
    'images_path': imgs_path,
    'captions_path': train_captions_path,
    'freq_threshold': FREQ_THRESHOLD,
    'caps_per_image': CAPS_PER_IMAGE,
    'mode': "train",
    'transform': None,
    'batch_size': BATCH_SIZE,
    'shuffle': True,
    'word2idx': word2idx,
}

val_loader_params = {
    'images_path': imgs_path,
    'captions_path': val_captions_path,
    'freq_threshold': FREQ_THRESHOLD,
    'caps_per_image': 1,
    'mode': "train",
    'transform': None,
    'batch_size': BATCH_SIZE,
    'shuffle': True,
    'word2idx': word2idx,
}

train_dl, train_dataset = get_loader(**train_loader_params)
val_dl, val_dataset = get_loader(**val_loader_params)

vocab_size = len(train_dataset.word2idx)

In [None]:
# this is so we can save hyper params in checkpoint for future use
hyper_params = {'vocab': train_dataset.idx2word,
                'vocab_size': len(train_dataset.idx2word),
                'vocab_captions_file': CAPTIONS_FILE,
                'freq_threshold': FREQ_THRESHOLD,
                'train_imgs_size': len(set([d[0] for d in train_dataset.img_deque])),
                'train_sample_size': len(train_dataset.img_deque),
                'val_imgs_size': len(set([d[0] for d in val_dataset.img_deque])),
                'val_sample_size': len(val_dataset.img_deque),
                'embed_size': EMBED_SIZE,
                'hidden_size': HIDDEN_SIZE,
                'num_layers': NUM_LAYERS,
                'optimizer_params': OPT_PARAMS,
                'batch_size': BATCH_SIZE
               }

### Encoder and decoder

In [None]:
encoder = Encoder(embed_size=EMBED_SIZE, pretrained=False, model_weight_path="../model/resnet152_model.pth")
decoder = Decoder(embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, vocab_size=vocab_size, num_layers=NUM_LAYERS)

### Training paramaters

In [None]:
# the loss is a cross entropy loss and ignore the index of <PAD> since it doesn't make any difference
# not sure we need this with pad method in pytorch
criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.word2idx["<PAD>"]) ############

# combine the parameters of decoder and encoder
params = list(decoder.parameters()) + list(encoder.embed.parameters())

# Adam optimizer
optimizer = optim.Adam(params, **OPT_PARAMS)

#### Load checkpoints if they exist

In [None]:
if Path(CHECKPOINT).exists():
    encoder, decoder, training_loss, validation_loss, hyper_params = load_model(encoder, decoder, CHECKPOINT)
else:
    print(f'{CHECKPOINT} file does not exist, training startging from scratch')
    training_loss = None
    validation_loss = None

## Train the model

In [None]:
train_params = {
    'device': device,
    'encoder': encoder,
    'decoder': decoder,
    'criterion': criterion,
    'optimizer': optimizer,
    'train_loader': train_dl,
    'val_loader': val_dl,
    'total_epoch': TOTAL_EPOCH,
    'training_loss': training_loss,
    'validation_loss': validation_loss,
    'checkpoint_path': CHECKPOINT,
    'print_every': PRINT_EVERY
}

training_loss, validation_loss = train(**train_params) 

In [None]:
batch = next(iter(train_dl))

In [None]:
for b in batch:
    print(type(b))

In [None]:
capts = [1,2,3,4,5]
torch.Tensor(capts)

In [None]:
torch.cat?