In [1]:
!rm -rf data/proc_data_files/*

In [2]:
from utils import create_input_files

create_input_files(dataset='data/caption_data_0_100.csv', 
                   image_folder='data/images/', 
                   captions_per_image=1000, 
                   min_word_freq=0, 
                   output_folder='data/proc_data_files/', 
                   max_len=20, 
                   key_max_len=10, 
                   num_images_to_train=10)

  from ._conv import register_converters as _register_converters


[nltk_data] Downloading package punkt to /home/as3ek/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


100%|██████████| 7/7 [00:00<00:00, 61.80it/s]


Reading TRAIN images and captions, storing to file...




100%|██████████| 1/1 [00:00<00:00, 91.68it/s]
100%|██████████| 4/4 [00:00<00:00, 84.33it/s]


Reading VAL images and captions, storing to file...


Reading TEST images and captions, storing to file...






In [3]:
# Initial Imports
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

In [4]:
# Data parameters
data_folder = 'data/proc_data_files/'  # folder with data files saved by create_input_files.py
data_name = 'meme_1000_cap_per_img_0_min_word_freq'  # base name shared by data files

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 120  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 32
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every __ batches
fine_tune_encoder = False  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none

In [5]:
def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Read word map
    word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    # Initialize / load checkpoint
    if checkpoint is None:
        decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
        encoder = Encoder()
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

    # Move to GPU, if available
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    # Loss function
    criterion = nn.CrossEntropyLoss().to(device)

    # Custom dataloaders
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
        batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
        if epochs_since_improvement == 20:
            break
        if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        # One epoch's training
        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        # One epoch's validation
        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                        decoder_optimizer, recent_bleu4, is_best)

In [6]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.

    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches (Added keys to be used in the future)
    for i, (imgs, caps, keys, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

In [7]:
def validate(val_loader, encoder, decoder, criterion):
    """
    Performs one epoch's validation.

    :param val_loader: DataLoader for validation data.
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :return: BLEU-4 score
    """
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    # Batches (Added keys and allkeys for later)
    for i, (imgs, caps, keys, caplens, allcaps, allkeys) in enumerate(val_loader):

        # Move to device, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        if encoder is not None:
            imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores_copy = scores.clone()
        scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True)

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Keep track of metrics
        losses.update(loss.item(), sum(decode_lengths))
        top5 = accuracy(scores, targets, 5)
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        if i % print_freq == 0:
            print('Validation: [{0}/{1}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                            loss=losses, top5=top5accs))

        # Store references (true captions), and hypothesis (prediction) for each image
        # If for n images, we have n hypotheses, and references a, b, c... for each image, we need -
        # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b], ...], hypotheses = [hyp1, hyp2, ...]

        # References
        allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
        for j in range(allcaps.shape[0]):
            img_caps = allcaps[j].tolist()
            img_captions = list(
                map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                    img_caps))  # remove <start> and pads
            references.append(img_captions)

        # Hypotheses
        _, preds = torch.max(scores_copy, dim=2)
        preds = preds.tolist()
        temp_preds = list()
        for j, p in enumerate(preds):
            temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
        preds = temp_preds
        hypotheses.extend(preds)

        assert len(references) == len(hypotheses)

    # Calculate BLEU-4 scores
    bleu4 = corpus_bleu(references, hypotheses)

    print(
        '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
            loss=losses,
            top5=top5accs,
            bleu=bleu4))

    return bleu4

In [None]:
main()

Epoch: [0][0/219]	Batch Time 2.118 (2.118)	Data Load Time 0.166 (0.166)	Loss 11.3139 (11.3139)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][100/219]	Batch Time 0.290 (0.308)	Data Load Time 0.000 (0.002)	Loss 6.5221 (7.5644)	Top-5 Accuracy 40.052 (30.605)
Epoch: [0][200/219]	Batch Time 0.307 (0.299)	Data Load Time 0.000 (0.001)	Loss 6.1619 (7.0295)	Top-5 Accuracy 39.579 (34.368)
Validation: [0/32]	Batch Time 0.364 (0.364)	Loss 8.1212 (8.1212)	Top-5 Accuracy 19.536 (19.536)	


The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()



 * LOSS - 8.007, TOP-5 ACCURACY - 18.479, BLEU-4 - 1.262575531466194e-78

Epoch: [1][0/219]	Batch Time 0.449 (0.449)	Data Load Time 0.142 (0.142)	Loss 6.2111 (6.2111)	Top-5 Accuracy 37.469 (37.469)
Epoch: [1][100/219]	Batch Time 0.303 (0.291)	Data Load Time 0.000 (0.002)	Loss 5.5070 (6.0712)	Top-5 Accuracy 49.115 (41.713)
Epoch: [1][200/219]	Batch Time 0.295 (0.291)	Data Load Time 0.000 (0.001)	Loss 5.6932 (6.0007)	Top-5 Accuracy 47.087 (42.463)
Validation: [0/32]	Batch Time 0.348 (0.348)	Loss 7.4979 (7.4979)	Top-5 Accuracy 26.349 (26.349)	

 * LOSS - 7.829, TOP-5 ACCURACY - 21.678, BLEU-4 - 2.2475891527506337e-78

Epoch: [2][0/219]	Batch Time 0.449 (0.449)	Data Load Time 0.147 (0.147)	Loss 5.4376 (5.4376)	Top-5 Accuracy 51.282 (51.282)
Epoch: [2][100/219]	Batch Time 0.292 (0.291)	Data Load Time 0.000 (0.002)	Loss 5.8411 (5.6909)	Top-5 Accuracy 45.714 (45.461)
Epoch: [2][200/219]	Batch Time 0.291 (0.290)	Data Load Time 0.000 (0.001)	Loss 5.4883 (5.6675)	Top-5 Accuracy 46.512 (45.793)


Epoch: [15][200/219]	Batch Time 0.285 (0.290)	Data Load Time 0.000 (0.001)	Loss 3.6134 (3.7627)	Top-5 Accuracy 65.217 (63.728)
Validation: [0/32]	Batch Time 0.367 (0.367)	Loss 8.5565 (8.5565)	Top-5 Accuracy 24.047 (24.047)	

 * LOSS - 8.360, TOP-5 ACCURACY - 24.320, BLEU-4 - 0.028669435150576756


Epochs since last improvement: 1

Epoch: [16][0/219]	Batch Time 0.459 (0.459)	Data Load Time 0.146 (0.146)	Loss 3.5177 (3.5177)	Top-5 Accuracy 65.741 (65.741)
Epoch: [16][100/219]	Batch Time 0.291 (0.291)	Data Load Time 0.000 (0.002)	Loss 3.8503 (3.5939)	Top-5 Accuracy 64.216 (66.343)
Epoch: [16][200/219]	Batch Time 0.289 (0.290)	Data Load Time 0.000 (0.001)	Loss 3.5147 (3.6412)	Top-5 Accuracy 67.574 (65.729)
Validation: [0/32]	Batch Time 0.382 (0.382)	Loss 8.5344 (8.5344)	Top-5 Accuracy 24.138 (24.138)	

 * LOSS - 8.562, TOP-5 ACCURACY - 23.463, BLEU-4 - 0.028872872444183288


Epochs since last improvement: 2

Epoch: [17][0/219]	Batch Time 0.435 (0.435)	Data Load Time 0.146 (0.146)	Loss 3.59

Validation: [0/32]	Batch Time 0.367 (0.367)	Loss 9.7515 (9.7515)	Top-5 Accuracy 20.059 (20.059)	

 * LOSS - 9.470, TOP-5 ACCURACY - 21.449, BLEU-4 - 0.031838689444593546


Epochs since last improvement: 7

Epoch: [30][0/219]	Batch Time 0.446 (0.446)	Data Load Time 0.148 (0.148)	Loss 2.3428 (2.3428)	Top-5 Accuracy 84.697 (84.697)
Epoch: [30][100/219]	Batch Time 0.298 (0.291)	Data Load Time 0.000 (0.002)	Loss 2.2875 (2.3863)	Top-5 Accuracy 85.176 (84.872)
Epoch: [30][200/219]	Batch Time 0.286 (0.290)	Data Load Time 0.000 (0.001)	Loss 2.4969 (2.4213)	Top-5 Accuracy 85.024 (84.407)
Validation: [0/32]	Batch Time 0.359 (0.359)	Loss 9.5167 (9.5167)	Top-5 Accuracy 21.856 (21.856)	

 * LOSS - 9.522, TOP-5 ACCURACY - 21.539, BLEU-4 - 0.035088314514798026


Epochs since last improvement: 8


DECAYING learning rate.
The new learning rate is 0.000320

Epoch: [31][0/219]	Batch Time 0.453 (0.453)	Data Load Time 0.148 (0.148)	Loss 2.2985 (2.2985)	Top-5 Accuracy 85.279 (85.279)
Epoch: [31][100/219]	Bat

Validation: [0/32]	Batch Time 0.412 (0.412)	Loss 9.8847 (9.8847)	Top-5 Accuracy 22.356 (22.356)	

 * LOSS - 10.204, TOP-5 ACCURACY - 20.213, BLEU-4 - 0.03647374456474386


Epochs since last improvement: 9

Epoch: [44][0/219]	Batch Time 0.450 (0.450)	Data Load Time 0.148 (0.148)	Loss 1.9379 (1.9379)	Top-5 Accuracy 90.231 (90.231)
Epoch: [44][100/219]	Batch Time 0.285 (0.292)	Data Load Time 0.000 (0.002)	Loss 1.8453 (1.8879)	Top-5 Accuracy 90.152 (90.125)
Epoch: [44][200/219]	Batch Time 0.286 (0.290)	Data Load Time 0.000 (0.001)	Loss 1.9239 (1.9032)	Top-5 Accuracy 89.182 (89.930)
Validation: [0/32]	Batch Time 0.360 (0.360)	Loss 11.5109 (11.5109)	Top-5 Accuracy 15.987 (15.987)	

 * LOSS - 10.235, TOP-5 ACCURACY - 20.243, BLEU-4 - 0.031693227286111424


Epochs since last improvement: 10

Epoch: [45][0/219]	Batch Time 0.440 (0.440)	Data Load Time 0.148 (0.148)	Loss 1.8740 (1.8740)	Top-5 Accuracy 91.105 (91.105)
Epoch: [45][100/219]	Batch Time 0.296 (0.291)	Data Load Time 0.000 (0.002)	Loss 