# Training of Encoder-Decoder network with attention mechanism
## Data preparation
* Create our input files from the Microsoft COCO dataset.
* By default COCO come with 5 captions per image. 
* We are going to ignore the captions that are longer than 50 words. As we have seen in the Image-Captioning-Project, they are really rare.
* The output files are stored under the "output_folder" directory

In [4]:
from utils import create_input_files

# This will create big files in your computer.
create_input_files(dataset='coco',
                   karpathy_json_path='dataset_coco.json',
                   image_folder='../cocoapi-master/images',
                   captions_per_image=5,
                   min_word_freq=5,
                   output_folder='../cocoapi-master/images',
                   max_len=50
                  )


Reading TRAIN images and captions, storing to file...



100%|██████████████████████████████████████████████████████████████████████████| 113287/113287 [48:08<00:00, 39.21it/s]



Reading VAL images and captions, storing to file...



100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [03:26<00:00, 32.00it/s]



Reading TEST images and captions, storing to file...



100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [02:57<00:00, 28.14it/s]


In [1]:
data_folder = '../cocoapi-master/images'         # the output_folder of created input_files
data_name = 'coco_5_cap_per_img_5_min_word_freq' # the same way we created them

## Training setup

In [2]:
# first import necessary packages
import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence
from models import Encoder, DecoderWithAttention
from datasets import *
from utils import *
from nltk.translate.bleu_score import corpus_bleu

### specify the model parameters

In [3]:
emb_dim = 512         # dimension of word embeddings
attention_dim = 512   # dimension of attention linear layers
decoder_dim = 512     # dimention of LSTM decoder
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # GPU set for PyTorch
cudnn.benchmark = True # set to "True" only if inputs to model are fixed size; otherwise lot of computational overhead

### specify the training parameters

In [5]:
start_epoch = 0
epochs = 3                    # Max number of training epochs if early-stopping is not triggered
epochs_since_improvement = 0  # This tracks the number of epochs since the last improvement was made
batch_size = 32
encoder_lr = 1e-4             # This is used with encoder if 'fine-tunning' is enabled
decoder_lr = 4e-4             # LSTM learning rate
grad_clip = 5.0               # Clip the gradients at this value
alpha_c = 1.0                 # Regularization for attention; implementation from the original paper
best_bleu4 = 0.0              # This tracks the BLUE-4 score
print_freq = 200    
fine_tune_encoder = True     # Do we train the encoder?
checkpoint = None             # Path to checkpoint, "None" if none
torch.cuda.get_device_name(0) # check if GPU is available

'GeForce GTX 1080 Ti'

## A walk through of what will happen later

In [15]:
# take the next gourpd of training data from loader
imgs, caps, caplens = next(iter(train_loader))

In [16]:
imgs.size()

torch.Size([32, 3, 256, 256])

In [17]:
caps.size()

torch.Size([32, 52])

In [18]:
caps.size()

torch.Size([32, 52])

In [19]:
# load data into GPU device
imgs = imgs.to(device)
caps = caps.to(device)
caplens = caplens.to(device)

# Pass through CNN encoder
imgs = encoder(imgs)   # (batch_size, encoded_image_size, encoded_image_size, 2048)

# Pass through RNN decoder
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

In [20]:
scores.size()

torch.Size([32, 26, 9490])

In [21]:
caps_sorted.size()

torch.Size([32, 52])

In [22]:
len(decode_lengths)

32

In [23]:
# pad the sequence so that we can calculate the loss later
scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
targets = caps_sorted[:, 1:]
targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)

In [26]:
loss = criterion(scores.data, targets.data)

## Now we have the loss to perform back propogate
All the rest can be handled by Pytorch

In [27]:
loss

tensor(2.2427, device='cuda:0', grad_fn=<NllLossBackward>)

## DEFINE Training function

In [8]:
def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    tain_loader: DataLoader for the training data
    encoder: encoder model
    decoder: decoder model
    criterion: loss function
    encoder_optimizer: optimizer for the encoder(if fine-tuning is on)
    decoder_optimizer: optimizer for the LSTM
    epoch: epoch number
    """
    
    decoder.train()    # Training mode
    encoder.train()
    
    # AverageMeter() is imported from utils.py to track the most recent, average, sum, and count of a metric.
    batch_time = AverageMeter()   # forward prop + back prop time
    data_time = AverageMeter()    # data loading time
    losses = AverageMeter()       # loss (per word decoded)
    top5accs = AverageMeter()     # top5 accuracy
    
    start = time.time()
    
    # Batches
    for i, (imgs, caps, caplens) in enumerate(train_loader):
        data_time.update(time.time() - start)
        
        # Move data to GPU
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)
        
        # Forward prop.
        imgs = encoder(imgs)   # (batch_size, encoded_image_size, encoded_image_size, 2048)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens) 
        # scores-(batch_size_t, vocab_size) 
        # caps_sorted-(batch_size, max_caption_length)
        # decode_lengths-[caption_lengths -1]
        # alphas, pixel-wise attention, (batch_size, max(decode_lengths), num_pixels)
        # sort_ind: this index sort the lengths of captions in descending order
        
        # Since we decoded staring with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]
        
        # Remove timesteps that we didn't encode at, or are pads
        # pack_padded_sentence to fo this trick
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)
        
        
        # Compute the loss
        loss = criterion(scores.data, targets.data)
        
        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
        
        # BackProp.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()
        
        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)
        
        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()
            
        # Keep track of metrics
        top5 = accuracy(scores.data, targets.data, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()

        # Print status
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
                                                                          batch_time=batch_time,
                                                                          data_time=data_time, loss=losses,
                                                                          top5=top5accs))

## DEFINE validation function

In [9]:
def validate(val_loader, encoder, decoder, criterion):
    """
    val_loader: DataLoader for validation data
    encoder: trained encoder model
    decoder: trained decoder model
    criterion: loss function
    
    return: BLEU-4 score
    """
    decoder.eval() # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()
    start = time.time()
    
    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # references (predictions)
    
    # explicitly disable gradient calculation to avoid CUDA memory error
    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
            # move to GPU
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)
            
            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
            
            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            # "caps_sorted" is the "caps" tensor sorted with descending caption_length order. (batch_size, max_caption_length)
            targets = caps_sorted[:, 1:]   # this matchs the position of target to predictions
            
            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence can do this trick
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)
            
            # Calculate loss
            loss = criterion(scores.data, targets.data)
            
            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
            
            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores.data, targets.data, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)
            
            start = time.time()
            
            if i % print_freq == 0:
                print('Validation: [{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(val_loader), batch_time=batch_time,
                                                                                loss=losses, top5=top5accs))
                
            # Store references (true captions), and hypothesis (prediction) for each image
            # If for n images, we have n hypotheses, and references a, b, c... for each image, we need
            # references = [[ref1a, ref1b, ref1c], [ref2a, ref2b],...], hypotheses = [hyp1, hyp2, ...]
            
            # References
            allcaps = allcaps[sort_ind] # the captions are sorted in the decoder
            for j in range(allcaps.shape[0]):
                img_caps = allcaps[j].tolist()
                img_captions = list(
                    map(lambda c: [w for w in c if w not in {word_map['<start>'], word_map['<pad>']}], 
                        img_caps))   # remove <start> and pads
                references.append(img_captions)
            
            # Hypotheses
            _, preds = torch.max(scores_copy, dim=2)
            preds = preds.tolist()
            temp_preds = list()
            for j,p in enumerate(preds):
                temp_preds.append(preds[j][:decode_lengths[j]])   # remove pads
            preds = temp_preds
            hypotheses.extend(preds)
            
            assert len(references) == len(hypotheses)
            
        # Calculate BLEU-4 scores
        bleu4 = corpus_bleu(references, hypotheses)
        
        print(
            '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'.format(
                loss=losses,
                top5=top5accs,
                bleu=bleu4))

    return bleu4

# Start training and validation

## load our word_map

In [10]:
global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

word_map_file = os.path.join(data_folder, 'WORDMAP_' + data_name + '.json')
with open(word_map_file, 'r') as j:
    word_map = json.load(j)
print(f"The total size of the vocabulary is :{len(word_map)}")

The total size of the vocabulary is :9490


## Initialize / load checkpoint

In [11]:
if checkpoint is None:
    decoder = DecoderWithAttention(attention_dim=attention_dim,
                                   embed_dim=emb_dim,
                                   decoder_dim=decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=dropout
                                  )
    decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                             lr=decoder_lr)
    encoder = Encoder()
    encoder.fine_tune(fine_tune_encoder)
    encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr) if fine_tune_encoder else None
else:
    checkpoint = torch.load(checkpoint)
    start_epoch = checkpoint['epoch'] + 1
    epochs_since_improvement = checkpoint['epochs_since_improvement']
    best_bleu4 = checkpoint['bleu-4']
    decoder = checkpoint['decoder']
    decoder_optimizer = checkpoint['decoder_optimizer']
    encoder = checkpoint['encoder']
    encoder_optimizer = checkpoint['encoder_optimizer']
    if fine_tune_encoder is True and encoder_optimizer is None:
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=encoder_lr)

## The actual training

In [12]:
# Move to GPU, if available
decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, pin_memory=True)

In [13]:
# Epochs
for epoch in range(start_epoch, epochs):

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

    # One epoch's validation
    recent_bleu4 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)

    # Check if there was an improvement
    is_best = recent_bleu4 > best_bleu4
    best_bleu4 = max(recent_bleu4, best_bleu4)
    if not is_best:
        epochs_since_improvement += 1
        print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))
    else:
        epochs_since_improvement = 0

    # Save checkpoint
    save_checkpoint(data_name, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer,
                    decoder_optimizer, recent_bleu4, is_best)

Epoch: [0][0/17702]	Batch Time 6.600 (6.600)	Data Load Time 0.876 (0.876)	Loss 10.0792 (10.0792)	Top-5 Accuracy 0.000 (0.000)
Epoch: [0][200/17702]	Batch Time 1.141 (1.301)	Data Load Time 0.500 (0.553)	Loss 5.4749 (6.1580)	Top-5 Accuracy 44.180 (38.731)
Epoch: [0][400/17702]	Batch Time 1.156 (1.275)	Data Load Time 0.469 (0.540)	Loss 4.8203 (5.6753)	Top-5 Accuracy 54.717 (44.448)
Epoch: [0][600/17702]	Batch Time 1.172 (1.260)	Data Load Time 0.453 (0.531)	Loss 4.4033 (5.3984)	Top-5 Accuracy 59.524 (47.907)
Epoch: [0][800/17702]	Batch Time 1.172 (1.246)	Data Load Time 0.484 (0.519)	Loss 4.6233 (5.2172)	Top-5 Accuracy 57.143 (50.215)
Epoch: [0][1000/17702]	Batch Time 1.141 (1.228)	Data Load Time 0.422 (0.505)	Loss 4.2600 (5.0831)	Top-5 Accuracy 65.181 (51.967)
Epoch: [0][1200/17702]	Batch Time 1.094 (1.214)	Data Load Time 0.437 (0.493)	Loss 4.5515 (4.9761)	Top-5 Accuracy 58.575 (53.310)
Epoch: [0][1400/17702]	Batch Time 1.031 (1.203)	Data Load Time 0.359 (0.483)	Loss 4.4792 (4.8921)	Top-5 

Epoch: [0][12600/17702]	Batch Time 1.031 (1.104)	Data Load Time 0.375 (0.387)	Loss 3.5631 (3.9005)	Top-5 Accuracy 68.508 (67.119)
Epoch: [0][12800/17702]	Batch Time 1.047 (1.104)	Data Load Time 0.359 (0.387)	Loss 3.4716 (3.8947)	Top-5 Accuracy 71.467 (67.196)
Epoch: [0][13000/17702]	Batch Time 1.000 (1.103)	Data Load Time 0.344 (0.386)	Loss 3.8411 (3.8893)	Top-5 Accuracy 70.195 (67.270)
Epoch: [0][13200/17702]	Batch Time 1.031 (1.103)	Data Load Time 0.344 (0.386)	Loss 3.3406 (3.8842)	Top-5 Accuracy 72.581 (67.336)
Epoch: [0][13400/17702]	Batch Time 1.031 (1.102)	Data Load Time 0.375 (0.385)	Loss 3.1991 (3.8793)	Top-5 Accuracy 76.099 (67.398)
Epoch: [0][13600/17702]	Batch Time 1.125 (1.102)	Data Load Time 0.422 (0.385)	Loss 3.5131 (3.8741)	Top-5 Accuracy 72.441 (67.465)
Epoch: [0][13800/17702]	Batch Time 1.000 (1.102)	Data Load Time 0.359 (0.385)	Loss 3.6828 (3.8692)	Top-5 Accuracy 69.859 (67.530)
Epoch: [0][14000/17702]	Batch Time 1.062 (1.101)	Data Load Time 0.406 (0.385)	Loss 3.3423 

Epoch: [1][6600/17702]	Batch Time 1.047 (1.084)	Data Load Time 0.359 (0.370)	Loss 3.3746 (3.3816)	Top-5 Accuracy 75.562 (73.872)
Epoch: [1][6800/17702]	Batch Time 1.172 (1.084)	Data Load Time 0.359 (0.369)	Loss 3.4370 (3.3806)	Top-5 Accuracy 74.607 (73.882)
Epoch: [1][7000/17702]	Batch Time 1.141 (1.083)	Data Load Time 0.422 (0.369)	Loss 3.4291 (3.3804)	Top-5 Accuracy 72.299 (73.889)
Epoch: [1][7200/17702]	Batch Time 0.984 (1.083)	Data Load Time 0.328 (0.368)	Loss 3.6055 (3.3795)	Top-5 Accuracy 69.359 (73.904)
Epoch: [1][7400/17702]	Batch Time 1.000 (1.083)	Data Load Time 0.359 (0.369)	Loss 3.3419 (3.3792)	Top-5 Accuracy 73.926 (73.910)
Epoch: [1][7600/17702]	Batch Time 0.922 (1.083)	Data Load Time 0.266 (0.368)	Loss 3.1014 (3.3783)	Top-5 Accuracy 79.037 (73.923)
Epoch: [1][7800/17702]	Batch Time 1.125 (1.082)	Data Load Time 0.281 (0.368)	Loss 3.3390 (3.3774)	Top-5 Accuracy 74.615 (73.935)
Epoch: [1][8000/17702]	Batch Time 0.969 (1.082)	Data Load Time 0.312 (0.368)	Loss 3.1932 (3.3774)

Epoch: [2][600/17702]	Batch Time 1.125 (1.109)	Data Load Time 0.437 (0.394)	Loss 3.1873 (3.2129)	Top-5 Accuracy 77.446 (76.180)
Epoch: [2][800/17702]	Batch Time 1.078 (1.107)	Data Load Time 0.359 (0.391)	Loss 2.9688 (3.2193)	Top-5 Accuracy 80.863 (76.102)
Epoch: [2][1000/17702]	Batch Time 1.297 (1.106)	Data Load Time 0.391 (0.390)	Loss 3.2531 (3.2198)	Top-5 Accuracy 77.507 (76.067)
Epoch: [2][1200/17702]	Batch Time 1.094 (1.106)	Data Load Time 0.437 (0.390)	Loss 2.9060 (3.2192)	Top-5 Accuracy 79.730 (76.092)
Epoch: [2][1400/17702]	Batch Time 1.094 (1.107)	Data Load Time 0.422 (0.392)	Loss 3.3950 (3.2224)	Top-5 Accuracy 74.859 (76.078)
Epoch: [2][1600/17702]	Batch Time 1.234 (1.107)	Data Load Time 0.469 (0.392)	Loss 3.1844 (3.2231)	Top-5 Accuracy 77.454 (76.046)
Epoch: [2][1800/17702]	Batch Time 1.047 (1.106)	Data Load Time 0.375 (0.390)	Loss 3.1322 (3.2228)	Top-5 Accuracy 74.085 (76.067)
Epoch: [2][2000/17702]	Batch Time 1.078 (1.104)	Data Load Time 0.375 (0.388)	Loss 3.3495 (3.2228)	T

Epoch: [2][13200/17702]	Batch Time 1.250 (1.082)	Data Load Time 0.391 (0.366)	Loss 3.6423 (3.2207)	Top-5 Accuracy 71.693 (76.162)
Epoch: [2][13400/17702]	Batch Time 1.047 (1.082)	Data Load Time 0.359 (0.366)	Loss 3.1192 (3.2205)	Top-5 Accuracy 77.557 (76.165)
Epoch: [2][13600/17702]	Batch Time 0.953 (1.082)	Data Load Time 0.312 (0.366)	Loss 3.0581 (3.2202)	Top-5 Accuracy 77.437 (76.171)
Epoch: [2][13800/17702]	Batch Time 1.125 (1.082)	Data Load Time 0.375 (0.366)	Loss 3.3854 (3.2202)	Top-5 Accuracy 73.077 (76.171)
Epoch: [2][14000/17702]	Batch Time 1.187 (1.082)	Data Load Time 0.359 (0.366)	Loss 3.2290 (3.2202)	Top-5 Accuracy 75.833 (76.172)
Epoch: [2][14200/17702]	Batch Time 1.016 (1.082)	Data Load Time 0.328 (0.366)	Loss 3.2873 (3.2203)	Top-5 Accuracy 76.454 (76.173)
Epoch: [2][14400/17702]	Batch Time 1.031 (1.082)	Data Load Time 0.312 (0.365)	Loss 3.2381 (3.2205)	Top-5 Accuracy 74.788 (76.174)
Epoch: [2][14600/17702]	Batch Time 0.953 (1.082)	Data Load Time 0.281 (0.365)	Loss 3.2307 