In [7]:
import os
import numpy as np
import h5py
import json
import torch
from scipy.misc import imread, imresize
from tqdm import tqdm
from collections import Counter
from random import seed, choice, sample
import time
import numpy as np
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu
from torch.utils.tensorboard import SummaryWriter
import numpy as np

In [3]:
# from utils import create_input_files
# #Create input file hdf5
# create_input_files(dataset='coco',
#                    karpathy_json_path='../caption_datasets/dataset_coco.json',
#                    image_folder='../images_data/',
#                    captions_per_image=5,
#                    min_word_freq=5,
#                    output_folder='../output_train/',
#                    max_len=50)

In [4]:
# Data parameters
writer = SummaryWriter(log_dir="tensorboard/")

# Data parameters
data_folder = '../output_train/'  # folder with data files saved by create_input_files.py
data_name = 'coco_5_cap_per_img_5_min_word_freq'  # base name shared by data files
ckpt_name = 'coco_{}_epochs__{bleu:.4f}_bleu__{loss:.4f}_loss__{acc:.4f}_accu'  #ckpt name

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none
#checkpoint = "checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar"

In [5]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.

    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()


        #write to tensorboard
        writer.add_scalar('training_loss', losses.val, epoch)
        writer.add_scalar('training_accuracy', top5accs.val, epoch)


        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

In [6]:
def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.

    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # explicitly disable gradient calculation to avoid CUDA memory error
    # solves the issue #57
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores= pack_padded_sequence(scores, decode_lengths, batch_first=True).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

            # References
            allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

        #write to tensorboard
        writer.add_scalar('validation_loss', losses.val, epoch)
        writer.add_scalar('validation_accuracy', top5accs.val, epoch)
        writer.add_scalar('validation_bleu4', bleu4, epoch)

        

    return bleu4, losses.avg, top5accs.avg

# Training

In [None]:
"""
Training and validation.
"""

global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

# Read word map
word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)

# Initialize / load checkpoint
if checkpoint is None:

    emb_dim=100 #remove if not usiong pretrained model
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout)
    pretrained_embeddings = decoder.create_pretrained_embedding_matrix(word_map)
    decoder.load_pretrained_embeddings(
        pretrained_embeddings)  # pretrained_embeddings should be of dimensions (len(word_map), emb_dim)
    decoder.fine_tune_embeddings(True)

    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                         lr=encoder_lr) if fine_tune_encoder else None

else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    epochs_since_improvement = checkpoint['epochs_since_improvement']
    best_bleu4 = checkpoint['bleu-4']
    decoder = checkpoint['decoder']
    decoder_optimizer = checkpoint['decoder_optimizer']
    encoder = checkpoint['encoder']
    encoder_optimizer = checkpoint['encoder_optimizer']
    if fine_tune_encoder is True and encoder_optimizer is None:
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr)

# Move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

# Epochs
for epoch in range(start_epoch, epochs):

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

    # One epoch's validation
    recent_bleu4, val_loss_avg, val_accu_avg = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    print("Saving model to file",ckpt_name.format(epoch, bleu=recent_bleu4, loss=val_loss_avg, acc=val_accu_avg))
    save_checkpoint(ckpt_name.format(epoch, bleu=recent_bleu4, loss=val_loss_avg, acc=val_accu_avg), 
                    epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, recent_bleu4, is_best)

#close tensorboard writer
writer.close()

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /tmp/xdg-cache/torch/checkpoints/resnet152-b121ed2d.pth
100%|██████████| 230M/230M [00:09<00:00, 24.4MB/s] 


Epoch: [0][0/17702]	Batch Time 1.836 (1.836)	Data Load Time 0.362 (0.362)	Loss 10.0520 (10.0520)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/17702]	Batch Time 0.244 (0.266)	Data Load Time 0.000 (0.004)	Loss 6.0038 (6.6746)	Top-5 Accuracy 38.043 (34.473)
Epoch: [0][200/17702]	Batch Time 0.256 (0.260)	Data Load Time 0.000 (0.002)	Loss 5.8753 (6.2542)	Top-5 Accuracy 39.733 (37.475)
Epoch: [0][300/17702]	Batch Time 0.247 (0.260)	Data Load Time 0.001 (0.004)	Loss 5.5189 (6.0203)	Top-5 Accuracy 44.011 (39.801)
Epoch: [0][400/17702]	Batch Time 0.257 (0.259)	Data Load Time 0.000 (0.003)	Loss 5.1371 (5.8373)	Top-5 Accuracy 48.209 (41.928)
Epoch: [0][500/17702]	Batch Time 0.246 (0.257)	Data Load Time 0.000 (0.003)	Loss 5.1928 (5.7003)	Top-5 Accuracy 49.721 (43.551)
Epoch: [0][600/17702]	Batch Time 0.245 (0.256)	Data Load Time 0.000 (0.002)	Loss 4.9300 (5.5874)	Top-5 Accuracy 52.394 (44.978)
Epoch: [0][700/17702]	Batch Time 0.251 (0.256)	Data Load Time 0.000 (0.002)	Loss 4.9594 (5.4900)	Top-5 Acc

Epoch: [0][6400/17702]	Batch Time 0.241 (0.252)	Data Load Time 0.000 (0.001)	Loss 3.6125 (4.3052)	Top-5 Accuracy 67.429 (61.501)
Epoch: [0][6500/17702]	Batch Time 0.264 (0.252)	Data Load Time 0.000 (0.001)	Loss 3.9032 (4.2983)	Top-5 Accuracy 67.488 (61.589)
Epoch: [0][6600/17702]	Batch Time 0.244 (0.252)	Data Load Time 0.001 (0.001)	Loss 3.7226 (4.2921)	Top-5 Accuracy 68.144 (61.668)
Epoch: [0][6700/17702]	Batch Time 0.257 (0.252)	Data Load Time 0.000 (0.001)	Loss 3.3581 (4.2856)	Top-5 Accuracy 72.849 (61.750)
Epoch: [0][6800/17702]	Batch Time 0.283 (0.252)	Data Load Time 0.001 (0.001)	Loss 4.1149 (4.2794)	Top-5 Accuracy 64.491 (61.832)
Epoch: [0][6900/17702]	Batch Time 0.252 (0.252)	Data Load Time 0.000 (0.001)	Loss 3.7311 (4.2731)	Top-5 Accuracy 67.213 (61.912)
Epoch: [0][7000/17702]	Batch Time 0.266 (0.252)	Data Load Time 0.001 (0.001)	Loss 3.8839 (4.2667)	Top-5 Accuracy 66.667 (61.996)
Epoch: [0][7100/17702]	Batch Time 0.249 (0.252)	Data Load Time 0.000 (0.001)	Loss 3.5835 (4.2609)

Epoch: [0][12800/17702]	Batch Time 0.254 (0.253)	Data Load Time 0.000 (0.001)	Loss 4.2495 (4.0377)	Top-5 Accuracy 63.788 (64.970)
Epoch: [0][12900/17702]	Batch Time 0.252 (0.253)	Data Load Time 0.000 (0.001)	Loss 3.6302 (4.0349)	Top-5 Accuracy 69.272 (65.006)
Epoch: [0][13000/17702]	Batch Time 0.251 (0.253)	Data Load Time 0.000 (0.001)	Loss 3.8809 (4.0322)	Top-5 Accuracy 65.782 (65.040)
Epoch: [0][13100/17702]	Batch Time 0.255 (0.253)	Data Load Time 0.000 (0.001)	Loss 3.9177 (4.0295)	Top-5 Accuracy 68.533 (65.076)
Epoch: [0][13200/17702]	Batch Time 0.251 (0.253)	Data Load Time 0.001 (0.001)	Loss 3.6724 (4.0269)	Top-5 Accuracy 70.492 (65.111)
Epoch: [0][13300/17702]	Batch Time 0.258 (0.253)	Data Load Time 0.000 (0.001)	Loss 3.5583 (4.0243)	Top-5 Accuracy 68.966 (65.145)
Epoch: [0][13400/17702]	Batch Time 0.255 (0.253)	Data Load Time 0.001 (0.001)	Loss 3.5800 (4.0214)	Top-5 Accuracy 69.072 (65.182)
Epoch: [0][13500/17702]	Batch Time 0.249 (0.253)	Data Load Time 0.000 (0.001)	Loss 3.5149 

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: [1][0/17702]	Batch Time 0.828 (0.828)	Data Load Time 0.420 (0.420)	Loss 3.5461 (3.5461)	Top-5 Accuracy 71.504 (71.504)
Epoch: [1][100/17702]	Batch Time 0.244 (0.263)	Data Load Time 0.000 (0.005)	Loss 3.5828 (3.5853)	Top-5 Accuracy 69.859 (70.608)
Epoch: [1][200/17702]	Batch Time 0.252 (0.258)	Data Load Time 0.000 (0.003)	Loss 3.5648 (3.5870)	Top-5 Accuracy 71.186 (70.600)
Epoch: [1][300/17702]	Batch Time 0.279 (0.257)	Data Load Time 0.000 (0.002)	Loss 3.7415 (3.5842)	Top-5 Accuracy 67.847 (70.676)
Epoch: [1][400/17702]	Batch Time 0.262 (0.257)	Data Load Time 0.000 (0.001)	Loss 3.5561 (3.5847)	Top-5 Accuracy 70.213 (70.674)
Epoch: [1][500/17702]	Batch Time 0.249 (0.257)	Data Load Time 0.000 (0.001)	Loss 3.9873 (3.5736)	Top-5 Accuracy 67.655 (70.785)
Epoch: [1][600/17702]	Batch Time 0.247 (0.256)	Data Load Time 0.000 (0.001)	Loss 3.3077 (3.5707)	Top-5 Accuracy 71.429 (70.781)
Epoch: [1][700/17702]	Batch Time 0.245 (0.256)	Data Load Time 0.000 (0.001)	Loss 3.7847 (3.5756)	Top-5 Acc

Epoch: [1][7000/17702]	Batch Time 0.245 (0.255)	Data Load Time 0.001 (0.001)	Loss 3.5590 (3.5433)	Top-5 Accuracy 69.399 (71.302)
Epoch: [1][7100/17702]	Batch Time 0.260 (0.255)	Data Load Time 0.000 (0.001)	Loss 3.6193 (3.5434)	Top-5 Accuracy 71.540 (71.300)
Epoch: [1][7200/17702]	Batch Time 0.248 (0.255)	Data Load Time 0.000 (0.001)	Loss 3.7979 (3.5428)	Top-5 Accuracy 69.741 (71.306)
Epoch: [1][7300/17702]	Batch Time 0.278 (0.255)	Data Load Time 0.001 (0.001)	Loss 3.8981 (3.5424)	Top-5 Accuracy 66.491 (71.313)
Epoch: [1][7400/17702]	Batch Time 0.247 (0.255)	Data Load Time 0.000 (0.001)	Loss 3.4411 (3.5425)	Top-5 Accuracy 74.521 (71.313)
Epoch: [1][7500/17702]	Batch Time 0.261 (0.255)	Data Load Time 0.000 (0.001)	Loss 3.9373 (3.5423)	Top-5 Accuracy 61.917 (71.319)
Epoch: [1][7600/17702]	Batch Time 0.254 (0.255)	Data Load Time 0.000 (0.001)	Loss 3.4290 (3.5421)	Top-5 Accuracy 74.242 (71.323)
Epoch: [1][7700/17702]	Batch Time 0.258 (0.255)	Data Load Time 0.000 (0.001)	Loss 3.6273 (3.5418)