In [1]:
import os
import time
import json
import h5py
import numpy as np
from six.moves import cPickle
from nltk.translate.bleu_score import corpus_bleu

import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pack_padded_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau

https://github.com/incredible-vision/show-and-tell/blob/master/train.py

https://github.com/muggin/show-and-tell/blob/master/models.py

https://gist.github.com/williamFalcon/f27c7b90e34b4ba88ced042d9ef33edd

# Helpers

In [2]:
class AverageMeter(object):
    """
    Computes and stores the average and current value of some metric.
    
    Reference: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# MobileNet V2 (pre-trained)

In [3]:
"""
Creates a MobileNetV2 model as defined in the paper: M. Sandler, 
A. Howard, M. Zhu, A. Zhmoginov, L.-C. Chen. "MobileNetV2: Inverted 
Residuals and Linear Bottlenecks.", arXiv:1801.04381, 2018."

Code reference: https://github.com/tonylins/pytorch-mobilenet-v2
ImageNet pretrained weights: https://drive.google.com/file/d/1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR
"""
import math
import torch
import torch.nn as nn



def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
def MobileNet(pretrained=True, **kwargs):
    """
    Constructs a MobileNet V2 model.
    
    Parameters
    ----------
    pretrained: bool, use ImageNet pretrained model or not.
    n_class: int, 1000 classes in ImageNet data.
    weight_file: str, path to pretrained weights
    """
    weight_file = kwargs.pop('weight_file', '')
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = torch.load(weight_file)
        model.load_state_dict(state_dict)
    return model

# Encoder

In [4]:
class EncoderCNN(nn.Module):
    """
    Convolutional Neural Network (MobileNetV2) that encodes input image 
    into encoded feature representations.
    """
    def __init__(self, weight_file, feature_size=14, tune_layer=15, finetune=True):
        """
        Parameters
        ----------
        weight_file: str, path to MobileNetV2 pretrained weights.
        feature_size: int, encoded feature map size to be used.
        tune_layer: int, tune layers from this layer onwards. For
            MobileNetV2 select integer from 0 (early) to 18 (final)
        finetune: bool, fine tune layers
        """
        super(EncoderCNN, self).__init__()
        self.weight_file = weight_file
        self.feature_size = feature_size
        self.tune_layer = tune_layer
        self.finetune = finetune
        self.pretrained = True
        
        # MobileNetV2 pretrained on ImageNet
        cnn = MobileNet(pretrained=self.pretrained, weight_file=self.weight_file)
        
        # Remove classification layer
        modules = list(cnn.children())[:-1]
        self.cnn = nn.Sequential(*modules)
        
        # Resize feature maps to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((self.feature_size, self.feature_size))
        
        # Fine-tune
        self.fine_tune()
        
    def forward(self, images):
        """
        Parameters
        ----------
        images: PyTorch tensor, size: [M, 3, H, W]
        """
        features = self.cnn(images) # size: [M, 1280, H/32, W/32]
        features = self.adaptive_pool(features) # size: [M, 1280, fs, fs]
        features = features.permute(0, 2, 3, 1) # size: [M, fs, fs, 1280]
        return features
    
    def fine_tune(self):
        """
        Fine-tuning CNN.
        """
        # Disable gradient computation
        for param in self.cnn.parameters():
            param.requires_grad = False
            
        # Enable gradient computation for few layers
        for child in list(self.cnn.children())[0][self.tune_layer:]:
            for param in child.parameters():
                param.requires_grad = self.finetune

# Attention Mechanism

In [5]:
class AttentionMechanism(nn.Module):
    """
    Attention Mechanism.
    """
    def __init__(self, encoder_size, decoder_size, attention_size):
        """
        Parameters
        ----------
        encoder_size: int, number of channels in encoder CNN output feature
            map (for MobileNetV2 it is 1280)
        decoder_size: int, number of features in the hidden state, i.e. LSTM 
            output size
        attention_size: int, size of MLP used to compute attention scores
        """
        super(AttentionMechanism, self).__init__()
        self.encoder_size = encoder_size
        self.decoder_size = decoder_size
        self.attention_size = attention_size
        
        # Linear layer to transform encoded features to attention size
        self.encoder_attn = nn.Linear(in_features=self.encoder_size, 
                                      out_features=self.attention_size)
        
        # Linear layer to transform decoders (hidden state) output to attention size
        self.decoder_attn = nn.Linear(in_features=self.decoder_size, 
                                      out_features=self.attention_size)
        
        # ReLU non-linearity
        self.relu = nn.ReLU()
        
        # Linear layer to compute attention scores at time t for L locations
        self.fc_attn = nn.Linear(in_features=self.attention_size, out_features=1)
        
        # Softmax layer to compute attention weights
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, encoder_out, decoder_out):
        """
        Parameters
        ----------
        encoder_out: PyTorch tensor, size: [M, L, D] where, L is feature
            map locations, and D is channels of encoded CNN feature map.
        decoder_out: PyTorch tensor, size: [M, h], where h is hidden
            dimension of the previous step output from decoder
            
        NOTE: M is batch size. k is attention size (see comments)
        
        Returns
        -------
        attn_weighted_encoding: PyTorch tensor, size: [M, D], attention weighted 
            annotation vector
        alpha: PyTorch tensor, size: [M, L], attention weights 
        """
        enc_attn = self.encoder_attn(encoder_out)  # size: [M, L, k]
        dec_attn = self.decoder_attn(decoder_out)  # size: [M, k]
        
        enc_dec_sum = enc_attn + dec_attn.unsqueeze(1)  # size: [M, L, k]
        
        # Compute attention scores for L locations at time t (Paper eq: 4)
        attn_scores = self.fc_attn(self.relu(enc_dec_sum))  # size: [M, L]
        
        # Compute for each location the probability that location i is the right 
        # place to focus for generating next word (Paper eq: 5)
        alpha = self.softmax(attn_scores.squeeze(2))  # size: [M, L]
        
        # Compute attention weighted annotation vector (Paper eq: 6)
        attn_weighted_encoding = torch.sum(encoder_out * alpha.unsqueeze(2), dim=1) # size: [M, D]
        
        return attn_weighted_encoding, alpha

# Decoder (w/ Attention)

In [6]:
class DecoderAttentionRNN(nn.Module):
    """
    RNN (LSTM) decoder to decode encoded images and generate sequences.
    """
    def __init__(self, encoder_size, decoder_size, attention_size, embedding_size, vocab_size, dropout_prob=0.5):
        """
        encoder_size: int, number of channels in encoder CNN output feature
            map (for MobileNetV2 it is 1280)
        decoder_size: int, number of features in the hidden state, i.e. LSTM 
            output size
        attention_size: int, size of MLP used to compute attention scores
        embedding_size: int, size of embedding
        vocab_size: int, vocabulary size
        dropout: float, dropout probability
        """
        super(DecoderAttentionRNN, self).__init__()
        self.encoder_size = encoder_size
        self.decoder_size = decoder_size
        self.attention_size = attention_size
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.dropout_prob = dropout_prob
        
        # Create attention mechanism
        self.attention = AttentionMechanism(self.encoder_size, self.decoder_size, self.attention_size)
        
        # Create embedding layer
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)  # size: [V, E]
        
        # Create dropout module
        self.dropout = nn.Dropout(p=self.dropout_prob)
        
        # Create LSTM cell (uses for loop) for decoding
        self.rnn = nn.LSTMCell(input_size=self.embedding_size + self.encoder_size, 
                               hidden_size=self.decoder_size, bias=True)
        
        # MLPs for LSTM cell's initial states
        self.init_h = nn.Linear(self.encoder_size, self.decoder_size)
        self.init_c = nn.Linear(self.encoder_size, self.decoder_size)
        
        # MLP to compute beta (gating scalar, paper section 4.2.1)
        self.f_beta = nn.Linear(self.decoder_size, 1) # scalar
        
        # Sigmoid to compute beta
        self.sigmoid = nn.Sigmoid()
        
        # FC layer to compute scores over vocabulary
        self.fc = nn.Linear(self.decoder_size, self.vocab_size)
        
        # Initialize embedding and FC layer weights
        self.init_weights()
        
    def init_weights(self):
        """
        Ref: https://github.com/ruotianluo/ImageCaptioning.pytorch/blob/master/models/Att2inModel.py
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        
    def init_lstm_states(self, encoder_out):
        """
        Initialize LSTM's initial hidden and cell memory states based on encoded
        feature representation. NOTE: Encoded feature map locations mean is used.
        """
        # Compute mean of encoder output locations
        mean_encoder_out = torch.mean(encoder_out, dim=1)  # size: [M, L, D] -> [M, D]
        
        # Initialize LSTMs hidden state
        h0 = self.init_h(mean_encoder_out)  # size: [M, h]
        
        # Initialize LSTMs cell memory state
        c0 = self.init_c(mean_encoder_out)  # size: [M, h]
        
        return h0, c0
    
    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Parameters
        ----------
        encoder_out: PyTorch tensor, size: [M, fs, fs, D] where, fs is feature
            map size, and D is channels of encoded CNN feature map.
        encoded_captions: PyTorch long tensor
        caption_lengths: PyTorch tensor
        """
        batch_size = encoder_out.size(0)
        
        # Flatten encoded feature maps from size [M, fs, fs, D] to [M, L, D]
        encoder_out = encoder_out.view(batch_size, -1, self.encoder_size)
        num_locations = encoder_out.size(1)
        
        # Sort caption lengths in descending order
        caption_lengths, sorted_idx = torch.sort(caption_lengths.squeeze(1), dim=0, 
                                                 descending=True)
        
        # Compute decode lengths to decode. Sequence generation ends when <END> token
        # is generated. A typical caption is [<START>, ..., <END>, <PAD>, <PAD>], caption
        # lengths only considers [<START>, ..., <END>], so when <END> is generated there
        # is no need to decode further. Decode lengths = caption lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()
        
        # Sort encoded feature maps and captions as per caption lengths. REASON: Since a 
        # batch contains different caption lengths (and decode lengths). At each time step 
        # up to max decode length T in a batch we need to apply attention mechanism to only 
        # those images in batch whose decode length is greater than current time step
        encoder_out = encoder_out[sorted_idx]
        encoded_captions = encoded_captions[sorted_idx]
        
        # Get embeddings for encoded captions
        embeddings = self.embedding(encoded_captions) # size: [M, T, E]
        
        # Initialize LSTM's states
        h, c = self.init_lstm_states(encoder_out) # sizes: [M, h], [M, h]
        
        # Compute max decode length
        T = int(max(decode_lengths))
        
        # Create placeholders to store predicted scores and alphas (alphas for doubly stochastic attention)
        pred_scores = torch.zeros(batch_size, T, self.vocab_size) # size: [M, T, V]
        alphas = torch.zeros(batch_size, T, num_locations) # size: [M, T, L]
        
        # Decoding step: (1) Compute attention weighted encoding and attention weights
        # using encoder output, and initial hidden state; (2) Generate a new encoded word
        for t in range(T):
            # Compute batch size at step t (At step t how many decoding lengths are greater than t)
            batch_size_t = sum([dl > t for dl in decode_lengths])
            
            # Encoder output and encoded captions are already sorted by caption lengths
            # in descending order. So based on the number of decoding lengths that are 
            # greater than current t, extract data from encoded output and initial hidden state
            # as input to attention mechanism. 
            attn_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                           h[:batch_size_t])
                        
            # Compute gating scalar beta (paper section: 4.2.1)
            beta_t = self.sigmoid(self.f_beta(h[:batch_size_t])) # size: [M, 1]
                        
            # Multiply gating scalar beta to attention weighted encoding
            context_vector = beta_t * attn_weighted_encoding  # size: [M, D]
                        
            # Concatenate embeddings and context vector, size: [M, E] and [M, D] -> [M, E+D]
            concat_input = torch.cat([embeddings[:batch_size_t, t, :], context_vector], dim=1) # size: [M, E+D]
                        
            # LSTM input states from time step t-1
            previous_states = (h[:batch_size_t], c[:batch_size_t])
                        
            # Generate decoded word
            h, c = self.rnn(concat_input, previous_states)
            
            # Compute scores over vocabulary
            scores = self.fc(self.dropout(h)) # size: [M, V]
            
            # Populate placeholders for predicted scores and alphas
            pred_scores[:batch_size_t, t, :] = scores
            alphas[:batch_size_t, t, :] = alpha # alpha size: [M, L]
            
        return pred_scores, encoded_captions, decode_lengths, alphas, sorted_idx

# COCO Dataset

In [7]:
import os
import json
import h5py
import numpy as np

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class COCODataset(Dataset):
    """
    COCO Dataset to be used in DataLoader for creating batches 
    during training.
    """
    def __init__(self, config, split='TRAIN', transform=None):
        self.config = config
        self.split = split
        self.transform = transform
        
        # Open files where images are stored in HDF5 data fromat, captions & their lengths
        if self.split == 'TRAIN':
            self.hdf5 = h5py.File(name=self.config.train_hdf5, mode='r')
            self.captions = self.read_json(self.config.train_captions)
        else:
            self.hdf5 = h5py.File(name=self.config.val_hdf5, mode='r')
            self.captions = self.read_json(self.config.val_captions)
            
        # Get image data
        self.images = self.hdf5['images']
                    
    def read_json(self, json_path):
        with open(json_path, 'r') as j:
            json_data = json.load(j)
        return json_data
        
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):
        img = torch.FloatTensor(self.images[idx])
        if self.transform is not None:
            img = self.transform(img)
         
        # There are 5 captions so randomly sample 1 caption
        cap_idx = np.random.randint(0, high=5)
        caption = torch.LongTensor(self.captions[idx][0][cap_idx])
        length = torch.LongTensor([self.captions[idx][1][cap_idx]])
        
        if self.split == 'TRAIN':
            return img, caption, length
        else:
            captions = torch.LongTensor(self.captions[idx][0])
            return img, caption, length, captions

# Config

In [18]:
class Config(object):
    def __init__(self):
        # Encoder parameters
        # ------------------
        self.cnn_weight_file = './mobilenet_v2.pth.tar'
        self.feature_size = 14
        self.tune_layer = 15
        self.finetune = True
        
        # Normalizing constants
        # ---------------------
        self.img_mean = [0.485, 0.456, 0.406]
        self.img_std = [0.229, 0.224, 0.225]
        
        # Decoder parameters
        # ------------------
        self.encoder_size = 1280  # MobileNetV2 output channels (do not change!) 2048 for ResNet
        self.decoder_size = 512  # LSTM output size (hidden state vector size)
        self.attention_size = 512  # Size of MLP used to compute attention scores
        self.embedding_size = 256  # Word embedding size
        self.dropout_prob = 0.5
        
        # Training config
        # ---------------
        self.use_gpu = True
        self.batch_size = 64
        self.start_epoch = 0
        self.num_epochs = 12
        self.encoder_lr = 0.0001 # Learning rate for encoder
        self.decoder_lr = 0.001  # Learning rate for decoder
        self.lr_multiplier = 0.9 # Learning rate decay
        self.alpha_c = 1.0
        self.clip_value = 5.0
        self.k = 5 # Top k accuracy
        self.device_id = 0 # select 1 or 2
        self.device = 'cuda:' + str(self.device_id) 
        self.best_bleu = 0
        
        # Word to index mapping
        # ---------------------
        self.word2idx_file = './WORD2IDX_COCO.json'
        
        # Training data
        # -------------
        self.train_hdf5 = './TRAIN_IMAGES_COCO.hdf5'
        self.train_captions = './TRAIN_CAPTIONS_COCO.json'
        
        # Validation data
        self.val_hdf5 = './VAL_IMAGES_COCO.hdf5'
        self.val_captions = './VAL_CAPTIONS_COCO.json'
        
        # Terminal display
        # ----------------
        self.display_interval = 100
        
        # Checkpoint config
        # -----------------
        self.start_epoch = 0
        self.start_from = 10 # Use None if training from epoch 0
        self.checkpoint_path = './checkpoints'
        self.load_best_model = False
        
config = Config()

# Trainer

In [19]:
class Trainer(object):
    def __init__(self, opt):
        self.opt = opt
        self.word2idx = self.read_json(self.opt.word2idx_file)
        self.vocab_size = len(self.word2idx)
        
        # Start training
        self.start()
        
    # Helpers
    def read_json(self, file):
        with open(file, 'r') as f:
            data = json.load(f)
        return data
    
    @staticmethod
    def get_optimizer(opt, net, coder='decoder'):
        if coder == 'decoder':
            lr = opt.decoder_lr
        else:
            lr = opt.encoder_lr
            
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, net.parameters()), lr=lr)
        
        return optimizer
    
    @staticmethod
    def decay_learning_rate(optimizer, lr_multiplier):
        """
        Decays learning rate by a multiplier.
        
        optimizer: PyTorch optim object
        lr_multiplier: float value in range (0, 1)
        """
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * lr_multiplier
        print('Learning rate has been reduced!')
        
    @staticmethod
    def clip_gradient(optimizer, clip_value):
        """
        Clip gradients computed during back propagation (to solve exploding
        gradients)
        """
        for group in optimizer.param_groups:
            for param in group['params']:
                if param.grad is not None:
                    param.grad.data.clamp_(-clip_value, clip_value)
    
    @staticmethod
    def top_k_accuracy(scores, targets, k):
        """
        scores and targets are PyTorch tensors, k is int.
        """
        num_elements = targets.numel()
        
        # Get indices of the k largest elements
        _, topk_idx = scores.data.topk(k, dim=1) # size: [num_elements, k]
        
        # Compute element wise equality
        correct = torch.eq(topk_idx, targets.view(-1, 1).cpu()) # targets size: [num_elements]
        
        # Total correct
        tot_correct = torch.sum(correct)
        
        return tot_correct.float().item() * 100.0 / num_elements
    
    @staticmethod
    def prepare_bleu_data(captions, sorted_idx, scores, decode_lengths, word2idx):
        temp_references = []
        temp_hypotheses = []
        
        # Prepare y_true i.e. references for BLEU
        captions = captions[sorted_idx] # Sort captions based on sorted indices from decoder
        remove_idx = [word2idx['<START>'], word2idx['<PAD>']]
        for c in range(captions.size(0)):
            img_caps = captions[c].tolist()
            # Remove indices corresponding to <START> and <PAD>
            img_caps = [[ix for ix in cap if ix not in remove_idx] for cap in img_caps]
            temp_references.append(img_caps)
            
        # Prepare y_pred i.e. hypotheses for BLEU
        scores_clone = scores.clone()
        _, preds = torch.max(scores_clone, dim=2) # Get indixes of words with max score
        preds = preds.tolist() # Convert PyTorch tensor to list
        for i, pred in enumerate(preds):
            img_hyp = preds[i][:decode_lengths[i]]
            temp_hypotheses.append(img_hyp)
            
        return temp_references, temp_hypotheses
    
    def create_model(self):
        info = {}

        # Encoder and its optimizer
        encoder = EncoderCNN(weight_file=self.opt.cnn_weight_file, 
                             tune_layer=self.opt.tune_layer, 
                             finetune=self.opt.finetune)

        encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 
                                             lr=self.opt.encoder_lr) if self.opt.finetune else None

        encoder_optimizer = self.get_optimizer(self.opt, encoder, coder='encoder') if self.opt.finetune else None

        # Decoder and its optimizer
        decoder = DecoderAttentionRNN(encoder_size=self.opt.encoder_size, 
                                      decoder_size=self.opt.decoder_size, 
                                      attention_size=self.opt.attention_size, 
                                      embedding_size=self.opt.embedding_size, 
                                      vocab_size=self.vocab_size, 
                                      dropout_prob=self.opt.dropout_prob)

        decoder_optimizer = self.get_optimizer(self.opt, decoder)

        if self.opt.start_from:
            if self.opt.load_best_model == 1:
                model_path = os.path.join(self.opt.checkpoint_path, 'MobileNetV2_Show_Attend_Tell.pth.tar')
            else:
                epoch = self.opt.start_from
                model_path = os.path.join(self.opt.checkpoint_path, 
                                          'MobileNetV2_Show_Attend_Tell_{}.pth.tar'.format(epoch))

            # Load checkpoint
            checkpoint = torch.load(model_path)
            info['epoch'] = checkpoint['epoch'] + 1
            info['epochs_since_improvement'] = checkpoint['epochs_since_improvement']
            info['best_bleu'] = checkpoint['best_bleu']

            # Load state dicts for encoder, decoder, and their optimizers
            encoder.load_state_dict(checkpoint['encoder'])
            decoder.load_state_dict(checkpoint['decoder'])
            decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer'])
            
            # Problem: torch.cuda.FloatTensor (Reference: https://github.com/pytorch/pytorch/issues/2830)
            for state in decoder_optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(self.opt.device)
                            
            if encoder_optimizer and checkpoint['encoder_optimizer']:
                encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer'])
                
                # Reference: https://github.com/pytorch/pytorch/issues/2830
                for state in encoder_optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.to(self.opt.device)

        return encoder, decoder, encoder_optimizer, decoder_optimizer, info
    
    def save_checkpoint(self, epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, 
                        best_bleu, best_flag=False):
        if not os.path.exists(self.opt.checkpoint_path):
            os.makedirs(self.opt.checkpoint_path)
            
        checkpoint_name = 'MobileNetV2_Show_Attend_Tell_{}.pth.tar'.format(epoch)
            
        state = {
            'epoch': epoch,
            'epochs_since_improvement': epochs_since_improvement,
            'best_bleu': best_bleu,
            'encoder': encoder.state_dict(),
            'decoder': decoder.state_dict(),
            'encoder_optimizer': encoder_optimizer.state_dict() if self.opt.finetune else None,
            'decoder_optimizer': decoder_optimizer.state_dict()}

        torch.save(state, os.path.join(self.opt.checkpoint_path, checkpoint_name))
        
        if best_flag:
            best_checkpoint_name = 'MobileNetV2_Show_Attend_Tell.pth.tar'
            torch.save(state, os.path.join(self.opt.checkpoint_path, best_checkpoint_name))
            
    def train(self, train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
        # Display string
        display = """>>> step: {}/{} (epoch: {}), loss: {ls.val:f}, avg loss: {ls.avg:f}, 
        time/batch: {proc_time.val:.3f}, avg time/batch: {proc_time.avg:.3f}, top-5 acc: {acc.val:f}, 
        avg top-5 acc: {acc.avg:f}"""

        # Training mode
        encoder.train()
        decoder.train()

        # Stats
        batch_time = AverageMeter() # Forward propagation + back propatation time
        losses = AverageMeter() # Loss 
        top5_accs = AverageMeter() # Top-5 accuracy

        start = time.time()

        # Training loop for one epoch
        for i, (imgs, caps, cap_lengths) in enumerate(train_loader):

            # Using CUDA as default
            imgs = imgs.to(self.opt.device)
            encoded_caps = caps.to(self.opt.device)
            cap_lengths = cap_lengths.to(self.opt.device)

            # Forward pass
            encoder_out = encoder(imgs)
            pred_scores, sorted_caps, decode_lengths, alphas, sorted_idx = decoder(encoder_out, 
                                                                                   encoded_caps, 
                                                                                   cap_lengths)

            # Select all words after <START> till <END>
            target_caps = sorted_caps[:, 1:]

            # Pack padded sequences. Before computing Cross Entropy Loss (Log Softmax and Negative Log
            # Likelihood Loss) we do not want to take into account padded items in the predicted scores
            scores, _ = pack_padded_sequence(pred_scores, decode_lengths, batch_first=True)
            targets, _ = pack_padded_sequence(target_caps, decode_lengths, batch_first=True)

            # Compute loss
            loss = criterion(scores.to(self.opt.device), targets.to(self.opt.device))

            # Add doubly stochastic attention regularization
            loss += (self.opt.alpha_c * ((1.0 - alphas.sum(dim=1))**2).mean()).to(self.opt.device)

            # Backward propagation
            decoder_optimizer.zero_grad()
            if encoder_optimizer is not None:
                encoder_optimizer.zero_grad()

            loss.backward()

            # Clip gradients
            if self.opt.clip_value is not None:
                self.clip_gradient(decoder_optimizer, self.opt.clip_value)
                if encoder_optimizer is not None:
                    self.clip_gradient(encoder_optimizer, self.opt.clip_value)

            # Update weights
            decoder_optimizer.step()
            if encoder_optimizer is not None:
                encoder_optimizer.step()

            # Compute top accuracy for top k words
            top5_acc = self.top_k_accuracy(scores.data, targets.data, k=self.opt.k)

            # Update metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5_accs.update(top5_acc, sum(decode_lengths))
            batch_time.update(time.time() - start)
            start = time.time() # Restart timer

            if i % self.opt.display_interval == 0 and i != 0:
                print(display.format(i, len(train_loader), epoch, ls=losses, 
                                     proc_time=batch_time, acc=top5_accs))
                
    def validate(self, val_loader, encoder, decoder, criterion, epoch):
        # Display string
        display = """>>> step: {}/{} (epoch: {}), loss: {ls.val:f}, avg loss: {ls.avg:f}, 
        time/batch: {proc_time.val:.3f}, avg time/batch: {proc_time.avg:.3f}, top-5 acc: {acc.val:f}, 
        avg top-5 acc: {acc.avg:f}"""

        # Stats
        batch_time = AverageMeter() # Forward propagation
        losses = AverageMeter() # Loss 
        top5_accs = AverageMeter() # Top 5 accuracy

        # Evaluation mode
        encoder.eval()
        decoder.eval()

        # Caches for BLEU score computation
        references = []  # y_true
        hypotheses = []  # y_pres

        start = time.time()

        # Training loop for one epoch
        for i, (imgs, caps, cap_lengths, captions) in enumerate(val_loader):

            # Using CUDA as default
            imgs = imgs.to(self.opt.device)
            encoded_caps = caps.to(self.opt.device)
            cap_lengths = cap_lengths.to(self.opt.device)

            # Forward pass
            encoder_out = encoder(imgs)
            pred_scores, sorted_caps, decode_lengths, alphas, sorted_idx = decoder(encoder_out, 
                                                                                   encoded_caps, 
                                                                                   cap_lengths)

            pred_scores_copy = pred_scores.clone()

            # Select all words after <START> till <END>
            target_caps = sorted_caps[:, 1:]

            # Pack padded sequences. Before computing Cross Entropy Loss (Log Softmax and Negative Log
            # Likelihood Loss) we do not want to take into account padded items in the predicted scores
            scores, _ = pack_padded_sequence(pred_scores, decode_lengths, batch_first=True)
            targets, _ = pack_padded_sequence(target_caps, decode_lengths, batch_first=True)

            # Compute loss
            loss = criterion(scores.to(self.opt.device), targets.to(self.opt.device))

            # Add doubly stochastic attention regularization
            loss += (self.opt.alpha_c * ((1.0 - alphas.sum(dim=1))**2).mean()).to(self.opt.device)

            # Compute top accuracy for top k words
            top5_acc = self.top_k_accuracy(scores.data, targets.data, k=self.opt.k)

            # Update metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5_accs.update(top5_acc, sum(decode_lengths))
            batch_time.update(time.time() - start)
            start = time.time() # Restart timer

            if i % self.opt.display_interval == 0 and i != 0:
                print(display.format(i, len(val_loader), epoch, ls=losses, proc_time=batch_time, 
                                     acc=top5_accs))

            # Prepare data to compute BLEU score
            temp_refs, temp_hyps = self.prepare_bleu_data(captions, sorted_idx, pred_scores, decode_lengths, 
                                                          self.word2idx)
            assert len(temp_refs) == len(temp_hyps)

            # Exted the caches
            references.extend(temp_refs)
            hypotheses.extend(temp_hyps)

        # Compute BLEU score
        bleu = corpus_bleu(references, hypotheses, weights=(0.5, 0.5))
        show = '>>> epoch: {}, avg loss: {ls.avg:f}, avg top-5 acc: {acc.avg:f}, bleu: {bleu}'
        print(show.format(epoch, ls=losses, acc=top5_accs, bleu=bleu))

        return bleu
    
    def start(self):
        # Create model
        encoder, decoder, encoder_optimizer, decoder_optimizer, info = self.create_model()
        
        # Loss criterion
        criterion = nn.CrossEntropyLoss().to(self.opt.device)
        
        if self.opt.use_gpu:
            decoder = decoder.to(self.opt.device)
            encoder = encoder.to(self.opt.device)
            criterion = criterion.to(self.opt.device)
        
        # Normalize image
        normalize = transforms.Normalize(mean=self.opt.img_mean, std=self.opt.img_std)

        # Data loaders
        train_data = COCODataset(self.opt, split='TRAIN', transform=transforms.Compose([normalize]))
        train_loader = DataLoader(train_data, batch_size=self.opt.batch_size, shuffle=True)
        val_data = COCODataset(self.opt, split='VAL', transform=transforms.Compose([normalize]))
        val_loader = DataLoader(val_data, batch_size=self.opt.batch_size, shuffle=True)
        
        # Start training: Train for epochs
        epochs_since_improvement = info.get('epochs_since_improvement', 0)
        start_epoch = info.get('epoch', 0) if info.get('epoch', 0) else self.opt.start_epoch
        best_bleu = info.get('best_bleu', 0) if info.get('best_bleu', 0) else self.opt.best_bleu
        
        # Train for epochs
        for epoch in range(start_epoch, self.opt.num_epochs):

            if epochs_since_improvement > 0 and epochs_since_improvement % 10 == 0:
                self.decay_learning_rate(decoder_optimizer, self.opt.lr_multiplier)
                if self.opt.finetune:
                    self.decay_learning_rate(encoder_optimizer, self.opt.lr_multiplier)

            # One epoch training
            self.train(train_loader=train_loader, encoder=encoder, decoder=decoder, 
                       encoder_optimizer=encoder_optimizer, decoder_optimizer=decoder_optimizer,
                       criterion=criterion, epoch=epoch)

            # One epoch validation
            val_bleu = self.validate(val_loader=val_loader, encoder=encoder, decoder=decoder, 
                                     criterion=criterion, epoch=epoch)

            # Check for best bleu score
            best_flag = val_bleu > best_bleu
            best_bleu = max(val_bleu, best_bleu)
            if not best_flag:
                epochs_since_improvement += 1
                print('Number of epochs since last improvement: ', epochs_since_improvement)
            else:
                epochs_since_improvement = 0

            # Save checkpoint
            self.save_checkpoint(epoch, epochs_since_improvement, encoder, decoder, encoder_optimizer, 
                                 decoder_optimizer, best_bleu, best_flag=best_flag)

# Training Test

In [20]:
Trainer(config)

>>> step: 100/1771 (epoch: 11), loss: 3.415995, avg loss: 3.339985, 
        time/batch: 0.884, avg time/batch: 0.859, top-5 acc: 73.618785, 
        avg top-5 acc: 73.883323
>>> step: 200/1771 (epoch: 11), loss: 3.170666, avg loss: 3.351551, 
        time/batch: 0.848, avg time/batch: 0.857, top-5 acc: 74.576271, 
        avg top-5 acc: 73.692721
>>> step: 300/1771 (epoch: 11), loss: 3.427770, avg loss: 3.360301, 
        time/batch: 0.891, avg time/batch: 0.862, top-5 acc: 73.463687, 
        avg top-5 acc: 73.575842
>>> step: 400/1771 (epoch: 11), loss: 3.274088, avg loss: 3.364825, 
        time/batch: 0.818, avg time/batch: 0.866, top-5 acc: 74.831763, 
        avg top-5 acc: 73.512639
>>> step: 500/1771 (epoch: 11), loss: 3.571495, avg loss: 3.362202, 
        time/batch: 0.856, avg time/batch: 0.867, top-5 acc: 71.095890, 
        avg top-5 acc: 73.535365
>>> step: 600/1771 (epoch: 11), loss: 3.346618, avg loss: 3.359778, 
        time/batch: 0.984, avg time/batch: 0.870, top-5 

<__main__.Trainer at 0x7f1f7077ef60>

In [None]:
# https://discuss.pytorch.org/t/loading-a-saved-model-for-continue-training/17244/4