In [None]:
!pip install catalyst==20.12 python-Levenshtein

In [2]:
import pandas as pd
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, RandomVerticalFlip, RandomHorizontalFlip, ToPILImage
import torchvision
import catalyst
import random
from catalyst import dl, utils
from catalyst.callbacks.scheduler import SchedulerCallback
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import Levenshtein
import cv2
from PIL import Image

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
!unzip /content/drive/MyDrive/image2text/train.zip -d /content/train/
!unzip /content/drive/MyDrive/image2text/train_labels.csv.zip

!unzip /content/drive/MyDrive/image2text/test.zip -d /content/test/
!cp /content/drive/MyDrive/image2text/sample_submission.csv /content/sample_submission.csv

In [7]:
TRAIN_SIZE = 0.9
TRAIN_PATH = '/content/train_labels.csv'
TEST_PATH = '/content/sample_submission.csv'
TRAIN_IMAGES = '/content/train'
TEST_IMAGES = '/content/test'
submission = pd.read_csv(TEST_PATH)

In [9]:
df = pd.read_csv(TRAIN_PATH)
df['len'] = df['InChI'].apply(len)
df['bucket'] = pd.qcut(df['len'], 10)
df_train, df_val = train_test_split(df, train_size = TRAIN_SIZE, random_state = 42, stratify = df['bucket'])
# df_val, df_test = train_test_split(df_val_test, train_size = 0.33, stratify = df_val_test['bucket'])

In [10]:
char_dict = {'<PAD>': 0,
             'InChI=1S/': 1,
             '<UNK>': 2, 
             '<EOS>': 3}
for _, row in tqdm(df_train.iterrows()):
  for char in row['InChI']:
    if char not in char_dict:
      char_dict[char] = len(char_dict)

indices_dict = dict(map(lambda x: (x[1], x[0]), char_dict.items()))

2181767it [02:48, 12972.00it/s]


In [11]:
class Config:
  max_len = 250
  batch_size = 128
  hidden_size = 300
  cell_size = 300
  num_layers = 2
  emb_size = 200
  attention_dim = 300
  teacher_forcing = 1
  dropout_1 = 0.3
  dropout_2 = 0.2
  image_embedding = 512
  char_dict = char_dict
  indices_dict = indices_dict
  vocab_size = len(char_dict)

config = Config

In [12]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        
        self.A = nn.Linear(attention_dim,1)
          
    def forward(self, features, hidden_state):
        hidden_state = hidden_state.mean(0)
        u_hs = self.U(features)     #(batch_size,64,attention_dim)
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,64,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,64,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,64)
        
        
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,64)
        
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,64,features_dim)
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,64)
        
        return alpha,attention_weights


class CNN(nn.Module):
  def __init__(self, emb_size = 512):
    super().__init__()
    self.emb_size = emb_size
    resnet = torchvision.models.resnet34(pretrained=True)
    modules = list(resnet.children())[:-2]
    self.resnet = nn.Sequential(*modules)

    # self.conv1 = Conv2D(3, 6, 256)
    # self.conv2 = Conv2D(6, 12, 128)
    # self.conv3 = Conv2D(12, 24, 64)
    # self.conv4 = Conv2D(24, 24, 32)
    # self.fc = nn.Sequential(
    #     nn.Linear(32768, emb_size),
    # )

  def forward(self, x):
    x = self.resnet(x)
    # x = x.reshape(x.shape[0], -1)
    # x = self.fc(x)
    x = x.permute(0, 2, 3, 1)                           #(batch_size,8,8,512)
    x = x.view(x.size(0), -1, x.size(-1))
    return x


class RNN(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.hidden_size = config.hidden_size
    self.embedding = nn.Embedding(config.vocab_size, config.emb_size)
    self.dropout_emb = nn.Dropout(config.dropout_1)
    self.dropout_hidden = nn.Dropout(config.dropout_2)
    self.rnn = nn.LSTM(input_size=config.emb_size + config.image_embedding,
                       hidden_size=config.hidden_size,
                       num_layers = config.num_layers,
                       batch_first = True)
    self.attention = Attention(config.image_embedding, config.hidden_size, config.attention_dim)
    self.fc_score = nn.Linear(config.hidden_size, config.vocab_size)

  def forward(self, indices, features, hidden, cell, teacher_forcing_ratio = 0.5):
    #input - индексы
    input = self.embedding(indices)
    outputs = torch.zeros(input.shape[0], config.max_len, config.vocab_size, requires_grad=True).cuda()
    input = input[:, 0] #(batch_size, emb_size)

    for t in range(1, config.max_len-1):
      alpha, attn = self.attention(features, hidden) #attn = (batch_size, 64)
      input, attn = input.unsqueeze(1), attn.unsqueeze(1)

      output, (hidden, cell) = self.rnn(torch.cat([input, attn], dim = -1), (hidden, cell))
      hidden = self.dropout_hidden(hidden)

      output = self.fc_score(output.squeeze())
      outputs[:, t] = output

      teacher_force = random.random() < teacher_forcing_ratio
      top1 = output.max(1)[1]
      input = (indices[:, t] if teacher_force else top1)
      input = self.dropout_emb(self.embedding(input))

    return outputs

  def generate(self, features, hidden, cell):

    outputs = torch.zeros(features.shape[0], config.max_len, config.vocab_size, requires_grad=True).cuda()
    input = config.char_dict['InChI=1S/']*torch.ones(features.shape[0]).long().cuda()
    input = self.embedding(input)

    for t in range(1, config.max_len-1):
      alpha, attn = self.attention(features, hidden) #attn = (batch_size, 64)
      input, attn = input.unsqueeze(1), attn.unsqueeze(1)

      output, (hidden, cell) = self.rnn(torch.cat([input, attn], dim = -1), (hidden, cell))
      hidden = self.dropout_hidden(hidden)

      output = self.fc_score(output.squeeze())
      outputs[:, t] = output

      input = output.max(1)[1]
      input = self.dropout_emb(self.embedding(input))

    return outputs


class Img2Text(nn.Module):
  def __init__(self, encoder, decoder, config):
    super().__init__()
    self.config = config
    self.fc_h0 = nn.Linear(encoder.emb_size, decoder.hidden_size)
    self.fc_c0 = nn.Linear(encoder.emb_size, decoder.hidden_size)

    self.encoder = encoder
    self.decoder = decoder

  def forward(self, image, text, teacher_forcing_ratio = 0.5):
    img_vector = self.encoder(image) #(batch, 64, 512)

    hidden = self.fc_h0(img_vector.mean(1).squeeze())
    cell = self.fc_c0(img_vector.mean(1).squeeze())
    hidden = hidden.reshape(1, hidden.shape[0], -1)
    hidden = hidden.repeat(self.config.num_layers, 1, 1) #(num_layers, batch, hidden)
    cell = cell.reshape(1, cell.shape[0], -1)
    cell = cell.repeat(self.config.num_layers, 1, 1)

    outputs = self.decoder(text, img_vector, hidden, cell, teacher_forcing_ratio)
    return outputs
  
  def generate(self, image):
    self.encoder.eval()
    self.decoder.eval()

    with torch.no_grad():
      img_vector = self.encoder(image)

      hidden = self.fc_h0(img_vector.mean(1).squeeze())
      cell = self.fc_c0(img_vector.mean(1).squeeze())
      hidden = hidden.reshape(1, hidden.shape[0], -1)
      hidden = hidden.repeat(self.config.num_layers, 1, 1) #(num_layers, batch, hidden)
      cell = cell.reshape(1, cell.shape[0], -1)
      cell = cell.repeat(self.config.num_layers, 1, 1)

      outputs = self.decoder.generate(img_vector, hidden, cell)
    return outputs

In [13]:
def convert_to_indices(char_dict, string):
  indices = [char_dict['InChI=1S/']]
  for char in string[6:]:
    if char not in char_dict:
      indices_append(char_dict['<UNK>'])
    else:
      indices.append(char_dict[char])
  indices += [char_dict['<EOS>']]
  return indices

def convert_to_string(indices_dict, indices):
  string = 'InChI='
  if isinstance(indices, torch.Tensor):
    indices = indices.detach().cpu().numpy()
  for indx in indices[1:]:
    if indices_dict[indx] == '<EOS>' or indices_dict[indx] == '<PAD>':
      break
    else:
      string += indices_dict[indx]
  return string

def get_path(mode, image_id):
    if mode == 'train':
      return f'{TRAIN_IMAGES}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    elif mode == 'test':
      return f'{TEST_IMAGES}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    else:
      raise NameError

class Dataset(torch.utils.data.Dataset):
  def __init__(self, df, mode, char_dict, transforms):
    super().__init__()
    self.df = df
    self.mode = mode
    self.char_dict = char_dict
    self.transform = transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self, indx):
    row = self.df.iloc[indx]
    image_path = get_path(self.mode, row['image_id'])
    indices = torch.tensor(convert_to_indices(self.char_dict, row['InChI']))
    return {'images': self.get_image(image_path),
            'indices': indices}

  def get_image(self, image_path):
    img = cv2.imread(image_path)
    img = self.transform(img)
    return img

class TestDataset(torch.utils.data.Dataset):
  def __init__(self, df, transforms):
    super().__init__()
    self.df = df
    self.transform = transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self, indx):
    row = self.df.iloc[indx]
    image_path = get_path('test', row['image_id'])
    return {'images': self.get_image(image_path)}

  def get_image(self, image_path):
    transform = A.Compose([A.Transpose(p=1),
                           A.VerticalFlip(p=1)
                          ])
    
    img = cv2.imread(image_path)
    h, w, _ = img.shape
    if h > w:
      img = transform(image=img)['image']
    img = self.transform(img)
    return img

def collate_fn(batch):
  images = torch.cat([item['images'].unsqueeze(0) for item in batch], dim = 0)
  indices = torch.zeros(len(batch), config.max_len).long()
  for i, item in enumerate(batch):
    ind = item['indices'][:config.max_len]
    indices[i][:len(ind)] = ind

  return {'images': images,
          'indices': indices}

In [14]:
def calculate_accuracy(logits, targets, mask):
    mask = mask.detach().cpu().numpy()
    logits = torch.argmax(logits, dim = -1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return accuracy_score(targets[mask], logits[mask])

def calculate_f1(logits, targets, mask):
    mask = mask.detach().cpu().numpy()
    logits = torch.argmax(logits, dim = -1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return f1_score(targets[mask], logits[mask], average = 'micro')

def calculate_levenstein(logits, targets, indices_dict):
  distances = []
  logits = torch.argmax(logits, dim = -1)
  targets = targets.detach()
  for indx in range(len(targets)):
    distances.append(Levenshtein.distance(convert_to_string(indices_dict, targets[indx]), 
                                          convert_to_string(indices_dict, logits[indx])))
  return np.array(distances).mean()

def calculate_levenstein_test(true_string, target_string):
  distances = []
  for indx in range(len(target_string)):
    distances.append(Levenshtein.distance(true_string[indx][6:], 
                                          target_string[indx][6:]))
  return np.array(distances)

def predict(test_df, model, test_loader, indices_dict):
    test_df = test_df.copy()
    InChI = []
    model.eval()
    with torch.no_grad():
      for batch in tqdm(test_loader):
        images = batch["images"].cuda()
        logits = model.generate(images)
        pred_indices = torch.argmax(logits, dim = -1)
        for row in pred_indices:
          InChI.append(convert_to_string(indices_dict, row))
    test_df['InChI'] = InChI
    return test_df

In [15]:
class CustomRunner(dl.Runner):

    def _handle_batch(self, batch):
        images = batch["images"]
        indices = batch["indices"]

        if self.loader_key == 'valid':
          self.model.eval()
          logits = self.model(images, indices[:, :-1], 0)
        else:
          self.model.train()
          logits = self.model(images, indices[:, :-1], config.teacher_forcing)
        output = logits[:, 1:].reshape(-1, logits.shape[-1])
        trg = indices[:, 1:].reshape(-1)
          
        loss = criterion(output, trg)
        levenstein = calculate_levenstein(logits, indices, config.indices_dict)

        batch_metrics = {
              "loss": loss,
              "f1_score": calculate_f1(output, trg, torch.where(trg)[0]),
              "levenstein": levenstein
              }
        self.batch_metrics.update(batch_metrics)

        if self.is_train_loader:
          loss.backward()
          # nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)
          self.optimizer.step()
          self.optimizer.zero_grad()

In [16]:
from torch.optim.lr_scheduler import OneCycleLR

EPOCHS = 3
transform = Compose([
    ToPILImage('RGB'),
    Resize((256,256)),
    RandomVerticalFlip(p = 0.5),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

transform_test = Compose([
    ToPILImage('RGB'),
    Resize((256,256)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

train_dataset = Dataset(df_train, 'train', char_dict, transform)
val_dataset = Dataset(df_val, 'train', char_dict, transform_test)
# test_dataset = Dataset(df_test, 'train', char_dict, transform)

train_loader = torch.utils.data.DataLoader(train_dataset, config.batch_size, num_workers = 4, collate_fn=collate_fn, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, config.batch_size, num_workers = 4, collate_fn = collate_fn)
# test_loader = torch.utils.data.DataLoader(test_dataset, config.batch_size, collate_fn=collate_fn)

encoder = CNN(config.image_embedding)
decoder = RNN(config)
model = Img2Text(encoder, decoder, config)
optimizer = torch.optim.Adam(model.parameters(), weight_decay = 1e-4)
criterion = nn.CrossEntropyLoss(ignore_index = 0)
scheduler = OneCycleLR(optimizer, 1e-3, total_steps=len(train_loader)*EPOCHS)

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




In [None]:
runner = CustomRunner()
runner.train(
    model=model,
    criterion= criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders={'train': train_loader,
             'valid': val_loader
             },
    logdir="/content/drive/MyDrive/image2text/logs",
    num_epochs=EPOCHS,
    verbose=True,
    load_best_on_end=True,
    overfit=False,
    callbacks=[SchedulerCallback(reduced_metric = 'loss', mode = 'batch')],
    main_metric="levenstein",
    minimize_metric=True,
)

In [17]:
tqdm._instances.clear()

In [None]:
# model.load_state_dict(torch.load('/content/drive/MyDrive/image2text/logs/checkpoints/best.pth')['model_state_dict'])
model.cuda()
model.eval()

In [None]:
pd_df = predict(df_test, model, test_loader, config.indices_dict)
pd_df['InChI_true'] = df_test['InChI']
pd_df = pd_df.assign(levenstein = lambda x: calculate_levenstein_test(x.InChI_true.values, x.InChI.values))

In [None]:
pd_df.levenstein.mean()

In [77]:
test_dataset = TestDataset(submission, transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, config.batch_size)
pd_df = predict(submission, model, test_loader, config.indices_dict)
pd_df.to_csv('/content/drive/MyDrive/image2text/submission.csv', index = False)

100%|██████████| 12626/12626 [2:32:43<00:00,  1.38it/s]
