This notebook borrowed many codes from： 
1. [PyTorch Tutorial](https://pytorch.org/tutorials/beginner/torchtext_translation.html)
2. [Blog: solving-an-image-captioning-task-using-deep-learning](https://www.analyticsvidhya.com/blog/2018/04/solving-an-image-captioning-task-using-deep-learning/)

This is a baseline model that using CNN and LSTM to translate image into chemical structure. The total training images are 2424186, that's enormous. So I decided to sample 5% of full training data for training.
![](https://github.com/yunjey/pytorch-tutorial/raw/master/tutorials/03-advanced/image_captioning/png/model.png)

In [None]:
!pip install timm

In [None]:
import os
import pandas as pd
import torch
import numpy as np
import random

from torch.utils.data import DataLoader
import cv2

import timm
from pprint import pprint
# model_names = timm.list_models(pretrained=True)
# pprint(model_names)

In [None]:
def seed_everything(seed=99):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything()

In [None]:
cfg = {
    'version':'version2',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'train_img':'../input/bms-molecular-translation/train',
    'train_anno':'../input/bms-molecular-translation/train_labels.csv',
    'sample_ratio': 0.1,
    'backbone':'efficientnet_b0',
    'pretrianed':True,
    'vocab_size':38, # 36 unique char + [SOS, EOS]
    'max_len':150,
    'embed_size':16,
    'hidden_size':64,
    'image_size':[128, 128],
    'batch_size':64,
    'num_workers':6,
    'n_epochs':10,
    'lr':1e-3,
    'min_lr':1e-6,
    'patience':3,
    'TTA':5,
}

if not os.path.isdir(cfg['version']):
    os.mkdir(cfg['version'])
    print('create dir')
cfg

#### build vocab
add start token 'SOS 'and end token 'EOS' into vocab

In [None]:
class Lang():
    """
    seq: chemical structure, shape like: 'InChI=1S...'
    """
    def __init__(self):
        start_end_token = ['PAD', 'SOS', 'EOS']
        self.vocab = start_end_token + ['(', ')', '+', ',', '-', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'B', 'C', 'D', 'F', 'H', 'I', 'N', 'O', 'P', 'S', 'T', 'b', 'c', 'h', 'i', 'l', 'm', 'r', 's', 't']
        self.token_to_idx = {token:idx for idx, token in enumerate(self.vocab)}
        
    def seq_to_idx(self, seq):
        idxs = []
        seq = seq.replace('InChI=1S', '')  # remove head
        for token in seq:
            idxs += [self.token_to_idx[token]]
        source = [self.token_to_idx['SOS']] + idxs
        target = idxs + [self.token_to_idx['EOS']]
        return source, target      # return source and target
    
    def idx_to_seq(self, idxs):
        idxs = self.rm_re_idxs(idxs)  # remove repeated text
        # add head, remove SOS and EOS
        seq = 'InChI=1S'+''.join([self.vocab[idx] for idx in idxs])
        
        return seq.replace('SOS','').replace('EOS','')
    
    def rm_re_idxs(self, idxs):
        # remove repeated text
        new_idxs = []
        for idx in idxs:
            if idx == self.token_to_idx['EOS'] or idx == self.token_to_idx['PAD']:
                break
            new_idxs += [idx]
        return new_idxs
    
lang = Lang()
cfg['vocab_size'] = len(lang.vocab)

sentence to idxs, then restore

In [None]:
sentence = 'InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12(13)11(4)14/h5-7,9,11,14H,8H2,1-4H3'
source, target = lang.seq_to_idx(sentence)
lang.idx_to_seq(source)

sampling data, then split it into train and test dataset

In [None]:
from sklearn.model_selection import train_test_split

full_df = pd.read_csv(cfg['train_anno'])

sample_df = full_df.sample(frac=cfg['sample_ratio']).reset_index(drop=True)

X = list(range(len(sample_df)))
train_indexs, val_indexs = train_test_split(X, test_size=0.33, random_state=42)
train_df, val_df = sample_df.loc[train_indexs], sample_df.loc[val_indexs]
train_df.head()

augmentation pipeline

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

transform = A.Compose([
    A.Resize(cfg['image_size'][0], cfg['image_size'][1]),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

dataloader pipeline

In [None]:
def convert_image_id_2_path(img_dir:str, image_id: str) -> str:
    return "{}/{}/{}/{}/{}.png".format(img_dir,
        image_id[0], image_id[1], image_id[2], image_id 
    )

def data_process(df, img_dir):
    data = []
    for idx in range(len(df)):
        image_id = df.iloc[idx].image_id
        img_path = convert_image_id_2_path(img_dir, image_id)
        seq = df.iloc[idx].InChI
        data += [(img_path, seq)]
        
    return data
    
def generate_batch(data_batch, tfs=transform, train=True):
    
    img_batch, source_batch, target_batch = [], [], []
    for (img_path, seq) in data_batch:
        img = cv2.imread(img_path)
        if tfs:
            img = tfs(image=img)['image']
            
        img_batch += [img]
        source, target = lang.seq_to_idx(seq)
        source_batch += [torch.tensor(source, dtype=torch.long)]
        target_batch += [torch.tensor(target, dtype=torch.long)]
    return img_batch, source_batch, target_batch


In [None]:
train_data = data_process(train_df, img_dir=cfg['train_img'])
val_data = data_process(val_df, img_dir=cfg['train_img'])
train_iter = DataLoader(train_data, batch_size=cfg['batch_size'],
                        shuffle=True, collate_fn=generate_batch, num_workers=cfg['num_workers'])
valid_iter = DataLoader(val_data, batch_size=cfg['batch_size'],
                        shuffle=False, collate_fn=generate_batch, num_workers=cfg['num_workers'])

display samples

In [None]:
import matplotlib.pyplot as plt

def visualize_batch(image, labels):
    plt.figure(figsize=(16, 12))
    for ind, (image, label) in enumerate(zip(image, labels)):
        plt.subplot(3, 3, ind + 1)
        plt.imshow(image.permute(1, 2, 0))
        plt.title(f"{label[:10]}...", fontsize=10)
        plt.axis("off")
    
    plt.show()

for imgs, source, target in train_iter:
    visualize_batch(imgs[:3], target[:3])
    break

In [None]:
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence
import torch.nn as nn

# Conv + LSTM
class Generator(nn.Module):
    """
    Conv encoder LSTM decoder
    x: imgs
    seqs: padded sequence of idxs that is a batch of InChl
    lengths: batch of seqs length
    """
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.backbone = timm.create_model(cfg['backbone'], pretrained=cfg['pretrianed'], num_classes=cfg['hidden_size'])
        self.emb_layer = nn.Embedding(cfg['vocab_size'], cfg['embed_size'])
        self.lstm = nn.LSTM(cfg['embed_size'], cfg['hidden_size'])
        self.out_layer = nn.Linear(cfg['hidden_size'], cfg['vocab_size'])
        
    def forward(self, x, source_padded, lens):
        batch_size = x.size(0)
        features = self.backbone(x)
        (h, c) = self.init_state(batch_size)
        states = (features.unsqueeze(0), c) 
        input_embedding = self.emb_layer(source_padded)    # embedding: len, batch, size
        packed = pack_padded_sequence(input_embedding, lens)
        out_packed, _ = self.lstm(packed, states)
        outputs = self.out_layer(out_packed[0])
        return outputs   # packed len, batch
    
    def greedy_decode(self, x, lens=None):
        """Greedy search"""
        max_len = lens if lens is not None else self.cfg['max_len']
        sampled_ids = []
        batch_size = x.size(0)
        logits = torch.zeros(max_len, batch_size, cfg['vocab_size'], device=self.cfg['device'])
        features = self.backbone(x)
        (h, c) = self.init_state(batch_size)
        states = (features.unsqueeze(0), c)
        # create start 
        SOS = torch.tensor(lang.token_to_idx['SOS'], dtype=torch.long, device=self.cfg['device']).expand((1, batch_size)) # create SOS
        last_emb = self.emb_layer(SOS) 
        for i in range(max_len):                      # maximum sampling length
            hiddens, states = self.lstm(last_emb, states)         # (1, batch, hidden_size), 
            outputs = self.out_layer(hiddens.squeeze())           # (batch_size, vocab_size)
            curr = torch.argmax(outputs, dim=-1, keepdim=True)    # to idxs
            curr = curr.reshape(1, batch_size)
            last_emb = self.emb_layer(curr)  # (1, batch, embed_size)
            
            sampled_ids.append(curr)
            logits[i] = outputs
            if curr[:,0] == lang.token_to_idx['EOS'] or curr[:,0] == lang.token_to_idx['PAD']:
                break
        sampled_ids = torch.cat(sampled_ids)                    
        return sampled_ids, logits  # (seq, batch)
    
    def init_state(self, batch_size):
        return (
            torch.zeros(1, batch_size, self.cfg['hidden_size']).to(self.cfg['device']),
            torch.zeros(1, batch_size, self.cfg['hidden_size']).to(self.cfg['device'])
        )

In [None]:
import Levenshtein

# metric
def get_score(y_pred, y_true):
    scores = []
    for true, pred in zip(y_true, y_pred):
        score = Levenshtein.distance(true, pred)
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score

def get_accuracy(y_pred, y_true):
    return (y_pred == y_true).sum() / len(y_pred)

In [None]:
def pad_packed(packed_data, batch_sizes):
    max_len, N = len(batch_sizes), batch_sizes[0]
    padded_packed = np.zeros((max_len, N), dtype=np.int64)
    lens = np.zeros(N, dtype=np.int64)
    curr_idx = 0
    for step, batch_size in enumerate(batch_sizes):
        curr_batch = packed_data[curr_idx: curr_idx+batch_size]
        padded_packed[step, :batch_size] = curr_batch
        lens[:batch_size] += 1
        curr_idx += batch_size
        
    return padded_packed, lens

# preprocess data
def load_data(data):
    imgs_batch, source_batch, target_batch = data
    imgs_batch = torch.stack(imgs_batch, dim=0).to(cfg['device'])
    lens = torch.tensor([len(item) for item in source_batch])
    
     # sort by length
    sorted_lens, sorted_indices = torch.sort(lens, descending=True)
    imgs_batch = imgs_batch[sorted_indices].to(cfg['device'])
    source_padded = pad_sequence(source_batch)  # padding
    source_padded = source_padded[:,sorted_indices].to(cfg['device'])
    target_batch = [target_batch[idx] for idx in sorted_indices]
    return imgs_batch, source_padded, sorted_lens, target_batch

In [None]:
import math

class Eearly_Stopping():
    def __init__(self, mode: str='min', patience: int=10, apply=True):
        self.best = math.inf
        self.mode = mode
        self.base_patience = patience
        self.patience = patience
        self.apply=apply
        if self.patience <= 0:
            raise Exception("Invalid patience!", patience)
            
    def step(self, model, monitor):
        if monitor < self.best:
            self.best = monitor
            torch.save(model, f'model.pth') # save model
            self.patience = self.base_patience
        else:
            self.patience -= 1 if self.apply else -1
        
        return True if self.patience==0 else False

In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau 

# training setting
generator = Generator(cfg).to(cfg['device'])
optimizer = optim.Adam(generator.parameters(), lr=cfg['lr'])
criterion = nn.CrossEntropyLoss()
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=cfg['patience'], factor=0.5, verbose=True)
es = Eearly_Stopping(mode='min', patience=5, apply=True)

In [None]:
def evaluate(val_loader, model):
    pred_text = []
    gt_text = []
    bar = tqdm(val_loader, desc='eval')
    model.eval()
    with torch.no_grad():
        for data in bar:
            imgs_batch, source_padded, lens, target_batch = load_data(data)
            pred_batch, logits = model.greedy_decode(imgs_batch, lens.max())
            outs_packed = pack_padded_sequence(logits, lens)
            target_packed = pack_sequence(target_batch)
            target_packed = target_packed.to(cfg['device'])
           
            loss = criterion(outs_packed.data, target_packed.data)
            
            # compute acc
            pred_packed = torch.argmax(torch.softmax(outs_packed.data, dim=-1), dim=-1)
            pred_packed = pred_packed.cpu().detach().numpy()
            target_packed = target_packed.data.cpu().numpy()
            acc = get_accuracy(pred_packed, target_packed)
            
            # compute metric
            pred_batch = pred_batch.permute(1, 0)  # (seq, batch) -> (batch, seq)
            preds = [lang.idx_to_seq(item) for item in pred_batch.cpu().numpy()]
            gts = [lang.idx_to_seq(item) for item in target_batch]
            score = get_score(preds, gts)
            
            bar.set_description(f'Eval loss: {loss.item():.3f} score: {score:.3f} acc: {acc:.3f}')
            
            pred_text += preds
            gt_text += gts
            
    score = get_score(pred_text, gt_text)
    print('val score:', score)    
    return score

def train(train_iter, epoch):
    pred_texts = []
    gt_texts = []
    generator.train()
    bar = tqdm(train_iter, desc='training')
    for data in bar:
        imgs_batch, source_padded, lens, target_batch  = load_data(data)
        outs_packed = generator(imgs_batch, source_padded, lens)
        target_packed, batch_sizes, _, _ = pack_sequence(target_batch)
        target_packed = target_packed.to(cfg['device'])
        loss = criterion(outs_packed, target_packed)
        
        ### Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        ## compute acc
        pred_packed = torch.argmax(torch.softmax(outs_packed, dim=-1), dim=-1)
        pred_packed = pred_packed.cpu().detach().numpy()
        target_packed = target_packed.cpu().numpy()
        acc = get_accuracy(pred_packed, target_packed)
        
        ## compute distance
        pred_padded = pad_packed(pred_packed, batch_sizes)[0].transpose(1, 0) # (seq, N) -> (N, seq)
        pred_text = [lang.idx_to_seq(item) for item in pred_padded]
        gt_text = [lang.idx_to_seq(item) for item in target_batch]
        score = get_score(pred_text, gt_text)
        
        ## log
        bar.set_description(f'Train epoch: {epoch+1} loss: {loss.item():.3f} score: {score:.3f} acc: {acc:.3f}')
        pred_texts += pred_text
        gt_texts += gt_text
        
    score = get_score(pred_texts, gt_texts)
    print('train score:', score)
    return generator

In [None]:
from tqdm import tqdm

for epoch in range(cfg['n_epochs']):
    ## train
    generator = train(train_iter, epoch)
    ## evaluate
    score = evaluate(valid_iter, generator)
    scheduler.step(score)
    if es.step(generator, score):
        print('early stop!')
        break;

In [None]:
model = torch.load('model.pth')
checkpoint = {}
checkpoint['net'] = model.state_dict()
checkpoint['train_cfg'] = cfg
torch.save(checkpoint, f"{cfg['version']}/checkpont.pth")  # save model