# About This Notebook

This notebook is based on https://www.kaggle.com/konradb/model-train-efficientnet & https://www.kaggle.com/konradb/model-infer-efficientnet, with a final score of 8.90 achieved in the BMS competition.

# Import Libraries

In [None]:
!pip install timm

In [None]:
import os
import re
import cv2
import gc
import timm
import time
import math
import torch
import random
import Levenshtein
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, Blur
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
import urllib.request
import ssl

ssl._create_default_https_context = ssl._create_unverified_context
response = urllib.request.urlopen('https://www.python.org')
print(response.read().decode('utf-8'))

# Read Input Data
> Import the train dataframe containing image IDs, InChI strings, their actual lengths and parsed sequences.

In [None]:
# read the input data contained in the pickle file saved previously
train_df = pd.read_pickle('../input/lstm-model/results-6/results-6/train.pkl')
train_df.head()

In [None]:
len(train_df)

# Add File Paths
> Make the process of reading the input data more efficient by storing paths to files in the train dataframe.

In [None]:
def get_file_path(image_id: str) -> str:
    
    """
    This method returns the paths to train images by indexing into the overall directory
    and the image_id's components.
    
    :param image_id: ID of the image
    :type  image_id: str
    :return:         path to image
    :rtype:          str
    """
    
    # index into original train images if '-' is not present
    if '-' not in image_id:
        return '../input/bms-molecular-translation/train/{}/{}/{}/{}.png'.format(
            image_id[0], image_id[1], image_id[2], image_id
        )
    
    # otherwise, it's an augmented image so need another indexing way
    else:
        return '../input/augmented-set-of-chemical-structures/0. augmentations/0. augmentations/{}.png'.format(
            image_id
        )

In [None]:
# get file paths
train_df['file_path'] = train_df['image_id'].apply(get_file_path)
train_df.to_csv('./train_df.csv')

In [None]:
# import the file back
train_df = pd.read_csv('./train_df.csv')

# display
train_df.head()

In [None]:
# limit to ±300K data sub-selected by size 200-350 (HxW)
valid_ids = pd.read_csv('../input/bmssmalldataset/new_dataset.csv')['image_id']
train_df  = train_df[train_df['image_id'].isin(valid_ids)]
print(train_df.shape)

In [None]:
train_df.reset_index(inplace=True)
train_df.head()

In [None]:
img_test = cv2.imread(train_df.loc[0, 'file_path'])
plt.imshow(img_test)

In [None]:
def change_fg_bg_colors(img: np.array) -> np.array:
    """Change foreground to white and background to black.
    
    :param img: image array
    :type  img: np.array
    :return:    image with reverted colors
    :rtype:     np.array
    """
    recolored_img = cv2.subtract(255, img) 
    
    return recolored_img

In [None]:
recolored = change_fg_bg_colors(img_test)
plt.imshow(recolored)

# Read the Tokenizer
> Import the string to index mapping for each InChI token, saved previously.

In [None]:
class Tokenizer(object):
    
    def __init__(self):
        # string to integer mapping
        self.stoi = {}
        # integer to string mapping
        self.itos = {}
    
    def __len__(self) -> None:
        
        """
        This method returns the length of token:index map.
        
        :return: length of map
        :rtype: int
        """
        # return the length of the map
        return len(self.stoi)
    
    def fit_on_texts(self, texts: list) -> None:
        
        """
        This method creates a vocabulary of all tokens contained in provided texts,
        and updates the mapping of token to index, and index to token.
        
        :param texts: list of texts
        :type texts:  list
        """
        
        # create a storage for all tokens
        vocab = set()
        
        # add tokens from each text to vocabulary
        for text in texts:
            vocab.update(text.split(' '))
            
        # sort the vocabulary in alphabetical order
        vocab = sorted(vocab)
        
        # add start, end and pad for sentence
        vocab.append('<sos>')
        vocab.append('<eos>')
        vocab.append('<pad>')
        
        # update the string to integer mapping, where integer is the index of the token
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        
        # reverse the previous vocabulary to create integer to string mapping
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        
    def text_to_sequence(self, text: str) -> list:
        
        """
        This method converts the given text to a list of its individual tokens,
        including start and end of string symbols.
        
        :param text: input textual data
        :type  text: str
        :return:     list of tokens
        :rtype:      list
        """
        
        # storage to append symbols to
        sequence = []
        
        # add the start of string symbol to storage
        sequence.append(self.stoi['<sos>'])
        
        # add each token in text to storage
        for s in text.split(' '):
            sequence.append(self.stoi[s])
            
        # add the end of string symbol to storage
        sequence.append(self.stoi['<eos>'])
        
        return sequence
    
    def texts_to_sequences(self, texts: list) -> list:
        
        """
        This method converts each text in the provided list into sequences of characters.
        Each sequence is appended to a list and the said list is returned.
        
        :param texts: a list of input texts
        :type  texts: list
        :return:      a list of sequences
        :rtype:       list
        """
        
        # storage to append sequences to
        sequences = []
        
        # for each text do
        for text in texts:
            # convert the text to a list of characters
            sequence = self.text_to_sequence(text)
            # append the lists of characters to an aggregated list storage
            sequences.append(sequence)

        return sequences
    
    def sequence_to_text(self, sequence: list) -> str:
        
        """
        This method converts the sequence of characters back into text.
        
        :param sequence: list of characters
        :type  sequence: list
        :return:         text
        :rtype:          str 
        """
        # join the characters with no space in between
        return ''.join(list(map(lambda i: self.itos[i], sequence)))
    
    def sequences_to_texts(self, sequences: list) -> list:
        
        """
        This method converts each provided sequence into text and returns all texts inside a list.
        
        :param sequences: list of character sequences
        :type  sequences: list
        :return:          list of texts
        :rtype:           list
        """
        
        # storage for texts
        texts = []
        
        # convert each sequence to text and append to storage
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)

        return texts
    
    def predict_caption(self, sequence: list) -> str:
        
        """
        This method predicts the caption by adding each symbol in sequence to a resulting string.
        This keeps happening up until the end of sentence or padding is met.
        
        :param sequence: list of characters
        :type  sequence: list
        :return:         image caption
        :rtype:          string
        """
        
        # storage for the final caption
        caption = ''
        
        # for each index in a sequence of symbols
        for i in sequence:
            # if symbol is the end of sentence or padding, break
            if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
                break
            # otherwise, add the symbol to the final caption
            caption += self.itos[i]
            
        return caption
    
    def predict_captions(self, sequences: list) -> list:
        
        """
        This method predicts the captions for each sequence in a list of sequences.
        
        :param sequences: list of sequences
        :type  sequences: list
        :return:          list of final image captions
        :rtype:           list
        """
        
        # storage for captions
        captions = []
        
        # for each sequence, do
        for sequence in sequences:
            
            # predict the caption per sequence
            caption = self.predict_caption(sequence)
            
            # append to the storage of captions
            captions.append(caption)
            
        return captions

# load the saved tokenizer and print its string to index mapping
tokenizer = torch.load('../input/lstm-model/results-6/results-6/tokenizer.pth')
print(tokenizer.stoi)

# Setup Configurations
> Set configurations needed for modelling and training in a separate class.

In [None]:
class CFG:
    
    """
    Set configurations for modelling and training.
    """
    
    debug       = False
    apex        = False
    max_len     = 275
    
    print_freq  = 10000
    num_workers = 4
    model_name  = 'efficientnet_b1'
    enc_size    = 1280
    samp_size   = 1000
    
    size        = 288
    
    scheduler   = 'CosineAnnealingLR'
    epochs      = 5
    T_max       = 4
    
    encoder_lr  = 1e-4
    decoder_lr  = 4e-4
    min_lr      = 1e-6
    
    batch_size   = 32
    weight_decay = 1e-6
    
    gradient_accumulation_steps = 1
    max_grad_norm               = 5
    
    attention_dim = 256
    embed_dim     = 256
    decoder_dim   = 512
    dropout       = 0.5
    seed          = 42
    n_fold        = 5
    trn_fold      = [1]
    train         = True
    
    prev_model = '../input/lstm-model/efficientnet_b1_fold1_best.pth'

In [None]:
# if in debug mode
if CFG.debug:
    
    # set number of epochs to 1
    CFG.epochs = 1
    
    # reduce the train set to a 1000 examples
    train_df   = train_df.sample(n=CFG.samp_size, random_state=CFG.seed).reset_index(drop=True)

# Utilities
> This is a set of utility functions used throughout the computations.

In [None]:
def get_score(y_true: str, y_pred: str) -> float:
    
    """
    This function computes the Levenstein distance between a true label and a prediction.
    This gets computed for all the provided data and an average score is then returned.
    
    :param y_true: true InChI label
    :type  y_true: str
    :param y_pred: predicted InChI label
    :type  y_pred: str
    :return:       average Levenstein score
    :rtype:        float
    """
    
    # storage for all Levenstein scores
    scores = []
    
    # for each (true label, predicted label) pair, do
    for true, pred in zip(y_true, y_pred):
        
        # find Levenstein distance for the pair and append to storage
        score = Levenshtein.distance(true, pred)
        scores.append(score)
    
    # compute mean Levenstein distance
    avg_score = np.mean(scores)
    
    return avg_score

In [None]:
def init_logger(log_file: str ='./train.log'):
    
    """
    Initialize the logger file for training.
    
    :param log_file: name of the logger file
    :type  log_file: str
    :return:         logger
    :rtype:          object
    """
    
    # make a reference to a logger instance
    logger = getLogger(__name__)
    
    # specify lowest-severity log message to be handled by the logger
    logger.setLevel(INFO)
    
    # send logging outputs to stream, i.e. sys.out, and format as message
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    
    # send logging outputs to a disk file and format as message
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    
    # add both handlers to the logger
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    
    return logger

# initialize the logger
LOGGER = init_logger()

In [None]:
def seed_torch(seed: int=42) -> None:
    
    """
    Seed torch with a specific seed number to ensure code consistency across runs.
    
    :param seed: seed number
    :type  seed: int
    """
    
    # set random seed
    random.seed(seed)
    
    # set environment seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # set numpy seed
    np.random.seed(seed)
    
    # set torch seeds
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
# seed torch with the seed from configs
seed_torch(seed=CFG.seed)

# Cross-Validation Split
> Here cross validation splits are created for the train dataframe.

In [None]:
# create a copy of the train dataframe to modify
folds = train_df.copy()

# provide train/validation indices to split data into train/validation sets
Fold  = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)

# for fold_number and (train_index, validation_index) in the splitted fold
for n, (train_idx, val_idx) in enumerate(Fold.split(folds, folds['InChI_length'])): # folds = separator, 
                                                                                    # len = maxsplit
    # assign fold number
    folds.loc[val_idx, 'fold'] = int(n)

# convert fold number to integer
folds['fold'] = folds['fold'].astype(int)

# print the size of each fold
print(folds.groupby(['fold']).size())

# Dataset
> Return the input data of the following format: (image, label tensor, label length). This is needed to ensure that the LSTM-Attention model is working as intended.

In [None]:
class TrainDataset(Dataset):
    
    """
    This class stores train dataset attributes and methods.
    """
    
    def __init__(self, df, tokenizer, transform=None):
        
        """
        Initialize train dataset attributes.
        
        :param df:        train dataframe
        :type  df:        pd.DataFrame
        :param tokenizer: string tokenizer
        :type tokenizer:  object
        :param transform: torch transformation
        :type transform:  object
        """
        
        # inherit from parent class
        super().__init__()
        
        # assign train dataframe
        self.df = df
        
        # assign tokenizer
        self.tokenizer = tokenizer
        
        # assign file paths
        self.file_paths = df['file_path'].values
        
        # assign tokenized labels
        self.labels     = df['InChI_text'].values
        
        # assign transformations
        self.transform  = transform
        
    def __len__(self) -> int:
        
        """
        Return size of the train dataframe.
        
        :return: length of the dataframe
        :rtype:  int
        """
        return len(self.df)
    
    def __getitem__(self, idx: int) -> tuple:
        
        """
        Get the image, tensor of its label and label length at the inputted index.
        
        :param idx: index of dataframe
        :type  idx: int
        :return:    image, tensored label and label length
        :rtype:     tuple 
        """
        
        # get file path of the indexed item
        file_path = self.file_paths[idx]
        
        # read in the image using the file path
        image     = cv2.imread(file_path)
        
        # revert black to white and white to black
        image     = change_fg_bg_colors(image)
        
        # convert the image to RGB and float type
        image     = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        
        # if transform is specified
        if self.transform:
            
            # augment the image with the transform
            augmented = self.transform(image=image)
            
            # update the image reference to point to the transformed image
            image     = augmented['image']
        
        # get the label of the indexed item
        label = self.labels[idx]
        
        # convert label to a sequence of symbols
        label = self.tokenizer.text_to_sequence(label)
        
        # get the length of the label and convert it to Tensor
        label_length = len(label)
        label_length = torch.LongTensor([label_length])
        
        return image, torch.LongTensor(label), label_length

In [None]:
class TestDataset(Dataset):
    
    """
    This class stores test dataset attributes and methods.
    """
    
    def __init__(self, df, transform=None):
        
        """
        Initialize test dataset attributes.
        
        :param df:        test dataframe
        :type  df:        pd.DataFrame
        :param tokenizer: string tokenizer
        :type tokenizer:  object
        :param transform: torch transformation
        :type transform:  object
        """
        
        # inherit from parent class
        super().__init__()
        
        # assign train dataframe
        self.df = df
        
        # assign file paths
        self.file_paths = df['file_path'].values
        
        # assign transformations
        self.transform  = transform
        
        # assign fixed transformations
        self.fix_transform = Compose([Transpose(p=1), VerticalFlip(p=1)])
        
    def __len__(self) -> int:
        
        """
        Return size of the train dataframe.
        
        :return: length of the dataframe
        :rtype:  int
        """
        return len(self.df)
    
    def __getitem__(self, idx: int) -> np.array:
        
        """
        Get the image, tensor of its label and label length at the inputted index.
        
        :param idx: index of dataframe
        :type  idx: int
        :return:    image
        :rtype:     array
        """
        
        # get file path of the indexed item
        file_path = self.file_paths[idx]
        
        # read in the image using the file path
        image     = cv2.imread(file_path)
        
        # revert black to white and white to black
        image     = change_fg_bg_colors(image)
        
        # convert the image to RGB and float type
        image     = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        
        # get image shape
        h, w, _ = image.shape
        
        # if height exceeds width, fix-transform the image
        if h > w:
            image = self.fix_transform(image=image)['image']
        
        # if transform is specified
        if self.transform:
            
            # augment the image with the transform
            augmented = self.transform(image=image)
            
            # update the image reference to point to the transformed image
            image     = augmented['image']

        
        return image

In [None]:
def bms_collate(batch: tuple) -> tuple:
    
    """
    Combine images, labels and label lengths per batch.
    
    :param batch: a collection of data points
    :type  batch: tuple
    :return:      stacked images, labels and label lengths, i.e. batch
    :rtype:       tuple
    """
    
    # initialize storages for images, labels and label lengths
    imgs, labels, label_lengths = [], [], []
    
    # for each data point, append image, labels and label lengths to respective storages
    for data_point in batch:
        imgs.append(data_point[0])
        labels.append(data_point[1])
        label_lengths.append(data_point[2])
    
    # pad each label sequence with the <pad> index value
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.stoi["<pad>"])
 
    return torch.stack(imgs), labels, torch.stack(label_lengths).reshape(-1, 1)

# Transformations
> Define basic torch transforms for the dataset, including resizing, normalizing and tensoring.

In [None]:
def get_transforms(*, data):
    
    """
    Compose several transforms together, mainly resizing, normalizing with Imagenet weights and tensoring.
    
    :param data: image data
    :type  data: np.array
    :return:     transformed image data
    :rtype:      np.array
    """
    if data == 'train':
        return Compose(
            [
                Resize(CFG.size, CFG.size),
                HorizontalFlip(p=0.5),
                Transpose(p=0.5),
                HorizontalFlip(p=0.5),
                VerticalFlip(p=0.5),
                ShiftScaleRotate(p=0.5),
                Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
                ToTensorV2(),
            ]
        )
    
    elif data == 'valid':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
# define train dataset from the train dataframe, using the tokenizer and transformations
train_ds = TrainDataset(train_df, tokenizer, transform=get_transforms(data='train'))

In [None]:
# display an example from the train dataset
for i in range(1):
    
    # get image, label and label length from the train dataset
    image, label, label_length = train_ds[i]
    # convert label sequence to text
    text = tokenizer.sequence_to_text(label.numpy())
    
    # transpose the image and show it
    plt.imshow(image.transpose(0,1).transpose(1,2))
    plt.title(f'label: {label} text: {text} label_length: {label_length}')
    plt.show()

# Modeling
> Create encoder, attention and decoder classes for modeling.

In [None]:
def _inflate(tensor, times, dim):
    # repeat_dims = [1] * tensor.dim()
    # repeat_dims[dim] = times
    # return tensor.repeat(*repeat_dims)
    return torch.repeat_interleave(tensor, times, dim)


class TopKDecoder(torch.nn.Module):
    r"""
    Top-K decoding with beam search.

    Args:
        decoder_rnn (DecoderRNN): An object of DecoderRNN used for decoding.
        k (int): Size of the beam.

    Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio
        - **inputs** (seq_len, batch, input_size): list of sequences, whose length is the batch size and within which
          each sequence is a list of token IDs.  It is used for teacher forcing when provided. (default is `None`)
        - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features
          in the hidden state `h` of encoder. Used as the initial hidden state of the decoder.
        - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder.
          Used for attention mechanism (default is `None`).
        - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state
          (default is `torch.nn.functional.log_softmax`).
        - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is
          drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value,
          teacher forcing would be used (default is 0).

    Outputs: decoder_outputs, decoder_hidden, ret_dict
        - **decoder_outputs** (batch): batch-length list of tensors with size (max_length, hidden_size) containing the
          outputs of the decoder.
        - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden
          state of the decoder.
        - **ret_dict**: dictionary containing additional information as follows {*length* : list of integers
          representing lengths of output sequences, *topk_length*: list of integers representing lengths of beam search
          sequences, *sequence* : list of sequences, where each sequence is a list of predicted token IDs,
          *topk_sequence* : list of beam search sequences, each beam is a list of token IDs, *inputs* : target
          outputs if provided for decoding}.
    """

    def __init__(self, decoder_rnn, k, decoder_dim, max_length, tokenizer):
        super(TopKDecoder, self).__init__()
        self.rnn = decoder_rnn
        self.k = k
        self.hidden_size = decoder_dim  # self.rnn.hidden_size
        self.V = len(tokenizer)
        self.SOS = tokenizer.stoi["<sos>"]
        self.EOS = tokenizer.stoi["<eos>"]
        self.max_length = max_length
        self.tokenizer = tokenizer

    def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, function=F.log_softmax,
                teacher_forcing_ratio=0, retain_output_probs=True):
        """
        Forward rnn for MAX_LENGTH steps.  Look at :func:`seq2seq.models.DecoderRNN.DecoderRNN.forward_rnn` for details.
        """

        # inputs, batch_size, max_length = self.rnn._validate_args(inputs, encoder_hidden, encoder_outputs,
        #                                                         function, teacher_forcing_ratio)

        batch_size = encoder_outputs.size(0)
        max_length = self.max_length

        self.pos_index = (torch.LongTensor(range(batch_size)) * self.k).view(-1, 1).cuda()

        # Inflate the initial hidden states to be of size: b*k x h
        # encoder_hidden = self.rnn._init_state(encoder_hidden)
        if encoder_hidden is None:
            hidden = None
        else:
            if isinstance(encoder_hidden, tuple):
                # hidden = tuple([_inflate(h, self.k, 1) for h in encoder_hidden])
                hidden = tuple([h.squeeze(0) for h in encoder_hidden])
                hidden = tuple([_inflate(h, self.k, 0) for h in hidden])
                hidden = tuple([h.unsqueeze(0) for h in hidden])
            else:
                # hidden = _inflate(encoder_hidden, self.k, 1)
                raise RuntimeError("Not supported")

        # ... same idea for encoder_outputs and decoder_outputs
        if True:  # self.rnn.use_attention:
            inflated_encoder_outputs = _inflate(encoder_outputs, self.k, 0)
        else:
            inflated_encoder_outputs = None

        # Initialize the scores; for the first step,
        # ignore the inflated copies to avoid duplicate entries in the top k
        sequence_scores = torch.Tensor(batch_size * self.k, 1)
        sequence_scores.fill_(-float('Inf'))
        sequence_scores.index_fill_(0, torch.LongTensor([i * self.k for i in range(0, batch_size)]), 0.0)
        sequence_scores = sequence_scores.cuda()

        # Initialize the input vector
        input_var = torch.transpose(torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1).cuda()

        # Store decisions for backtracking
        stored_outputs = list()
        stored_scores = list()
        stored_predecessors = list()
        stored_emitted_symbols = list()
        stored_hidden = list()

        for i in range(0, max_length):

            # Run the RNN one step forward
            log_softmax_output, hidden, _ = self.rnn.forward_step(input_var, hidden,
                                                                  inflated_encoder_outputs, function=function)
            # If doing local backprop (e.g. supervised training), retain the output layer
            if retain_output_probs:
                stored_outputs.append(log_softmax_output)

            # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1)
            sequence_scores = _inflate(sequence_scores, self.V, 1)
            sequence_scores += log_softmax_output.squeeze(1)
            scores, candidates = sequence_scores.view(batch_size, -1).topk(self.k, dim=1)

            # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
            input_var = (candidates % self.V).view(batch_size * self.k, 1)
            sequence_scores = scores.view(batch_size * self.k, 1)

            # Update fields for next timestep
            predecessors = (candidates // self.V + self.pos_index.expand_as(candidates)).view(batch_size * self.k, 1)
            if isinstance(hidden, tuple):
                hidden = tuple([h.index_select(1, predecessors.squeeze()) for h in hidden])
            else:
                hidden = hidden.index_select(1, predecessors.squeeze())

            # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded
            stored_scores.append(sequence_scores.clone())
            eos_indices = input_var.data.eq(self.EOS)
            if eos_indices.nonzero().dim() > 0:
                sequence_scores.data.masked_fill_(eos_indices, -float('inf'))

            # Cache results for backtracking
            stored_predecessors.append(predecessors)
            stored_emitted_symbols.append(input_var)
            stored_hidden.append(hidden)

        # Do backtracking to return the optimal values
        output, h_t, h_n, s, l, p = self._backtrack(stored_outputs, stored_hidden,
                                                    stored_predecessors, stored_emitted_symbols,
                                                    stored_scores, batch_size, self.hidden_size)

        # Build return objects
        decoder_outputs = [step[:, 0, :] for step in output]
        if isinstance(h_n, tuple):
            decoder_hidden = tuple([h[:, :, 0, :] for h in h_n])
        else:
            decoder_hidden = h_n[:, :, 0, :]
        metadata = {}
        metadata['inputs'] = inputs
        metadata['output'] = output
        metadata['h_t'] = h_t
        metadata['score'] = s
        metadata['topk_length'] = l
        metadata['topk_sequence'] = p
        metadata['length'] = [seq_len[0] for seq_len in l]
        metadata['sequence'] = [seq[0] for seq in p]
        return decoder_outputs, decoder_hidden, metadata

    def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hidden_size):
        """Backtracks over batch to generate optimal k-sequences.

        Args:
            nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network
            nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network
            predecessors [(batch*k)] * sequence_length: A Tensor of predecessors
            symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens
            scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1]
            b: Size of the batch
            hidden_size: Size of the hidden state

        Returns:
            output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences.

            score [batch, k]: A list containing the final scores for all top-k sequences

            length [batch, k]: A list specifying the length of each sequence in the top-k candidates

            p (batch, k, sequence_len): A Tensor containing predicted sequence
        """

        lstm = isinstance(nw_hidden[0], tuple)

        # initialize return variables given different types
        output = list()
        h_t = list()
        p = list()
        # Placeholder for last hidden state of top-k sequences.
        # If a (top-k) sequence ends early in decoding, `h_n` contains
        # its hidden state when it sees EOS.  Otherwise, `h_n` contains
        # the last hidden state of decoding.
        if lstm:
            state_size = nw_hidden[0][0].size()
            h_n = tuple([torch.zeros(state_size).cuda(), torch.zeros(state_size).cuda()])
        else:
            h_n = torch.zeros(nw_hidden[0].size()).cuda()
        l = [[self.max_length] * self.k for _ in range(b)]  # Placeholder for lengths of top-k sequences
        # Similar to `h_n`

        # the last step output of the beams are not sorted
        # thus they are sorted here
        sorted_score, sorted_idx = scores[-1].view(b, self.k).topk(self.k)
        # initialize the sequence scores with the sorted last step beam scores
        s = sorted_score.clone()

        batch_eos_found = [0] * b  # the number of EOS found
        # in the backward loop below for each batch

        t = self.max_length - 1
        # initialize the back pointer with the sorted order of the last step beams.
        # add self.pos_index for indexing variable with b*k as the first dimension.
        t_predecessors = (sorted_idx + self.pos_index.expand_as(sorted_idx)).view(b * self.k)
        while t >= 0:
            # Re-order the variables with the back pointer
            current_output = nw_output[t].index_select(0, t_predecessors)
            if lstm:
                current_hidden = tuple([h.index_select(1, t_predecessors) for h in nw_hidden[t]])
            else:
                current_hidden = nw_hidden[t].index_select(1, t_predecessors)
            current_symbol = symbols[t].index_select(0, t_predecessors)
            # Re-order the back pointer of the previous step with the back pointer of
            # the current step
            t_predecessors = predecessors[t].index_select(0, t_predecessors).squeeze()

            # This tricky block handles dropped sequences that see EOS earlier.
            # The basic idea is summarized below:
            #
            #   Terms:
            #       Ended sequences = sequences that see EOS early and dropped
            #       Survived sequences = sequences in the last step of the beams
            #
            #       Although the ended sequences are dropped during decoding,
            #   their generated symbols and complete backtracking information are still
            #   in the backtracking variables.
            #   For each batch, everytime we see an EOS in the backtracking process,
            #       1. If there is survived sequences in the return variables, replace
            #       the one with the lowest survived sequence score with the new ended
            #       sequences
            #       2. Otherwise, replace the ended sequence with the lowest sequence
            #       score with the new ended sequence
            #
            eos_indices = symbols[t].data.squeeze(1).eq(self.EOS).nonzero()
            if eos_indices.dim() > 0:
                for i in range(eos_indices.size(0) - 1, -1, -1):
                    # Indices of the EOS symbol for both variables
                    # with b*k as the first dimension, and b, k for
                    # the first two dimensions
                    idx = eos_indices[i]
                    b_idx = int(idx[0] // self.k)
                    # The indices of the replacing position
                    # according to the replacement strategy noted above
                    res_k_idx = self.k - (batch_eos_found[b_idx] % self.k) - 1
                    batch_eos_found[b_idx] += 1
                    res_idx = b_idx * self.k + res_k_idx

                    # Replace the old information in return variables
                    # with the new ended sequence information
                    t_predecessors[res_idx] = predecessors[t][idx[0]]
                    current_output[res_idx, :] = nw_output[t][idx[0], :]
                    if lstm:
                        current_hidden[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :]
                        current_hidden[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :]
                        h_n[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :].data
                        h_n[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :].data
                    else:
                        current_hidden[:, res_idx, :] = nw_hidden[t][:, idx[0], :]
                        h_n[:, res_idx, :] = nw_hidden[t][:, idx[0], :].data
                    current_symbol[res_idx, :] = symbols[t][idx[0]]
                    s[b_idx, res_k_idx] = scores[t][idx[0]].data[0]
                    l[b_idx][res_k_idx] = t + 1

            # record the back tracked results
            output.append(current_output)
            h_t.append(current_hidden)
            p.append(current_symbol)

            t -= 1

        # Sort and re-order again as the added ended sequences may change
        # the order (very unlikely)
        s, re_sorted_idx = s.topk(self.k)
        for b_idx in range(b):
            l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]]

        re_sorted_idx = (re_sorted_idx + self.pos_index.expand_as(re_sorted_idx)).view(b * self.k)

        # Reverse the sequences and re-order at the same time
        # It is reversed because the backtracking happens in reverse time order
        output = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(output)]
        p = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(p)]
        if lstm:
            h_t = [tuple([h.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for h in step]) for step in reversed(h_t)]
            h_n = tuple([h.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) for h in h_n])
        else:
            h_t = [step.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for step in reversed(h_t)]
            h_n = h_n.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size)
        s = s.data

        return output, h_t, h_n, s, l, p

    def _mask_symbol_scores(self, score, idx, masking_score=-float('inf')):
        score[idx] = masking_score

    def _mask(self, tensor, idx, dim=0, masking_score=-float('inf')):
        if len(idx.size()) > 0:
            indices = idx[:, 0]
            tensor.index_fill_(dim, indices, masking_score)

In [None]:
class Encoder(nn.Module):
    
    """
    Encodes the image with 3 color channels into a smaller learned image.
    """
    
    def __init__(self, model_name=CFG.model_name, pretrained=False):
        
        """
        Initialize the encoder with CNN equal to the chosen model and set pretrained parameter.
        
        :param model_name: name of the model
        :type  model_name: str
        :param pretrained: pretrained weights or not
        :type  pretrained: Boolean
        """
        
        # inherit attributes and methods from parent class
        super().__init__()
        
        # create a chosen CNN model
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        
    def forward(self, x):
        
        """
        Propagate the input forward.
        
        :param x: image data
        :type  x: np.array
        :return:  image features
        :rtype:   np.array 
        """
        
        # get batch size from image dimensions
        bs = x.size(0)
        
        # get image features from the CNN
        features = self.cnn.forward_features(x)
        
        # re-arrange the dimensions so that (bs, encoded_image_size, encoded_image_size, n_channels=2048)
        features = features.permute(0, 2, 3, 1)
        
        return features

In [None]:
class Attention(nn.Module):
    
    """
    Define Attention network to calculate the attention value.
    """
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        
        """
        Initialize the Attention network.
        
        :param encoder_dim:   input size of the encoder network
        :type  encoder_dim:   int
        :param decoder_dim:   input size of the decoder network
        :type  decoder_dim:   int
        :param attention_dim: input size of the attention network
        :type  attention_dim: int
        """
        
        # initiliaze and inherit from parent class
        super(Attention, self).__init__()
        
        # 1st linear layer to transform the encoded image
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        
        # 1st linear layer to transform the decoder's output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        
        # linear layer to calculate values to be softmax-ed
        self.full_att    = nn.Linear(attention_dim, 1)
        
        # define ReLU function
        self.relu        = nn.ReLU()
        
        # define the softmax layer to calculate weights
        self.softmax     = nn.Softmax(dim=1)
        
    def forward(self, encoder_out, decoder_hidden):
        
        """
        Propagate the inputs forward.
        
        :param encoder_out: encoded images of dimension (batch_size, num_pixels, encoder_dim)
        :type  encoder_out: tensor
        :param decoder_hidden: previous decoder output of dimension (batch_size, decoder_dim)
        :type  decoder_hidden: tensor
        :return: attention-weighted encoding, weights
        :rtype:  tensor of (batch_size, encoder_dim)
        """
        
        # apply the linear layer to encoded images to get (bs, num_pixels, attention_dim)
        att1 = self.encoder_att(encoder_out)
        
        # apply the linear layer to transform the decoder's output to get (bs, attention_dim)
        att2 = self.decoder_att(decoder_hidden)
        
        # pass the sum of transformed encoded images and decoder outputs through a ReLU and apply a linear layer
        # gets (bs, num_pixels)
        att  = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        
        # pass the results through a softmax layer to get weights
        alpha = self.softmax(att)
        
        # apply the resulting (bs, num_pixels) weights to the encoder output and sum across dim=1
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        
        return attention_weighted_encoding, alpha

In [None]:
class DecoderWithAttention(nn.Module):
    
    """
    Decoder network with attention network used for training.
    """
    
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim=CFG.enc_size, dropout=CFG.dropout):
        
        """
        Initialize the decoder with attention network.
        
        :param attention_dim: input size of the attention network
        :type  attention_dim: int
        :param embed_dim:     input size of the embedding network
        :type  embed_dim:     int
        :param decoder_dim:   input size of the decoder network
        :type  decoder_dim:   int
        :param vocab_size:    total number of characters used in training
        :type  vocab_size:    int
        :param encoder_dim:   input size of the encoder network
        :type  encoder_dim:   int
        :param dropout:       dropout rate
        :type  dropout:       float
        """
        
        # inherit from parent class
        super(DecoderWithAttention, self).__init__()
        
        # set dimensions of the encoder, attention, embedder and decoder
        self.encoder_dim   = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim     = embed_dim
        self.decoder_dim   = decoder_dim
        
        # set vocabulary size, dropout rate and the used device
        self.vocab_size    = vocab_size
        self.dropout       = dropout
        self.device        = device
        
        # set attention network, embedding network and dropout layer
        self.attention     = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding     = nn.Embedding(vocab_size, embed_dim)
        self.dropout       = nn.Dropout(p=self.dropout)
        
        # set the LSTM Cell decoder
        self.decode_step   = nn.LSTMCell(embed_dim+encoder_dim, decoder_dim, bias=True)
        
        # linear layer to find the initial hidden state of LSTM Cell
        self.init_h        = nn.Linear(encoder_dim, decoder_dim)
        
        # linear layer to find the initial cell state of LSTM Cell
        self.init_c        = nn.Linear(encoder_dim, decoder_dim)
        
        # linear layer to create a sigmoid-activated gate
        self.f_beta        = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid       = nn.Sigmoid()
        
        # linear layer to find scores over vocabulary
        self.fc            = nn.Linear(decoder_dim, vocab_size)
        
        # initialize some layers with uniform distribution
        self.init_weights()
        
    def init_weights(self):
        
        """
        Initialize weights with uniform distribution for embedding and FC layers.
        """
        
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        
    def load_pretrained_embeddings(self, embeddings):
        
        """
        Load the embedding layer with pre-trained embeddings.
        
        :param embeddings: pre-trained embeddings
        """
        
        self.embedding.weight = nn.Parameter(embeddings)
        
    def fine_tune_embeddings(self, fine_tune=True):
        
        """
        Allow fine-tuning of the embedding layer.
        
        :param fine_tune: allow fine-tuning
        :type  fine_tune: Boolean
        """
        # loop over each embedding parameter to set the fine-tuning option
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune
            
    def init_hidden_state(self, encoder_out):
        
        """
        Create the initial hidden and cell state for decoder's LSTM based on encoded images.
        
        :param encoder_out: encoded images, of size (bs, num_pixels, encoder_dim)
        :type  encoder_out: tensor
        :return:            initial hidden and cell states
        :rtype:             tuple
        
        """
        
        # get the mean of the encoded image's dim=1
        mean_encoder_out = encoder_out.mean(dim=1)
        
        # initialize the hidden and cell states
        h                = self.init_h(mean_encoder_out)
        c                = self.init_c(mean_encoder_out)
        
        return h, c
    
    def forward(self, encoder_out, encoded_captions, caption_lengths):
        
        """
        Propagate the inputs forward.
        
        :param encoder_out: encoded images of dimension (bs, num_pixels, encoder_dim)
        :type  encoder_out: tensor
        :param encoded_captions: encoded captions of dimension (bs, max_caption_length)
        :type  encoded_captions: tensor
        :param caption_lengths: caption lengths of dimension (bs, 1)
        :type  caption_lengths: tensor
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        :rtype:  tuple
        """
        batch_size  = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size  = self.vocab_size
        
        # flatten the image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels  = encoder_out.size(1)
        
        # sort input data by decreasing lengths
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        
        # embed the encoded captions
        embeddings = self.embedding(encoded_captions)
        
        # initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)
        
        # decoding lengths are actual lengths - 1 because of <end>
        decode_lengths = (caption_lengths - 1).tolist()
        
        # initialize tensors to hold word predictions scores and alphas (weights)
        predictions    = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(self.device)
        alphas         = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        
        # at each time step, decode by 
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l>t for l in decode_lengths])
            
            # attention weighted encoder's output based on decoder's previous state output
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            # generate a new word with previous word and attention weighted encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t])
            )
            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
            
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind
    
    def predict(self, encoder_out, decode_lengths, tokenizer):
        
        """
        Predict captions.
        
        :param encoder_out: encoded images of dimension (bs, num_pixels, encoder_dim)
        :type  encoder_out: tensor
        :param decode_lengths: lengths of decoded captions
        :type  decode_lengths: tensor
        :param tokenizer: word tokenizer
        :type  tokenizer: dict
        :return: predictions
        :rtype:  str
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        
        # flatten the image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels = encoder_out.size(1)
        
        # define start tokens and embed them
        start_tokens = torch.ones(batch_size, dtype=torch.long).to(self.device) * tokenizer.stoi["<sos>"]
        embeddings = self.embedding(start_tokens)
        
        # initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)
        
        # initialize tensors to hold predictions and final conditions
        predictions = torch.zeros(batch_size, decode_lengths, vocab_size).to(self.device)
        end_condition = torch.zeros(batch_size, dtype=torch.long).to(encoder_out.device)
        
        # at each time step decode by
        for t in range(decode_lengths):
            
            # applying attention to encoded image and decoder's previous state output
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            
            # generate a new word with previous word and attention weighted encoding
            h, c = self.decode_step(
                torch.cat([embeddings, attention_weighted_encoding], dim=1),
                (h, c))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:, t, :] = preds
            
            # if end of sentence is reached, stop predicting
            end_condition |= (torch.argmax(preds, -1) == tokenizer.stoi["<eos>"])
            if end_condition.sum() == batch_size:
                break
                
            # embed the predictions
            embeddings = self.embedding(torch.argmax(preds, -1))
            
        return predictions
    
    def forward_step(self, prev_tokens, hidden, encoder_out, function):
        assert len(hidden) == 2
        h, c = hidden
        h, c = h.squeeze(0), c.squeeze(0)

        embeddings = self.embedding(prev_tokens)
        if embeddings.dim() == 3:
            embeddings = embeddings.squeeze(1)

        attention_weighted_encoding, alpha = self.attention(encoder_out, h)
        gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
        attention_weighted_encoding = gate * attention_weighted_encoding
        h, c = self.decode_step(
            torch.cat([embeddings, attention_weighted_encoding], dim=1),
            (h, c))  # (batch_size_t, decoder_dim)
        preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)

        hidden = (h.unsqueeze(0), c.unsqueeze(0))
        predicted_softmax = function(preds, dim=1)
        return predicted_softmax, hidden, None

# Helper Functions
> This is a set of functions used as utilities.

In [None]:
class AverageMeter(object):
    
    """
    Computes and stores average and current values.
    """
    
    def __init__(self):
        
        """
        Reset settings.
        """
        self.reset()
    
    def reset(self):
        
        """
        Set current, average values, sum and count values to zero.
        """
        
        self.val   = 0
        self.avg   = 0
        self.sum   = 0
        self.count = 0
        
    def update(self, val, n=1):
        
        """
        Update current value, sum, count and average value.
        """
        
        self.val    = val
        self.sum   += val*n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def asMinutes(s):
    
    """
    Convert seconds to minutes.
    
    :param s: seconds
    :type  s: float
    """
    m  = math.floor(s/60)
    s -= m * 60
    return '%d %ds'% (m, s)

In [None]:
def timeSince(since, percent):
    
    """
    Calculate time since.
    
    :param since: previous date
    :type  since: time
    """
    
    now = time.time()
    s   = now - since
    es  = s / (percent)
    rs  = es - s
    
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
def train_fn(
    train_loader, encoder, decoder, 
    criterion, encoder_optimizer, decoder_optimizer, 
    epoch, encoder_scheduler, decoder_scheduler, device
):
    """
    Perform one epoch training.
    
    :param train_loader: data loader for training data
    :type  train_loader: DataLoader
    :param encoder:      encoder model
    :type  encoder:      Encoder
    :param decoder:      decoder model
    :type  decoder:      Decoder
    :param criterion:    loss layer
    :type criterion:     Loss 
    :param encoder_optimizer: optimizer for encoder
    :type  encoder_optimizer: Optimizer
    :param decoder_optimizer: optimizer for decoder
    :type  decoder_optimizer: Optimizer
    :param epoch:             Epoch number
    :type  epoch:             int
    :param encoder_scheduler: Encoder scheduler
    :type  encoder_scheduler: Encoder
    :param decoder_scheduler: Decoder scheduler
    :type  decoder_scheduler: Decoder
    :param device:            device selection
    :type  device:            Device
    :return:                  Average loss
    :rtype:                   float
    """
    
    batch_time = AverageMeter()
    data_time  = AverageMeter()
    losses     = AverageMeter()
    
    # switch to train mode
    encoder.train()
    decoder.train()
    start = end = time.time()
    global_step = 0
    
    # for index and inputs (imgs, labels and label lengths) in train dataset, do
    for step, (images, labels, label_lengths) in enumerate(train_loader):
        
        # update time step
        data_time.update(time.time() - end)
        
        # send inputs to device
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)
        
        # set batch size
        batch_size = images.size(0)
        
        # encode images to get features
        features = encoder(images)
        
        # return scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        predictions, caps_sorted, decode_lengths, alphas, sort_ind = decoder(features, labels, label_lengths)
        
        # specify true targets (true y's) after <start>
        targets = caps_sorted[:, 1:]
        
        # get predictions in a packed sequence
        predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True).data
        
        # get targets in a packed sequence
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data
        
        # compute loss between predictions and targets
        loss = criterion(predictions, targets)
        
        # record the loss
        losses.update(loss.item(), batch_size)
        
        # if optimizer steps are more than 1, divide loss by the number of those steps
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
            
        if CFG.apex:
            with amp.scale_loss(loss, decoder_optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            # backward propagation of loss
            loss.backward()
        
        # perform gradient clipping to avoid exploding gradients
        encoder_grad_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), CFG.max_grad_norm)
        decoder_grad_norm = torch.nn.utils.clip_grad_norm_(decoder.parameters(), CFG.max_grad_norm)
        
        # update weights
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            
            # take a step based on gradients
            encoder_optimizer.step()
            decoder_optimizer.step()
            
            # clear all gradients from last step
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            global_step += 1
            
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # print out the results per epoch
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Encoder Grad: {encoder_grad_norm:.4f}  '
                  'Decoder Grad: {decoder_grad_norm:.4f}  '
                  .format(
                   epoch+1, step, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   encoder_grad_norm=encoder_grad_norm,
                   decoder_grad_norm=decoder_grad_norm,
                   ))
            
    return losses.avg


def valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device):
    
    """
    Predict for validation set.
    
    :param valid_loader: Data Loader for validation data
    :type  valid_loader: DataLoader
    :param encoder:      Encoder
    :type  encoder:      Encoder
    :param decoder:      Decoder
    :type  decoder:      Decoder
    :param tokenizer:    Tokenizer
    :type  tokenizer:    Tokenizer
    :param criterion:    Loss
    :type  criterion:    Loss 
    :param device:       device selection
    :type  device:       Device
    :return:             predictions
    :rtype:              list
    """
    
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    # switch to evaluation mode
    encoder.eval()
    decoder.eval()
    
    # store predictions here
    text_preds = []
    
    start = end = time.time()
    
    # for each image in validation set
    for step, (images) in enumerate(valid_loader):
        
        # measure data loading time
        data_time.update(time.time() - end)
        
        # send images to device
        images = images.to(device)
        
        # specify batch size
        batch_size = images.size(0)
        
        # disable gradient calculation to avoid CUDA errors
        with torch.no_grad():
            
            # encode images
            features = encoder(images)
            
            # predict sequence using decoder
            predictions = decoder.predict(features, CFG.max_len, tokenizer)
            
        # choose the best predicted sequence
        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
        
        # predict captions from the predicted sequence and append to storage
        _text_preds = tokenizer.predict_captions(predicted_sequence)
        text_preds.append(_text_preds)
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        
        # print out the validation results
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
            
    # concatenate preds into one string
    text_preds = np.concatenate(text_preds)
    
    return text_preds

# Train Loop
> Performs training.

In [None]:
def train_loop(folds, fold):
    
    """
    Perform training in a loop.
    """

    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds['InChI'].values

    train_dataset = TrainDataset(train_folds, tokenizer, transform=get_transforms(data='train'))
    valid_dataset = TestDataset(valid_folds, transform=get_transforms(data='valid'))

    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=True, 
                              num_workers=CFG.num_workers, 
                              pin_memory=True,
                              drop_last=True, 
                              collate_fn=bms_collate)
    
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers,
                              pin_memory=True, 
                              drop_last=False)
    
    # ====================================================
    # scheduler 
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
        return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    states = torch.load(CFG.prev_model, map_location=torch.device('cpu'))
    encoder = Encoder(CFG.model_name, pretrained=True)
    encoder.load_state_dict(states['encoder'])
    encoder.to(device)
    encoder_optimizer = Adam(encoder.parameters(), lr=CFG.encoder_lr, weight_decay=CFG.weight_decay, amsgrad=False)
    encoder_optimizer.load_state_dict(states['encoder_optimizer'])
    encoder_scheduler = get_scheduler(encoder_optimizer)
    encoder_scheduler.load_state_dict(states['encoder_scheduler'])
    
    decoder = DecoderWithAttention(attention_dim=CFG.attention_dim,
                                   embed_dim=CFG.embed_dim,
                                   decoder_dim=CFG.decoder_dim,
                                   vocab_size=len(tokenizer),
                                   dropout=CFG.dropout,
                                   device=device)
    decoder.load_state_dict(states['decoder'])
    decoder.to(device)
    decoder_optimizer = Adam(decoder.parameters(), lr=CFG.decoder_lr, weight_decay=CFG.weight_decay, amsgrad=False)
    decoder_optimizer.load_state_dict(states['decoder_optimizer'])
    decoder_scheduler = get_scheduler(decoder_optimizer)
    decoder_scheduler.load_state_dict(states['decoder_scheduler'])

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])

    best_score = np.inf
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        avg_loss = train_fn(train_loader, encoder, decoder, criterion, 
                            encoder_optimizer, decoder_optimizer, epoch, 
                            encoder_scheduler, decoder_scheduler, device)

        # eval
        text_preds = valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device)
        text_preds = [f"InChI=1S/{text}" for text in text_preds]
        LOGGER.info(f"labels: {valid_labels[:5]}")
        LOGGER.info(f"preds: {text_preds[:5]}")
        
        # scoring
        score = get_score(valid_labels, text_preds)
        
        if isinstance(encoder_scheduler, ReduceLROnPlateau):
            encoder_scheduler.step(score)
        elif isinstance(encoder_scheduler, CosineAnnealingLR):
            encoder_scheduler.step()
        elif isinstance(encoder_scheduler, CosineAnnealingWarmRestarts):
            encoder_scheduler.step()
            
        if isinstance(decoder_scheduler, ReduceLROnPlateau):
            decoder_scheduler.step(score)
        elif isinstance(decoder_scheduler, CosineAnnealingLR):
            decoder_scheduler.step()
        elif isinstance(decoder_scheduler, CosineAnnealingWarmRestarts):
            decoder_scheduler.step()

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        
        if score < best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'encoder': encoder.state_dict(), 
                        'encoder_optimizer': encoder_optimizer.state_dict(), 
                        'encoder_scheduler': encoder_scheduler.state_dict(), 
                        'decoder': decoder.state_dict(), 
                        'decoder_optimizer': decoder_optimizer.state_dict(), 
                        'decoder_scheduler': decoder_scheduler.state_dict(), 
                        'text_preds': text_preds,
                       },
                        './'+f'{CFG.model_name}_fold{fold}_best.pth')

# Main Program - Train

In [None]:
def main():

    """
    Prepare: 1.train  2.folds
    """

    if CFG.train:
        # train
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                print(fold)
                train_loop(folds, fold)

In [None]:
# if __name__ == '__main__':
#     main()

# Testing (Inference)

In [None]:
# def get_test_file_path(image_id):

#     return "../input/bms-molecular-translation/test/" + "{}/{}/{}/{}.png".format(
#         image_id[0], image_id[1], image_id[2], image_id 
#     )

# test = pd.read_csv('../input/bms-molecular-translation/sample_submission.csv')
# test['file_path'] = test['image_id'].apply(get_test_file_path)

# print(f'test.shape: {test.shape}')

In [None]:
# def inference(test_loader, encoder, decoder, tokenizer, device):
#     encoder.eval()
#     decoder.eval()
#     text_preds = []
    
#     # k = 2
#     topk_decoder = TopKDecoder(decoder, 2, CFG.decoder_dim, CFG.max_len, tokenizer)
    
#     tk0 = tqdm(test_loader, total=len(test_loader))
#     for images in tk0:
#         images = images.to(device)
#         predictions = []
#         with torch.no_grad():
#             encoder_out = encoder(images)
#             batch_size = encoder_out.size(0)
#             encoder_dim = encoder_out.size(-1)
#             encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
#             h, c = decoder.init_hidden_state(encoder_out)
#             hidden = (h.unsqueeze(0), c.unsqueeze(0))
            
#             decoder_outputs, decoder_hidden, other = topk_decoder(None, hidden, encoder_out)
            
#             for b in range(batch_size):
#                 length = other['topk_length'][b][0]
#                 tgt_id_seq = [other['topk_sequence'][di][b, 0, 0].item() for di in range(length)]
#                 predictions.append(tgt_id_seq)
#             assert len(predictions) == batch_size
            
#         predictions = tokenizer.predict_captions(predictions)
#         predictions = ['InChI=1S/' + p.replace('<sos>', '') for p in predictions]
#         # print(predictions[0])
#         text_preds.append(predictions)
#     text_preds = np.concatenate(text_preds)
#     return text_preds

In [None]:
# states = torch.load(CFG.prev_model, map_location=torch.device('cpu'))

# encoder = Encoder(CFG.model_name, pretrained=False)
# encoder.load_state_dict(states['encoder'])
# encoder.to(device)

# decoder = DecoderWithAttention(attention_dim=CFG.attention_dim,
#                                embed_dim=CFG.embed_dim,
#                                decoder_dim=CFG.decoder_dim,
#                                vocab_size=len(tokenizer),
#                                dropout=CFG.dropout,
#                                device=device)

# decoder.load_state_dict(states['decoder'])
# decoder.to(device)

# del states; gc.collect()

# test_dataset = TestDataset(test, transform=get_transforms(data='valid'))
# test_loader = DataLoader(test_dataset, batch_size= 256, shuffle=False, num_workers=CFG.num_workers)
# predictions = inference(test_loader, encoder, decoder, tokenizer, device)

# del test_loader, encoder, decoder, tokenizer; gc.collect()

In [None]:
# # submission
# test['InChI'] = [f"InChI=1S/{text}" for text in predictions]
# test[['image_id', 'InChI']].to_csv('./submission.csv', index=False)

In [None]:
submission = pd.read_csv('../input/submission-beam-search/submission.csv')
submission.head(2)

In [None]:
submission['InChI'] = submission['InChI'].str.replace('InChI=1S/InChI=1S/','InChI=1S/')
submission.set_index('image_id', inplace=True)
submission.head(2)

In [None]:
submission.to_csv('./submission.csv')