In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import cv2
import pickle
import csv
import math
import re
from Levenshtein import distance as levenshtein_distance
import transformers

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch import Tensor
import torch.nn.init as init
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR, CyclicLR, ExponentialLR, StepLR
import torch.nn.functional as F
from torch.nn import TransformerDecoder, TransformerDecoderLayer
from torchvision import transforms
from matplotlib import pyplot as plt
!pip install efficientnet_pytorch
from efficientnet_pytorch import EfficientNet
import efficientnet_pytorch
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.0.tar.gz (20 kB)
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.0-py3-none-any.whl size=16033 sha256=8c3ef0d681f3d86811ff0638166fb4b05a89b9fb66c5d1837034f2c9ed7d8740
  Stored in directory: /root/.cache/pip/wheels/b7/cc/0d/41d384b0071c6f46e542aded5f8571700ace4f1eb3f1591c29
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.7.0


In [2]:
MEAN = [0.9871, 0.9871, 0.9871]
STD = [0.9892, 0.9892, 0.9892]
# TRANSFORMS = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize(mean=MEAN, std=STD)])
IMG_HEIGHT = 300
IMG_WIDTH = 300

def remove_blobs(img, min_size=10, debug=False):
    if debug:
        fig, ax = plt.subplots(1,2, figsize=(30,8))
        ax[0].imshow(img)
        ax[0].set_title('original image', size=16)
    
    height, width = img.shape

    # find all the connected components (white blobs in your image)
    nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=8)
    # Removes background, which is seen as a big component
    sizes = stats[1:, -1]
  
    blob_idxs = []    
    for idx, s in enumerate(sizes):
        if s < min_size:
            blob_idxs.append(idx+1)
    
    img[np.isin(output, blob_idxs)] = 0
    
    if debug:
        ax[1].imshow(img)
        ax[1].set_title('image with removed blobs', size=16)
        plt.show()
    
    return img


def crop(img, debug=False):
    if debug:
        fig, ax = plt.subplots(1,2, figsize=(30,8))
        ax[0].imshow(img)
        ax[0].set_title(f'original image, shape: {img.shape}', size=16)
        
    _, thresh = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)[-2:]
    
    x_min, y_min, x_max, y_max = np.inf, np.inf, 0, 0
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        x_min = min(x_min, x)
        y_min = min(y_min, y)
        x_max = max(x_max, x + w)
        y_max = max(y_max, y + h)

    img_cropped = img[y_min:y_max, x_min:x_max]
    
    if debug:
        ax[1].imshow(img_cropped)
        ax[1].set_title(f'cropped image, shape: {img_cropped.shape}', size=16)
        plt.show()
    
    return img_cropped


def pad_kernel(kernel, max_pad=np.inf):
    kernel = np.array(kernel)
    h, w = kernel.shape
    pad_h = min((max(h, w) - h) // 2, max_pad)
    pad_w = min((max(h, w) - w) // 2, max_pad)
    return np.pad(kernel, ([pad_h, pad_h], [pad_w, pad_w]), 'constant', constant_values=-1)


# creates a mask of missing pixels to be filled using
def create_mask(kernel, img_b):
    mask = cv2.filter2D(img_b, -1, kernel)
    kernel_flat_sum = (kernel == a).flatten().sum()
    threshold_min = kernel_flat_sum * threshold_ratio
    threshold_max = kernel_flat_sum + 1
    return (mask > threshold_min) & (mask < threshold_max)


# make kernels
a = np.float32(1.0 / 255.0)
threshold_ratio = 0.50
# single pixel width horizontal line with 1 pixel missing
kernel_h_single_mono = pad_kernel([
    [ a, a,  a, -1,  a,  a, a ]
], max_pad=1)
# single pixel width horizontal line with 3 pixels missing
kernel_h_single_triple = pad_kernel([
    [ a, a, a, -1, -1, -1, a, a, a ]
], max_pad=1)

kernel_h_multi = pad_kernel([
    [ a, a, a, a, a, a, a ],
    [ a, a, a,-1, a, a, a ],
    [ a, a, a, a, a, a, a ],
], max_pad=1)

kernel_v_single = pad_kernel([
    [ a],
    [ a],
    [ a],
    [-1],
    [ a],
    [ a],
    [ a],
], max_pad=1)

kernel_v_multi = pad_kernel([
    [ a, a, a ],
    [ a, a, a ],
    [ a, a, a ],
    [ a,-1, a ],
    [ a, a, a ],
    [ a, a, a ],
    [ a, a, a ],
], max_pad=1)

kernel_lr_single = pad_kernel([
    [ -1,-1,-1,-1, a ],
    [ -1,-1,-1, a,-1 ],
    [ -1,-1,-1,-1,-1 ],
    [ -1, a,-1,-1,-1 ],
    [  a,-1,-1,-1,-1 ],
])

kernel_lr_multi = pad_kernel([
    [ -1,-1,-1, a, a ],
    [ -1,-1, a, a, a ],
    [ -1, a,-1, a,-1 ],
    [  a, a, a,-1,-1 ],
    [  a, a,-1,-1,-1 ],
])

kernel_rl_single = pad_kernel([
    [  a,-1,-1,-1,-1 ],
    [ -1, a,-1,-1,-1 ],
    [ -1,-1,-1,-1,-1 ],
    [ -1,-1,-1, a,-1 ],
    [ -1,-1,-1,-1, a ],
])

kernel_rl_multi = pad_kernel([
    [ a, a,-1,-1,-1],
    [ a, a, a,-1,-1],
    [-1, a,-1, a,-1],
    [-1,-1, a, a, a],
    [-1,-1,-1, a, a],
])

def fill_missing_pixels(img, debug):
    img_b = img.astype(np.float32)
    img_b[img_b > 0] = 255

    mask_h_single_mono = create_mask(kernel_h_single_mono, img_b)

    mask_h_single_triple = create_mask(kernel_h_single_triple, img_b)

    mask_h_single = mask_h_single_mono | mask_h_single_triple

    mask_h_multi = create_mask(kernel_h_multi, img_b)


    mask_v_single = create_mask(kernel_v_single, img_b)


    mask_v_multi = create_mask(kernel_v_multi, img_b)


    mask_lr_single = create_mask(kernel_lr_single, img_b)


    mask_lr_multi = create_mask(kernel_lr_multi, img_b)


    mask_rl_single = create_mask(kernel_lr_single, img_b)


    mask_rl_multi = create_mask(kernel_rl_multi, img_b)

    mask_single = mask_h_single | mask_v_single | mask_lr_single | mask_rl_single
    mask_multi = mask_h_multi  | mask_v_multi |mask_lr_multi | mask_rl_multi
    mask = mask_single | mask_multi

    if debug:
        fig, ax = plt.subplots(2, 2 ,figsize=(35,20))
        ax[0,0].imshow(mask_h_single)
        ax[0,0].set_title('mask_h_single', size=16)
        ax[0,1].imshow(mask_v_single)
        ax[0,1].set_title('mask_v_single', size=16)
        ax[1,0].imshow(mask_lr_single)
        ax[1,0].set_title('mask_lr_single', size=16)
        ax[1,1].imshow(mask_lr_single)
        ax[1,1].set_title('mask_lr_single', size=16)
        plt.show()

        fig, ax = plt.subplots(2, 2, figsize=(35,20))
        ax[0,0].imshow(mask_h_multi)
        ax[0,0].set_title('mask_h_multi', size=16)
        ax[0,1].imshow(mask_v_multi)
        ax[0,1].set_title('mask_v_multi', size=16)
        ax[1,0].imshow(mask_lr_multi)
        ax[1,0].set_title('mask_lr_multi', size=16)
        ax[1,1].imshow(mask_rl_multi)
        ax[1,1].set_title('mask_rl_multi', size=16)
        plt.show()

        fig, ax = plt.subplots(2, 1 ,figsize=(15,20))
        ax[0].imshow(img)
        ax[0].set_title('original image', size=16)

        img_fill = mask.copy()
        img_fill[img_fill > 0] = 255

        img_rgb = np.stack([
            img_fill,
            img_b,
            np.zeros(img.shape),
        ], axis=2)

        ax[1].imshow(img_rgb)
        ax[1].set_title('image with filled missing pixels (red)', size=16)
        plt.show()    

    # all pixels in the mask are filled up
    img[mask] = 255

    return img


def pad_resize(img):
    h, w = img.shape
    s = max(w, h)
    pad_h, pad_v = 0, 0
    hw_ratio = (h / w) - (IMG_HEIGHT / IMG_WIDTH)
    if hw_ratio < 0:
        pad_h = int(abs(hw_ratio) * w / 2)
    else:
        wh_ratio = (w / h) - (IMG_WIDTH / IMG_HEIGHT)
        pad_v = int(abs(wh_ratio) * h // 2)

    img = np.pad(img, [(pad_h, pad_h), (pad_v, pad_v)], mode='constant')
    img = cv2.resize(img,(IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)

    return img


def process_img(file_path, folder='train', debug=False):
    # read image and invert colors to get black background and white moleculernir
    img0 = 255 - cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
    
    # rotate counter clockwise to get horizontal images
    h, w = img0.shape
    if h > w:
        img0 = np.rot90(img0)
    
    # remove blobs, crop, fill missing pixels, pad and resize
    img = remove_blobs(img0, min_size=2, debug=debug)
    img = crop(img, debug=debug)
    img = fill_missing_pixels(img, debug=debug)
    img = pad_resize(img)
    
    if debug:
        fig, ax = plt.subplots(1, 2, figsize=(20,10))
        ax[0].imshow(img0)
        ax[0].set_title('Original image', size=16)
        ax[1].imshow(img)
        ax[1].set_title('Fully processed image', size=16)
    
    # normalize to range 0-255 and encode as png
    img = (img / img.max() * 255).astype(np.uint8)
#     img = cv2.imencode('.png', img)[1].tobytes()
#     print('img',img)

    return img


class Tokenizer():
    
    def __init__(self, label_dir):
        
        self.prefix_list = ['c', 'h', 'b', 't', 'm', 's', 'i', 'h', 't', 'm', 's']
        self.ELEM_REGEX = re.compile(r"[A-Z][a-z]?[0-9]*")
        self.ATOM_REGEX = re.compile(r"[A-Z][a-z]?")
        self.NUM_REGEX = re.compile(r"[0-9]+|.")
        
        self.max_str_len = 0
        self.vocab = set()
        self.label_dict = dict()
        
        with open(label_dir) as csvfile:
            reader = csv.reader(csvfile)
            next(reader, None)
            for row in reader:
                label = row[1].split('/', 1)[1]
                tokenized_label = self._tokenize_InChI(label)
                self.vocab.update(set(tokenized_label))
                if self.max_str_len<len(tokenized_label):
                    self.max_str_len=len(tokenized_label)
        self.vocab = sorted(self.vocab)
        self.stoi={v: idx+1 for idx,v in enumerate(self.vocab)}
        self.itos={item[1]: item[0] for item in self.stoi.items()}
        
        with open(label_dir) as csvfile:
            reader = csv.reader(csvfile)
            next(reader, None)
            for row in reader:
                label = row[1].split('/', 1)[1]
                tokenized_label = self._tokenize_InChI(label)
                self.label_dict[row[0]] = [self.stoi[t] for t in tokenized_label]
                
        
    def _tokenize_InChI(self, st):
        st_list = st.split('/')
        str_out = []
        str_out.append('<sos>')
        str_out.extend(self._tokenize_formula(st_list.pop(0)))
        str_out.append('<eos>')
        for prefix in self.prefix_list:
            if st_list:
                if st_list[0][0]==prefix[0]:
                    str_out.extend(self.NUM_REGEX.findall(st_list.pop(0)[1:]))
            str_out.append('<eos>')        
        return str_out
    

    def _tokenize_formula(self, st):
        st_out = []
        elem_list = self.ELEM_REGEX.findall(st)
        for elem in elem_list:
            atom = self.ATOM_REGEX.findall(elem)[0]
            st_out.append(atom)
            count_str = elem[elem.find(atom)+len(atom):]
            if count_str!='':
                st_out.append(count_str)
        return st_out

In [None]:
def preprocess(img_dir, label_dir):
    img_list = list()
    for dirname, _, filenames in os.walk(img_dir):
            for filename in filenames:
                img_list.append(os.path.join(dirname, filename))
    print('Image names loaded')
        
    tokenizer = Tokenizer(label_dir)
    print('Labels loaded')
    
    return img_list, tokenizer

img_list, tokenizer = preprocess(img_dir ='../input/bms-molecular-translation/train', 
                                 label_dir='../input/bms-molecular-translation/train_labels.csv')
pickle.dump(img_list, open("train_image_names.p", "wb"))
pickle.dump(tokenizer, open("train_label_data.p", "wb"))

In [None]:
class BMSdata (Dataset):
    
    def __init__(self, img_list, tokenizer, cleaning=None):
        print('Intantiating dataset')
        self.cleaning = cleaning
        self.img_list = img_list
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.img_list)
    
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_list[idx])
        #img = cv2.resize(img, (300, 300))
        if self.cleaning:
            img = self.cleaning(self.img_list[idx])
            
        img_id = self.img_list[idx].split('/')[-1][:-4]
        label_ = np.array(self.tokenizer.label_dict[img_id])
        len_ = label_.shape[0]
            
        src = np.zeros(self.tokenizer.max_str_len+1)
        src[:len_] = label_
            
        trg = np.zeros(self.tokenizer.max_str_len+1)
        trg[:len_-1] = label_[1:]
            
        return img, src, trg

In [None]:
img_list = pickle.load(open("../input/processed-data/train_image_names.p", "rb"))
tokenizer = pickle.load(open("../input/processed-data/train_label_data.p", "rb"))
dataset = BMSdata(img_list, tokenizer, cleaning=process_img)
int_to_char = np.array(list(dataset.tokenizer.itos.values()))
int_to_char

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, out_size):
        super(Encoder, self).__init__()
        self.image_encoder_model = EfficientNet.from_name('efficientnet-b0')
        self.image_encoder_model._conv_stem = efficientnet_pytorch.utils.Conv2dStaticSamePadding(
                                                1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False, image_size=302)
        self.fc = nn.Linear(1000, out_size)
        
    def forward(self, input_):
        conv_out = self.image_encoder_model(input_).squeeze()
        enc_out = self.fc(conv_out)
        return enc_out

    
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class Decoder(nn.Module):

    def __init__(self, in_size, nhead, hidden_size, nlayers, out_size, dropout=0):
        super(Decoder, self).__init__()
        
        self.pos_encoder = PositionalEncoding(in_size, dropout)
        decoder_layers = TransformerDecoderLayer(in_size, nhead, hidden_size, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.linear = nn.Linear(in_size, out_size)

    def forward(self, tgt, memory, tgt_mask):
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory, tgt_mask)
        edge_logit = self.linear(output)
        return edge_logit


def _inflate(tensor, times, dim):
    # repeat_dims = [1] * tensor.dim()
    # repeat_dims[dim] = times
    # return tensor.repeat(*repeat_dims)
    return torch.repeat_interleave(tensor, times, dim)


class TopKDecoder(torch.nn.Module):
    r"""
    Top-K decoding with beam search.

    Args:
        decoder_rnn (DecoderRNN): An object of DecoderRNN used for decoding.
        k (int): Size of the beam.

    Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio
        - **inputs** (seq_len, batch, input_size): list of sequences, whose length is the batch size and within which
          each sequence is a list of token IDs.  It is used for teacher forcing when provided. (default is `None`)
        - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features
          in the hidden state `h` of encoder. Used as the initial hidden state of the decoder.
        - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder.
          Used for attention mechanism (default is `None`).
        - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state
          (default is `torch.nn.functional.log_softmax`).
        - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is
          drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value,
          teacher forcing would be used (default is 0).

    Outputs: decoder_outputs, decoder_hidden, ret_dict
        - **decoder_outputs** (batch): batch-length list of tensors with size (max_length, hidden_size) containing the
          outputs of the decoder.
        - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden
          state of the decoder.
        - **ret_dict**: dictionary containing additional information as follows {*length* : list of integers
          representing lengths of output sequences, *topk_length*: list of integers representing lengths of beam search
          sequences, *sequence* : list of sequences, where each sequence is a list of predicted token IDs,
          *topk_sequence* : list of beam search sequences, each beam is a list of token IDs, *inputs* : target
          outputs if provided for decoding}.
    """

    def __init__(self, decoder_rnn, k, decoder_dim, max_length, tokenizer):
        super(TopKDecoder, self).__init__()
        self.rnn = decoder_rnn
        self.k = k
        self.hidden_size = decoder_dim  # self.rnn.hidden_size
        self.V = len(tokenizer)
        self.SOS = tokenizer.stoi["<sos>"]
        self.EOS = tokenizer.stoi["<eos>"]
        self.max_length = max_length
        self.tokenizer = tokenizer

    def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, function=F.log_softmax,
                teacher_forcing_ratio=0, retain_output_probs=True):
        """
        Forward rnn for MAX_LENGTH steps.  Look at :func:`seq2seq.models.DecoderRNN.DecoderRNN.forward_rnn` for details.
        """

        # inputs, batch_size, max_length = self.rnn._validate_args(inputs, encoder_hidden, encoder_outputs,
        #                                                         function, teacher_forcing_ratio)

        batch_size = encoder_outputs.size(0)
        max_length = self.max_length

        self.pos_index = (torch.LongTensor(range(batch_size)) * self.k).view(-1, 1).cuda()

        # Inflate the initial hidden states to be of size: b*k x h
        # encoder_hidden = self.rnn._init_state(encoder_hidden)
        if encoder_hidden is None:
            hidden = None
        else:
            if isinstance(encoder_hidden, tuple):
                # hidden = tuple([_inflate(h, self.k, 1) for h in encoder_hidden])
                hidden = tuple([h.squeeze(0) for h in encoder_hidden])
                hidden = tuple([_inflate(h, self.k, 0) for h in hidden])
                hidden = tuple([h.unsqueeze(0) for h in hidden])
            else:
                # hidden = _inflate(encoder_hidden, self.k, 1)
                raise RuntimeError("Not supported")

        # ... same idea for encoder_outputs and decoder_outputs
        if True:  # self.rnn.use_attention:
            inflated_encoder_outputs = _inflate(encoder_outputs, self.k, 0)
        else:
            inflated_encoder_outputs = None

        # Initialize the scores; for the first step,
        # ignore the inflated copies to avoid duplicate entries in the top k
        sequence_scores = torch.Tensor(batch_size * self.k, 1)
        sequence_scores.fill_(-float('Inf'))
        sequence_scores.index_fill_(0, torch.LongTensor([i * self.k for i in range(0, batch_size)]), 0.0)
        sequence_scores = sequence_scores.cuda()

        # Initialize the input vector
        input_var = torch.transpose(torch.LongTensor([[self.SOS] * batch_size * self.k]), 0, 1).cuda()

        # Store decisions for backtracking
        stored_outputs = list()
        stored_scores = list()
        stored_predecessors = list()
        stored_emitted_symbols = list()
        stored_hidden = list()

        for i in range(0, max_length):

            # Run the RNN one step forward
            log_softmax_output, hidden, _ = self.rnn.forward_step(input_var, hidden,
                                                                  inflated_encoder_outputs, function=function)
            # If doing local backprop (e.g. supervised training), retain the output layer
            if retain_output_probs:
                stored_outputs.append(log_softmax_output)

            # To get the full sequence scores for the new candidates, add the local scores for t_i to the predecessor scores for t_(i-1)
            sequence_scores = _inflate(sequence_scores, self.V, 1)
            sequence_scores += log_softmax_output.squeeze(1)
            scores, candidates = sequence_scores.view(batch_size, -1).topk(self.k, dim=1)

            # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
            input_var = (candidates % self.V).view(batch_size * self.k, 1)
            sequence_scores = scores.view(batch_size * self.k, 1)

            # Update fields for next timestep
            predecessors = (candidates // self.V + self.pos_index.expand_as(candidates)).view(batch_size * self.k, 1)
            if isinstance(hidden, tuple):
                hidden = tuple([h.index_select(1, predecessors.squeeze()) for h in hidden])
            else:
                hidden = hidden.index_select(1, predecessors.squeeze())

            # Update sequence scores and erase scores for end-of-sentence symbol so that they aren't expanded
            stored_scores.append(sequence_scores.clone())
            eos_indices = input_var.data.eq(self.EOS)
            if eos_indices.nonzero().dim() > 0:
                sequence_scores.data.masked_fill_(eos_indices, -float('inf'))

            # Cache results for backtracking
            stored_predecessors.append(predecessors)
            stored_emitted_symbols.append(input_var)
            stored_hidden.append(hidden)

        # Do backtracking to return the optimal values
        output, h_t, h_n, s, l, p = self._backtrack(stored_outputs, stored_hidden,
                                                    stored_predecessors, stored_emitted_symbols,
                                                    stored_scores, batch_size, self.hidden_size)

        # Build return objects
        decoder_outputs = [step[:, 0, :] for step in output]
        if isinstance(h_n, tuple):
            decoder_hidden = tuple([h[:, :, 0, :] for h in h_n])
        else:
            decoder_hidden = h_n[:, :, 0, :]
        metadata = {}
        metadata['inputs'] = inputs
        metadata['output'] = output
        metadata['h_t'] = h_t
        metadata['score'] = s
        metadata['topk_length'] = l
        metadata['topk_sequence'] = p
        metadata['length'] = [seq_len[0] for seq_len in l]
        metadata['sequence'] = [seq[0] for seq in p]
        return decoder_outputs, decoder_hidden, metadata

    def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hidden_size):
        """Backtracks over batch to generate optimal k-sequences.

        Args:
            nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network
            nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network
            predecessors [(batch*k)] * sequence_length: A Tensor of predecessors
            symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens
            scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1]
            b: Size of the batch
            hidden_size: Size of the hidden state

        Returns:
            output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n)
            from the last layer of the RNN, for every n = [0, ... , seq_len - 1]

            h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences.

            score [batch, k]: A list containing the final scores for all top-k sequences

            length [batch, k]: A list specifying the length of each sequence in the top-k candidates

            p (batch, k, sequence_len): A Tensor containing predicted sequence
        """

        lstm = isinstance(nw_hidden[0], tuple)

        # initialize return variables given different types
        output = list()
        h_t = list()
        p = list()
        # Placeholder for last hidden state of top-k sequences.
        # If a (top-k) sequence ends early in decoding, `h_n` contains
        # its hidden state when it sees EOS.  Otherwise, `h_n` contains
        # the last hidden state of decoding.
        if lstm:
            state_size = nw_hidden[0][0].size()
            h_n = tuple([torch.zeros(state_size).cuda(), torch.zeros(state_size).cuda()])
        else:
            h_n = torch.zeros(nw_hidden[0].size()).cuda()
        l = [[self.max_length] * self.k for _ in range(b)]  # Placeholder for lengths of top-k sequences
        # Similar to `h_n`

        # the last step output of the beams are not sorted
        # thus they are sorted here
        sorted_score, sorted_idx = scores[-1].view(b, self.k).topk(self.k)
        # initialize the sequence scores with the sorted last step beam scores
        s = sorted_score.clone()

        batch_eos_found = [0] * b  # the number of EOS found
        # in the backward loop below for each batch

        t = self.max_length - 1
        # initialize the back pointer with the sorted order of the last step beams.
        # add self.pos_index for indexing variable with b*k as the first dimension.
        t_predecessors = (sorted_idx + self.pos_index.expand_as(sorted_idx)).view(b * self.k)
        while t >= 0:
            # Re-order the variables with the back pointer
            current_output = nw_output[t].index_select(0, t_predecessors)
            if lstm:
                current_hidden = tuple([h.index_select(1, t_predecessors) for h in nw_hidden[t]])
            else:
                current_hidden = nw_hidden[t].index_select(1, t_predecessors)
            current_symbol = symbols[t].index_select(0, t_predecessors)
            # Re-order the back pointer of the previous step with the back pointer of
            # the current step
            t_predecessors = predecessors[t].index_select(0, t_predecessors).squeeze()

            # This tricky block handles dropped sequences that see EOS earlier.
            # The basic idea is summarized below:
            #
            #   Terms:
            #       Ended sequences = sequences that see EOS early and dropped
            #       Survived sequences = sequences in the last step of the beams
            #
            #       Although the ended sequences are dropped during decoding,
            #   their generated symbols and complete backtracking information are still
            #   in the backtracking variables.
            #   For each batch, everytime we see an EOS in the backtracking process,
            #       1. If there is survived sequences in the return variables, replace
            #       the one with the lowest survived sequence score with the new ended
            #       sequences
            #       2. Otherwise, replace the ended sequence with the lowest sequence
            #       score with the new ended sequence
            #
            eos_indices = symbols[t].data.squeeze(1).eq(self.EOS).nonzero()
            if eos_indices.dim() > 0:
                for i in range(eos_indices.size(0) - 1, -1, -1):
                    # Indices of the EOS symbol for both variables
                    # with b*k as the first dimension, and b, k for
                    # the first two dimensions
                    idx = eos_indices[i]
                    b_idx = int(idx[0] // self.k)
                    # The indices of the replacing position
                    # according to the replacement strategy noted above
                    res_k_idx = self.k - (batch_eos_found[b_idx] % self.k) - 1
                    batch_eos_found[b_idx] += 1
                    res_idx = b_idx * self.k + res_k_idx

                    # Replace the old information in return variables
                    # with the new ended sequence information
                    t_predecessors[res_idx] = predecessors[t][idx[0]]
                    current_output[res_idx, :] = nw_output[t][idx[0], :]
                    if lstm:
                        current_hidden[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :]
                        current_hidden[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :]
                        h_n[0][:, res_idx, :] = nw_hidden[t][0][:, idx[0], :].data
                        h_n[1][:, res_idx, :] = nw_hidden[t][1][:, idx[0], :].data
                    else:
                        current_hidden[:, res_idx, :] = nw_hidden[t][:, idx[0], :]
                        h_n[:, res_idx, :] = nw_hidden[t][:, idx[0], :].data
                    current_symbol[res_idx, :] = symbols[t][idx[0]]
                    s[b_idx, res_k_idx] = scores[t][idx[0]].data[0]
                    l[b_idx][res_k_idx] = t + 1

            # record the back tracked results
            output.append(current_output)
            h_t.append(current_hidden)
            p.append(current_symbol)

            t -= 1

        # Sort and re-order again as the added ended sequences may change
        # the order (very unlikely)
        s, re_sorted_idx = s.topk(self.k)
        for b_idx in range(b):
            l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]]

        re_sorted_idx = (re_sorted_idx + self.pos_index.expand_as(re_sorted_idx)).view(b * self.k)

        # Reverse the sequences and re-order at the same time
        # It is reversed because the backtracking happens in reverse time order
        output = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(output)]
        p = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(p)]
        if lstm:
            h_t = [tuple([h.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for h in step]) for step in reversed(h_t)]
            h_n = tuple([h.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) for h in h_n])
        else:
            h_t = [step.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for step in reversed(h_t)]
            h_n = h_n.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size)
        s = s.data

        return output, h_t, h_n, s, l, p

    def _mask_symbol_scores(self, score, idx, masking_score=-float('inf')):
        score[idx] = masking_score

    def _mask(self, tensor, idx, dim=0, masking_score=-float('inf')):
        if len(idx.size()) > 0:
            indices = idx[:, 0]
            tensor.index_fill_(dim, indices, masking_score)

In [None]:
# Instantiate model
#out_channel_size = [32, 64, 128, 256, 512, 1024]
#in_channels = 3
encoder = Encoder(out_size=128).cuda()
decoder = Decoder(in_size=128, nhead=4, hidden_size=256, nlayers=6, out_size=185).cuda()
embedding = nn.Embedding(185, 128, padding_idx=0).cuda()
model = encoder, decoder, embedding

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def train(model, dataset):
    
    total_steps = 60000
    warmup_steps = 500
    n = dataset.__len__()
    num_eval = int(0.2*n)
    num_train =  n-num_eval
    
    encoder, decoder, embedding = model
    parameters = list(encoder.parameters())
    parameters.extend(list(decoder.parameters()))
    parameters.extend(list(embedding.parameters()))
        
    max_val_acc = 0
    softmax_ = nn.Softmax(2)
    criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
    optimizer = torch.optim.Adam(parameters, lr=0.005, weight_decay=0)
    scheduler = transformers.get_cosine_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    encoder.train()
    decoder.train()
    embedding.train()
    
    step=0
    epoch=0
    ld_list = []
    loss_list = []
    lr_list = []
    while step<total_steps:
        train_dataset, val_dataset = random_split(dataset, [num_train, num_eval])
        train_loader = DataLoader(dataset, batch_size=64, shuffle=True, sampler=None, batch_sampler=None)
        #val_loader = DataLoader(val_dataset, batch_size=100, shuffle=True, sampler=None, batch_sampler=None)
        
        for img, src, trg in train_loader:
            step+=1
            img = img.unsqueeze(1).float().cuda()
            trg_mask = generate_square_subsequent_mask(286).cuda()
            src = src.long().cuda()
            src_emb = embedding(src).permute(1, 0, 2)
            trg = trg.long().cuda()-1
            encoder_out = encoder(img)
            decoder_out = decoder(src_emb, encoder_out, trg_mask)
            #print(encoder_out.shape, src.shape, decoder_out.shape, trg.shape)
            
            loss = torch.mean(criterion(decoder_out.permute(1,2,0), trg))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if step%100==0:
                ld = get_LD(softmax_(decoder_out), trg)
                ld_list.append(ld)
                loss_ = float(loss.cpu().detach().numpy())
                loss_list.append(loss_)
                lr_ = get_lr(optimizer)
                lr_list.append(lr_)
                print('Step: {} | LD: {:.4f} | Loss: {:.4f} | learning rate: {:.6f}'.format(step, ld, loss_, lr_))
                # save data after each epoch
                torch.save(encoder.state_dict(), '/kaggle/working/encoder_best_model.ckpt')
                torch.save(decoder.state_dict(), '/kaggle/working/decoder_best_model.ckpt')
                torch.save(embedding.state_dict(), '/kaggle/working/embedding_best_model.ckpt')                
                pickle.dump(ld_list, open('/kaggle/working/ld_list.p','wb'))
                pickle.dump(loss_list, open('/kaggle/working/loss_list.p','wb'))
                pickle.dump(lr_list, open('/kaggle/working/lr_list.p','wb'))
            


def get_LD(decoder_softmax, target):
    decoder_logit = torch.argmax(decoder_softmax, 2)
    decoder_logit = decoder_logit.detach().cpu().numpy()
    target = target.permute(1,0).detach().cpu().numpy()
    
    acc = []
    for i in range(decoder_logit.shape[1]):
        output_ = ''.join(int_to_char[decoder_logit[:,i]]).replace('<eos>', '/')
        target_ = ''.join(int_to_char[target[:,i]]).replace('<eos>', '/')
        
        try:
            eos_d = [m.start() for m in re.finditer(r'/', output_)][-1]
            output_ = output_[:eos_d]
        except IndexError:
            pass
        
        try:
            eos_t = [m.start() for m in re.finditer(r'/', target_)][-1]
            target_ = target[:eos_t]
        except IndexError:
            pass

        print(output_)
        print(target_)
        acc.append(levenshtein_distance(output_, target_))
    
    return sum(acc)/len(acc)

In [None]:
# Train
train(model, dataset)