### For training EfficientNet-LSTM model

In [None]:
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from dataset import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

In [2]:
#Convert flickr8k images folder and caption to hdf5 and json
#I have provided these Json files no need to run this code
# create_input_files(dataset='flickr8k',
#                    karpathy_json_path='../data/dataset_flickr8k.json',
#                    image_folder='../data/flickr8k/Images/',
#                    captions_per_image=5,
#                    min_word_freq=5, # words with less than 5 freq will be <unk>
#                    output_folder='../data/flickr8k/',
#                    max_len=50)


Reading TRAIN images and captions, storing to file...



100%|██████████| 6000/6000 [00:49<00:00, 121.78it/s]



Reading VAL images and captions, storing to file...



100%|██████████| 1000/1000 [00:09<00:00, 100.38it/s]



Reading TEST images and captions, storing to file...



100%|██████████| 1000/1000 [00:11<00:00, 88.12it/s]


In [2]:
# Data parameters
data_folder = '../data/flickr8k/'  # folder with data files
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'  # base name shared by data files

In [3]:
# Model parameters
emb_dim = 256  # dimension of word embeddings
attention_dim = 256  # dimension of attention linear layers
decoder_dim = 256  # dimension of decoder LSTM
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for using gpu
cudnn.benchmark = True

# Training parameters
start_epoch = 0
epochs = 100  # max epochs
epochs_since_improvement = 0
batch_size = 32
workers = 12
decoder_lr = 1e-3
decay_lr = 0.5 # decay lr by this factor in case of no improvement in training
grad_clip = 5.  # clip gradients at 5
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention'
early_stop_epochs = 10 # stop training if no improvement for 10 epochs
decay_lr_epochs = 4
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 100  # print training/validation stats every 100 batches
checkpoint = None  # path to checkpoint

global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

In [4]:
# Read word map
word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)

In [5]:
# Initialize / load saved model
if checkpoint is None:
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout)
    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                         lr=decoder_lr)
    encoder = Encoder()

else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    epochs_since_improvement = checkpoint['epochs_since_improvement']
    best_bleu4 = checkpoint['bleu-4']
    decoder = checkpoint['decoder']
    decoder_optimizer = checkpoint['decoder_optimizer']
    encoder = checkpoint['encoder']

In [6]:
def train(train_loader, encoder, decoder, criterion, decoder_optimizer, epoch):
    """
    Performs training

    Parameters:-
    train_loader: DataLoader for training data
    encoder: encoder model
    decoder: decoder model
    criterion: loss
    decoder_optimizer: optimizer for decoder
    epoch: max epochs
    """

    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()

    batch_time = AverageMeter()  # forward + back propogation time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss per word
    top5accs = AverageMeter()  # top5 accuracy

    start = time.time()

    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)

        # Move to GPU
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward propogation
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

        # Calculate loss
        loss = criterion(scores, targets)
        
        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        loss.backward()
        
        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()

        # Metrics
        top5 = accuracy(scores, targets, 5) # to calculate if target is in top 5 predicted words
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print training status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

In [7]:
def validate(val_loader, encoder, decoder, criterion):
    """
    Performs validation

    Parameters:-
    val_loader: DataLoader for validation data
    encoder: encoder model
    decoder: decoder model
    criterion: loss

    Returns:- BLEU-4 score
    """
    decoder.eval() # eval mode (turn off dropout or batchnorm)
    encoder.eval()

    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()

    start = time.time()

    references = list()  # targets
    hypotheses = list()  # predictions

    with torch.no_grad():
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):

            # Move to gpu, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)
            allcaps = allcaps.to(device)

            # Forward propogation
            imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Remove <start>
            targets = caps_sorted[:, 1:]

            # Remove pads
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

            # Calculate loss
            loss = criterion(scores, targets)
            
            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5) # to calculate if target is in top 5 predicted words
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))

            # Storing targets and predictions for each image
            allcaps = allcaps[sort_ind] 
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}],
                        img_caps))  # remove <start> and pads
                references.append(img_captions)

            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j, p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
            preds = temp_preds
            hypotheses.extend(preds)

            assert len(references) == len(hypotheses)

        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)

        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4

In [None]:
# Move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Transform the images
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose([normalize])

# Dataloaders
train_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'TRAIN', transform=transform), 
                                           batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(CaptionDataset(data_folder, data_name, 'VAL', transform=transform), 
                                         batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

for epoch in range(start_epoch, epochs):

    # Early stopping
    if epochs_since_improvement == early_stop_epochs:
        break

    # Decay learning rate
    if epochs_since_improvement > 0 and epochs_since_improvement % decay_lr_epochs == 0:
        decay_learning_rate(decoder_optimizer, decay_lr)

    # Train
    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

    # Get bleu score from validation data
    recent_bleu4 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    # Check for improvemrnt in bleu score
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save model parameters
    save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder,
                    decoder_optimizer, recent_bleu4, is_best)