In [None]:
!pip install catalyst==20.12 python-Levenshtein

In [2]:
import pandas as pd
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, RandomVerticalFlip, RandomHorizontalFlip, ToPILImage
import torchvision
import catalyst
import random
from catalyst import dl, utils
from catalyst.callbacks.scheduler import SchedulerCallback
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import Levenshtein
import cv2
from PIL import Image

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
!unzip /content/drive/MyDrive/image2text/train.zip -d /content/train/
!unzip /content/drive/MyDrive/image2text/train_labels.csv.zip

!unzip /content/drive/MyDrive/image2text/test.zip -d /content/test/
!cp /content/drive/MyDrive/image2text/sample_submission.csv /content/sample_submission.csv

In [3]:
TRAIN_SIZE = 0.98
TRAIN_PATH = '/content/train_labels.csv'
TEST_PATH = '/content/sample_submission.csv'
TRAIN_IMAGES = '/content/train'
TEST_IMAGES = '/content/test'
submission = pd.read_csv(TEST_PATH)

In [4]:
df = pd.read_csv(TRAIN_PATH)
df['len'] = df['InChI'].apply(len)
df['bucket'] = pd.qcut(df['len'], 10)
df_train, df_val = train_test_split(df, train_size = TRAIN_SIZE, random_state = 42, stratify = df['bucket'])
# df_val, df_test = train_test_split(df_val_test, train_size = 0.33, stratify = df_val_test['bucket'])

In [5]:
char_dict = {'<PAD>': 0,
             'InChI=1S/': 1,
             '<UNK>': 2, 
             '<EOS>': 3}
for _, row in tqdm(df_train.iterrows()):
  for char in row['InChI']:
    if char not in char_dict:
      char_dict[char] = len(char_dict)

indices_dict = dict(map(lambda x: (x[1], x[0]), char_dict.items()))

2375702it [04:37, 8552.81it/s]


In [6]:
class Config:
  max_len = 250
  batch_size = 128
  emb_size = 300
  n_heads = 6
  n_layers = 4
  dropout_emb = 0.3
  image_embedding = 512
  dim_feedforward = 1024
  char_dict = char_dict
  indices_dict = indices_dict
  vocab_size = len(char_dict)

config = Config

In [12]:
def generate_square_subsequent_mask(sz):
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    resnet = torchvision.models.resnet34(pretrained=True)
    modules = list(resnet.children())[:-2]
    self.resnet = nn.Sequential(*modules)

  def forward(self, x):
    x = self.resnet(x)
    x = x.permute(0, 2, 3, 1) #(batch_size,8,8,512)
    x = x.view(x.size(0), -1, x.size(-1)) #(batch_size, 64, 512)
    return x 


class Decoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.trg_emb = nn.Embedding(config.vocab_size, config.emb_size)
    # weight = np.load('/content/drive/MyDrive/image2text/models/transformer_emb_300.npy')
    # weight = torch.from_numpy(weight)
    # self.trg_emb = nn.Embedding.from_pretrained(weight)
    self.trg_pos_emb = nn.Embedding(config.max_len, config.emb_size)
    self.dropout_emb = nn.Dropout(config.dropout_emb)
    self.decoder_trs = self.initialize_decoder(config.emb_size, config.n_heads, config.n_layers, config.dim_feedforward)
    self.fc_logits = nn.Linear(config.emb_size, config.vocab_size)

  def forward(self, tgt, src):
    #tgt (batch, tgt_len, emb)
    #src (batch, src_len, emb)
    B, trg_seq_len = tgt.shape 
    trg_positions = (torch.arange(0, trg_seq_len).expand(B, trg_seq_len).cuda())
    trg_mask = generate_square_subsequent_mask(trg_seq_len).cuda()

    embed_trg = self.trg_emb(tgt) + self.trg_pos_emb(trg_positions)
    tgt_padding_mask = tgt == 0
    
    output = self.decoder_trs(
            embed_trg.permute(1,0,2),
            src.permute(1,0,2),   
            tgt_mask=trg_mask, 
            tgt_key_padding_mask = tgt_padding_mask
        ).permute(1,0,2) 
    logits = self.fc_logits(output) #(batch, 250, vocab)
    return logits

  def generate(self, images):
    self.eval()
    with torch.no_grad():
      images = images.cuda()
      B = images.shape[0]
      sos = torch.tensor([config.char_dict['InChI=1S/']], dtype=torch.long).expand(B, 1).cuda()
      input = sos
      for _ in range(config.max_len-1):
          preds = self(input, images)
          preds = torch.argmax(preds, axis=-1)
          input = torch.cat([sos, preds], 1)
      return preds

  def initialize_decoder(self, d_model, n_heads, n_layers, dim_feedforward):
    decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=dim_feedforward)
    transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
    return transformer_decoder


class Img2Text(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.fc_emb = nn.Linear(config.image_embedding, config.emb_size)

    self.encoder = CNN()
    self.decoder = Decoder(config)

  def forward(self, text, image): 
    img_vector = self.encoder(image) #(64, batch, 512)
    img_vector = self.fc_emb(img_vector) #(64, batch, emb_size)

    outputs = self.decoder(text, img_vector) #(batch, seq_len, vocab)
    return outputs

  def generate(self, images):
    img_vector = self.encoder(images)
    img_vector = self.fc_emb(img_vector)
    return self.decoder.generate(img_vector)

In [13]:
def convert_to_indices(char_dict, string):
  indices = [char_dict['InChI=1S/']]
  for char in string[6:]:
    if char not in char_dict:
      indices.append(char_dict['<UNK>'])
    else:
      indices.append(char_dict[char])
  indices += [char_dict['<EOS>']]
  return indices

def convert_to_string(indices_dict, indices):
  string = 'InChI='
  if isinstance(indices, torch.Tensor):
    indices = indices.detach().cpu().numpy()
  for indx in indices[1:]:
    if indices_dict[indx] == '<EOS>' or indices_dict[indx] == '<PAD>':
      break
    else:
      string += indices_dict[indx]
  return string

def get_path(mode, image_id):
    if mode == 'train':
      return f'{TRAIN_IMAGES}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    elif mode == 'test':
      return f'{TEST_IMAGES}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    else:
      raise NameError

class Dataset(torch.utils.data.Dataset):
  def __init__(self, df, mode, char_dict, transforms):
    super().__init__()
    self.df = df
    self.mode = mode
    self.char_dict = char_dict
    self.transform = transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self, indx):
    row = self.df.iloc[indx]
    image_path = get_path(self.mode, row['image_id'])
    indices = torch.tensor(convert_to_indices(self.char_dict, row['InChI']))
    return {'images': self.get_image(image_path),
            'indices': indices}

  def get_image(self, image_path):
    img = cv2.imread(image_path)
    img = self.transform(img)
    return img

class TestDataset(torch.utils.data.Dataset):
  def __init__(self, df, transforms):
    super().__init__()
    self.df = df
    self.transform = transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self, indx):
    row = self.df.iloc[indx]
    image_path = get_path('test', row['image_id'])
    return {'images': self.get_image(image_path)}

  def get_image(self, image_path):
    transform = A.Compose([A.Transpose(p=1),
                           A.VerticalFlip(p=1)
                          ])
    
    img = cv2.imread(image_path)
    h, w, _ = img.shape
    if h > w:
      img = transform(image=img)['image']
    img = self.transform(img)
    return img

def collate_fn(batch):
  images = torch.cat([item['images'].unsqueeze(0) for item in batch], dim = 0)
  indices = torch.zeros(len(batch), config.max_len).long()
  for i, item in enumerate(batch):
    ind = item['indices'][:config.max_len]
    indices[i][:len(ind)] = ind

  return {'images': images,
          'indices': indices}

In [14]:
def calculate_accuracy(logits, targets, mask):
    mask = mask.detach().cpu().numpy()
    logits = torch.argmax(logits, dim = -1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return accuracy_score(targets[mask], logits[mask])

def calculate_f1(logits, targets, mask):
    mask = mask.detach().cpu().numpy()
    logits = torch.argmax(logits, dim = -1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return f1_score(targets[mask], logits[mask], average = 'micro')

def calculate_levenstein(logits, targets, indices_dict):
  distances = []
  logits = torch.argmax(logits, dim = -1)
  targets = targets.detach()
  for indx in range(len(targets)):
    distances.append(Levenshtein.distance(convert_to_string(indices_dict, targets[indx]), 
                                          convert_to_string(indices_dict, logits[indx])))
  return np.array(distances).mean()

def calculate_levenstein_test(indices, target_indices):
  distances = []
  indices = indices.detach().cpu().numpy()
  target_indices = target_indices.cpu().numpy()
  for i in range(len(target_indices)):
    distances.append(Levenshtein.distance(convert_to_string(config.indices_dict, target_indices[i])[6:], 
                                          convert_to_string(config.indices_dict, indices[i])[6:]))
  return np.array(distances).mean()

def predict(test_df, model, test_loader, indices_dict):
    test_df = test_df.copy()
    InChI = []
    model.eval()
    with torch.no_grad():
      for batch in tqdm(test_loader):
        images = batch["images"].cuda()
        logits = model.generate(images)
        pred_indices = torch.argmax(logits, dim = -1)
        for row in pred_indices:
          InChI.append(convert_to_string(indices_dict, row))
    test_df['InChI'] = InChI
    return test_df

In [15]:
class CustomRunner(dl.Runner):

    def _handle_batch(self, batch):
        images = batch["images"]
        indices = batch["indices"]

        if self.loader_key == 'valid':
          self.model.eval()
          output = self.model.generate(images)
          levenstein = calculate_levenstein_test(output, indices)
          loss = 0
          f1 = 0
        else:
          self.model.train()
          logits = self.model(indices[:, :-1], images)
          output = logits.reshape(-1, logits.shape[-1])
          trg = indices[:, 1:].reshape(-1)

          loss = criterion(output, trg)
          f1 = calculate_f1(output, trg, torch.where(trg)[0])
          levenstein = calculate_levenstein(logits, indices, config.indices_dict)

        batch_metrics = {
              "loss": loss,
              "f1_score": f1,
              "levenstein": levenstein
              }
        self.batch_metrics.update(batch_metrics)

        if self.is_train_loader:
          loss.backward()
          # nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)
          self.optimizer.step()
          self.optimizer.zero_grad()

In [16]:
from torch.optim.lr_scheduler import OneCycleLR

EPOCHS = 3
transform = Compose([
    ToPILImage('RGB'),
    Resize((256,256)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

transform_test = Compose([
    ToPILImage('RGB'),
    Resize((256,256)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

train_dataset = Dataset(df_train, 'train', char_dict, transform)
val_dataset = Dataset(df_val, 'train', char_dict, transform_test)
# test_dataset = Dataset(df_test, 'train', char_dict, transform)

train_loader = torch.utils.data.DataLoader(train_dataset, config.batch_size, num_workers = 4, collate_fn=collate_fn, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, config.batch_size, num_workers = 4, collate_fn = collate_fn)
# test_loader = torch.utils.data.DataLoader(test_dataset, config.batch_size, collate_fn=collate_fn)

# encoder = CNN(config.image_embedding)
# decoder = RNN(config)
model = Img2Text(config)
optimizer = torch.optim.Adam(model.parameters(), weight_decay = 1e-4)
criterion = nn.CrossEntropyLoss(ignore_index = 0)
scheduler = OneCycleLR(optimizer, 1e-3, total_steps=len(train_loader)*EPOCHS)

In [None]:
# checkpoint = torch.load("/content/drive/MyDrive/image2text/logs/checkpoints/last_full.pth")
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# criterion.load_state_dict(checkpoint['criterion_state_dict'])
# model.train()
# EPOCH = checkpoint['epoch']

In [None]:
runner = CustomRunner()
runner.train(
    model=model,
    criterion= criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders={
        'train': train_loader,
        'valid': val_loader,
           },
    logdir="/content/drive/MyDrive/image2text/logs",
    num_epochs=1,
    verbose=True,
    load_best_on_end=True,
    overfit=False,
    callbacks=[SchedulerCallback(reduced_metric = 'loss', mode = 'batch')],
    main_metric="levenstein",
    minimize_metric=True,
)

In [30]:
tqdm._instances.clear()

In [None]:
# model.load_state_dict(torch.load('/content/drive/MyDrive/image2text/logs/checkpoints/best.pth')['model_state_dict'])
model.cuda()
model.eval()

In [None]:
pd_df = predict(df_test, model, test_loader, config.indices_dict)
pd_df['InChI_true'] = df_test['InChI']
pd_df = pd_df.assign(levenstein = lambda x: calculate_levenstein_test(x.InChI_true.values, x.InChI.values))

In [None]:
pd_df.levenstein.mean()

In [77]:
test_dataset = TestDataset(submission, transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, config.batch_size)
pd_df = predict(submission, model, test_loader, config.indices_dict)
pd_df.to_csv('/content/drive/MyDrive/image2text/submission.csv', index = False)

100%|██████████| 12626/12626 [2:32:43<00:00,  1.38it/s]
