In [None]:
!pip install catalyst==20.12 python-Levenshtein

In [19]:
import pandas as pd
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, RandomVerticalFlip, RandomHorizontalFlip, ToPILImage
import albumentations as A
from albumentations.pytorch.transforms import ToTensor
import torchvision
import catalyst
import random
from catalyst import dl, utils
from catalyst.callbacks.scheduler import SchedulerCallback
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from torch.optim.lr_scheduler import OneCycleLR, StepLR
import Levenshtein
import pickle
import cv2
from PIL import Image
tqdm.pandas()

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/image2text/train.zip -d /content/train/
!unzip /content/drive/MyDrive/image2text/train_labels.csv.zip

!unzip /content/drive/MyDrive/image2text/test.zip -d /content/test/
!cp /content/drive/MyDrive/image2text/sample_submission.csv /content/sample_submission.csv

In [20]:
TRAIN_SIZE = 0.98
TRAIN_PATH = '/content/train_labels.csv'
TEST_PATH = '/content/sample_submission.csv'
TRAIN_IMAGES = '/content/train/'
TEST_IMAGES = '/content/test/'
submission = pd.read_csv(TEST_PATH)

In [21]:
df = pd.read_csv(TRAIN_PATH)
df['len'] = df['InChI'].apply(len)
df['bucket'] = pd.qcut(df['len'], 10)
df_train, df_val = train_test_split(df, train_size = TRAIN_SIZE, random_state = 42, stratify = df['bucket'])
# df_val, df_test = train_test_split(df_val_test, train_size = 0.33, stratify = df_val_test['bucket'])

In [22]:
char_dict = {'<PAD>': 0,
             'InChI=1S/': 1,
             '<UNK>': 2, 
             '<EOS>': 3}
for _, row in tqdm(df_train.iloc[:100000].iterrows()):
  for char in row['InChI']:
    if char not in char_dict:
      char_dict[char] = len(char_dict)

indices_dict = dict(map(lambda x: (x[1], x[0]), char_dict.items()))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [23]:
# гиперпараметры чувака который выложил ноутбук
# weight_decay=1e-6
# gradient_accumulation_steps=1
# max_grad_norm=5
# attention_dim=256
# embed_dim=256
# decoder_dim=512
# dropout=0.5
# num_layers = 1
class Config:
  max_len = 250
  batch_size = 128
  hidden_size = 300
  cell_size = 300
  num_layers = 2
  emb_size = 300
  attention_dim = 300
  teacher_forcing = 1
  dropout_emb = 0.3
  dropout_hidden = 0.2
  image_embedding = 512
  device = torch.device('cuda:0')
  char_dict = char_dict
  indices_dict = indices_dict
  vocab_size = len(char_dict)

config = Config

In [24]:
import numpy as np
from random import shuffle
from torch.utils.data import Sampler
import torch
import math
from tqdm import tqdm


class SequenceLengthSampler(Sampler):

    def __init__(self, data_source,
                bucket_boundaries = [32, 86, 96, 104, 112, 121, 131, 142, 154, 172],
                batch_size=128, drop_last=False):
        self.data_source = data_source

        self.ind_n_len = pickle.load(open('/content/drive/MyDrive/image2text/sampler/sampler_lens', 'rb'))
        self.bucket_boundaries = bucket_boundaries
        self.batch_size = batch_size
        self.drop_last = drop_last

        self.boundaries = list(self.bucket_boundaries)
        self.buckets_min = torch.tensor([np.iinfo(np.int32).min] + self.boundaries)
        self.buckets_max = torch.tensor(self.boundaries + [np.iinfo(np.int32).max])
        self.boundaries = torch.tensor(self.boundaries)

    def shuffle_tensor(self, t):
        return t[torch.randperm(len(t))]

    #оптимизировать, а то долго  
    def __iter__(self):
        data_buckets = dict()
        # where p is the id number and seq_len is the length of this id number. 
        for p, seq_len in self.ind_n_len:
            pid = self.element_to_bucket_id(p,seq_len)
            if pid in data_buckets.keys():
                data_buckets[pid].append(p)
            else:
                data_buckets[pid] = [p]

        for k in data_buckets.keys():

            data_buckets[k] = torch.tensor(data_buckets[k])

        iter_list = []
        for k in data_buckets.keys():

            t = self.shuffle_tensor(data_buckets[k])
            batch = torch.split(t, self.batch_size, dim=0)

            if self.drop_last and len(batch[-1]) != self.batch_size:
                batch = batch[:-1]

            iter_list += batch

        shuffle(iter_list) # shuffle all the batches so they arent ordered by bucket
        # size
        for i in iter_list: 
            yield i.numpy().tolist() # as it was stored in an array
    
    def __len__(self):
        return len(self.data_source)
    
    def element_to_bucket_id(self, x, seq_length):

        valid_buckets = (seq_length >= self.buckets_min)*(seq_length < self.buckets_max)
        bucket_id = torch.nonzero(valid_buckets)[0].item()
        
        return bucket_id

In [25]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        
        self.A = nn.Linear(attention_dim,1)
          
    def forward(self, features, hidden_state):
        hidden_state = hidden_state.mean(0)
        u_hs = self.U(features)     #(batch_size,64,attention_dim)
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,64,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,64,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,64)
        
        
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,64)
        
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,64,features_dim)
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,64)
        
        return alpha,attention_weights


class CNN(nn.Module):
  def __init__(self, n_channels = 3, pretrained = True, emb_size = 512):
    super().__init__()
    self.emb_size = emb_size
    if n_channels == 1:
      resnet = torchvision.models.resnet34(pretrained=False)
      resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias = False)
    else:
      resnet = torchvision.models.resnet34(pretrained=pretrained)

    modules = list(resnet.children())[:-2]
    self.resnet = nn.Sequential(*modules)

    # self.conv1 = Conv2D(3, 6, 256)
    # self.conv2 = Conv2D(6, 12, 128)
    # self.conv3 = Conv2D(12, 24, 64)
    # self.conv4 = Conv2D(24, 24, 32)
    # self.fc = nn.Sequential(
    #     nn.Linear(32768, emb_size),
    # )

  def forward(self, x):
    x = self.resnet(x)
    # x = x.reshape(x.shape[0], -1)
    # x = self.fc(x)
    x = x.permute(0, 2, 3, 1)                           #(batch_size,8,8,512)
    x = x.view(x.size(0), -1, x.size(-1))
    return x


class RNN(nn.Module):
  def __init__(self, config, pretrained = False):
    super().__init__()
    self.hidden_size = config.hidden_size
    self.embedding = self.get_embedding(pretrained)
    self.dropout_emb = nn.Dropout(config.dropout_emb)
    self.dropout_hidden = nn.Dropout(config.dropout_hidden)
    self.rnn = nn.LSTM(input_size=config.emb_size + config.image_embedding,
                       hidden_size=config.hidden_size,
                       num_layers = config.num_layers,
                       batch_first = True)
    self.attention = Attention(config.image_embedding, config.hidden_size, config.attention_dim)
    self.fc_score = nn.Linear(config.hidden_size, config.vocab_size)

  def get_embedding(self, pretrained = False):
    if pretrained:
      weights = np.load('/content/drive/MyDrive/image2text/models/transformer_emb_300.npy')
      weights = torch.from_numpy(weights)
      return nn.Embedding.from_pretrained(weights, freeze = False)
    else:
      return nn.Embedding(config.vocab_size, config.emb_size)

  def forward(self, indices, features, hidden, cell, teacher_forcing_ratio = 0.5):
    #input - индексы
    input = self.embedding(indices)
    outputs = torch.zeros(input.shape[0], config.max_len, config.vocab_size).to(config.device)
    input = input[:, 0] #(batch_size, emb_size)

    for t in range(1, config.max_len-1):
      alpha, attn = self.attention(features, hidden) #attn = (batch_size, 64)
      input, attn = input.unsqueeze(1), attn.unsqueeze(1)

      output, (hidden, cell) = self.rnn(torch.cat([input, attn], dim = -1), (hidden, cell))
      hidden = self.dropout_hidden(hidden)

      output = self.fc_score(output.squeeze())
      outputs[:, t] = output

      teacher_force = random.random() < teacher_forcing_ratio
      top1 = output.max(1)[1]
      input = (indices[:, t] if teacher_force else top1)
      input = self.dropout_emb(self.embedding(input))

    return outputs

  def generate(self, features, hidden, cell):

    outputs = torch.zeros(features.shape[0], config.max_len, config.vocab_size, device = config.device)
    input = config.char_dict['InChI=1S/']*torch.ones(features.shape[0], device = config.device).long()
    input = self.embedding(input)

    for t in range(1, config.max_len-1):
      alpha, attn = self.attention(features, hidden) #attn = (batch_size, 64)
      input, attn = input.unsqueeze(1), attn.unsqueeze(1)

      output, (hidden, cell) = self.rnn(torch.cat([input, attn], dim = -1), (hidden, cell))
      hidden = self.dropout_hidden(hidden)

      output = self.fc_score(output.squeeze())
      outputs[:, t] = output

      input = output.max(1)[1]
      input = self.dropout_emb(self.embedding(input))

    return outputs


class Img2Text(nn.Module):
  def __init__(self, encoder, decoder, config):
    super().__init__()
    self.config = config
    self.fc_h0 = nn.Linear(encoder.emb_size, decoder.hidden_size)
    self.fc_c0 = nn.Linear(encoder.emb_size, decoder.hidden_size)

    self.encoder = encoder
    self.decoder = decoder

  def forward(self, image, text, teacher_forcing_ratio = 0.5):
    img_vector = self.encoder(image) #(batch, 64, 512)

    hidden = self.fc_h0(img_vector.mean(1).squeeze())
    cell = self.fc_c0(img_vector.mean(1).squeeze())
    hidden = hidden.reshape(1, hidden.shape[0], -1)
    hidden = hidden.repeat(self.config.num_layers, 1, 1) #(num_layers, batch, hidden)
    cell = cell.reshape(1, cell.shape[0], -1)
    cell = cell.repeat(self.config.num_layers, 1, 1)

    outputs = self.decoder(text, img_vector, hidden, cell, teacher_forcing_ratio)
    return outputs
  
  def generate(self, image):
    self.encoder.eval()
    self.decoder.eval()

    with torch.no_grad():
      img_vector = self.encoder(image)

      hidden = self.fc_h0(img_vector.mean(1).squeeze())
      cell = self.fc_c0(img_vector.mean(1).squeeze())
      hidden = hidden.reshape(1, hidden.shape[0], -1)
      hidden = hidden.repeat(self.config.num_layers, 1, 1) #(num_layers, batch, hidden)
      cell = cell.reshape(1, cell.shape[0], -1)
      cell = cell.repeat(self.config.num_layers, 1, 1)

      outputs = self.decoder.generate(img_vector, hidden, cell)
    return outputs

In [26]:
def convert_to_indices(char_dict, string):
  indices = [char_dict['InChI=1S/']]
  for char in string[9:]:
    if char not in char_dict:
      indices.append(char_dict['<UNK>'])
    else:
      indices.append(char_dict[char])
  indices += [char_dict['<EOS>']]
  return indices

def convert_to_string(indices_dict, indices):
  string = 'InChI=1S/'
  if isinstance(indices, torch.Tensor):
    indices = indices.detach().cpu().numpy()
  for indx in indices[1:]:
    if indices_dict[indx] == '<EOS>' or indices_dict[indx] == '<PAD>':
      break
    else:
      string += indices_dict[indx]
  return string

def get_path(mode, image_id):
    if mode == 'train':
      return f'{TRAIN_IMAGES}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    elif mode == 'test':
      return f'{TEST_IMAGES}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    else:
      raise NameError

class Dataset(torch.utils.data.Dataset):
  def __init__(self, df, mode, char_dict, transforms):
    super().__init__()
    self.df = df
    self.mode = mode
    self.char_dict = char_dict
    self.transform = transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self, indx):
    row = self.df.iloc[indx]
    image_path = get_path(self.mode, row['image_id'])
    indices = torch.tensor(convert_to_indices(self.char_dict, row['InChI']))
    return {'images': self.get_image(image_path),
            'indices': indices}

  def get_image(self, image_path):
    img = cv2.imread(image_path)
    img = self.transform(image=img)['image']
    return img

class TestDataset(torch.utils.data.Dataset):
  def __init__(self, df, transforms):
    super().__init__()
    self.df = df
    self.transform = transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self, indx):
    row = self.df.iloc[indx]
    image_path = get_path('test', row['image_id'])
    return {'images': self.get_image(image_path)}

  def get_image(self, image_path):
    transform = A.Compose([A.Transpose(p=1),
                           A.VerticalFlip(p=1)
                          ])
    
    img = cv2.imread(image_path)
    h, w, _ = img.shape
    if h > w:
      img = transform(image=img)['image']
    img = self.transform(image=img)['image']
    return img

def collate_fn(batch):
  images = torch.cat([item['images'].unsqueeze(0) for item in batch], dim = 0)
  indices = torch.zeros(len(batch), config.max_len).long()
  for i, item in enumerate(batch):
    ind = item['indices'][:config.max_len]
    indices[i][:len(ind)] = ind

  return {'images': images,
          'indices': indices}

In [27]:
def calculate_accuracy(logits, targets, mask):
    mask = mask.detach().cpu().numpy()
    logits = torch.argmax(logits, dim = -1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return accuracy_score(targets[mask], logits[mask])

def calculate_f1(logits, targets, mask):
    mask = mask.detach().cpu().numpy()
    logits = torch.argmax(logits, dim = -1).detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()
    return f1_score(targets[mask], logits[mask], average = 'micro')

def calculate_levenstein(logits, targets, indices_dict):
  distances = []
  logits = torch.argmax(logits, dim = -1)
  targets = targets.detach()
  for indx in range(len(targets)):
    distances.append(Levenshtein.distance(convert_to_string(indices_dict, targets[indx]), 
                                          convert_to_string(indices_dict, logits[indx])))
  return np.array(distances).mean()

def calculate_levenstein_test(indices, target_indices):
  distances = []
  indices = indices.detach().cpu().numpy()
  target_indices = target_indices.cpu().numpy()
  for i in range(len(target_indices)):
    distances.append(Levenshtein.distance(convert_to_string(config.indices_dict, target_indices[i])[6:], 
                                          convert_to_string(config.indices_dict, indices[i])[6:]))
  return np.array(distances).mean()

def predict(test_df, model, test_loader, indices_dict):
    test_df = test_df.copy()
    InChI = []
    model.eval()
    with torch.no_grad():
      for batch in tqdm(test_loader):
        images = batch["images"].cuda()
        logits = model.generate(images)
        pred_indices = torch.argmax(logits, dim = -1)
        for row in pred_indices:
          InChI.append(convert_to_string(indices_dict, row))
    test_df['InChI'] = InChI
    return test_df

In [28]:
class CustomRunner(dl.Runner):

    def _handle_batch(self, batch):
        images = batch["images"]
        indices = batch["indices"]

        if self.loader_key == 'valid':
          self.model.eval()
          logits = self.model(images, indices[:, :-1], 0)
        else:
          self.model.train()
          logits = self.model(images, indices[:, :-1], config.teacher_forcing)
        output = logits[:, 1:].reshape(-1, logits.shape[-1])
        trg = indices[:, 1:].reshape(-1)
          
        loss = criterion(output, trg)
        levenstein = calculate_levenstein(logits, indices, config.indices_dict)

        batch_metrics = {
              "loss": loss,
              "f1_score": calculate_f1(output, trg, torch.where(trg)[0]),
              "levenstein": levenstein
              }
        self.batch_metrics.update(batch_metrics)

        if self.is_train_loader:
          loss.backward()
          # nn.utils.clip_grad_norm_(model.parameters(), 3.0, 2.0)
          self.optimizer.step()
          self.optimizer.zero_grad()

In [29]:
EPOCHS = 1
lr = 1e-5
pretrained = True
ratio = 0.1
n_channels = 3
transform_train  = A.Compose([
        A.Resize(256, 256),
        A.RandomRotate90(p=.5),
        A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                        always_apply=True),
        ToTensor()
        ])
transform_test = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                        always_apply=True),
    ToTensor()
    ])

train_dataset = Dataset(df_train, 'train', char_dict, transform_train)
val_dataset = Dataset(df_val, 'train', char_dict, transform_test)
# test_dataset = Dataset(df_test, 'train', char_dict, transform)
# sampler = SequenceLengthSampler(train_dataset, batch_size =config.batch_size)
#потом сохранить эту хуйню, чтобы не ждать в будущем по 2 часа
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           config.batch_size,
                                           num_workers = 4, 
                                           collate_fn=collate_fn,
                                           shuffle = True)
val_loader = torch.utils.data.DataLoader(val_dataset, config.batch_size, num_workers = 4, collate_fn = collate_fn)
# test_loader = torch.utils.data.DataLoader(test_dataset, config.batch_size, collate_fn=collate_fn)

encoder = CNN(n_channels = n_channels, pretrained = pretrained, emb_size = config.image_embedding)
decoder = RNN(config, pretrained)
model = Img2Text(encoder, decoder, config)
optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay = 1e-4)
criterion = nn.CrossEntropyLoss(ignore_index = 0).to(config.device)
# scheduler = OneCycleLR(optimizer, lr, pct_start = ratio, total_steps=len(train_loader)*EPOCHS)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99987)

In [30]:
checkpoint = torch.load("/content/drive/MyDrive/image2text/logs/checkpoints/last.pth")
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# # criterion.load_state_dict(checkpoint['criterion_state_dict'])
# model.train()
# EPOCH = checkpoint['epoch']

<All keys matched successfully>

In [None]:
runner = CustomRunner(device = config.device)
runner.train(
    model=model,
    criterion= criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders={
        'train': train_loader,
        'valid': val_loader,
           },
    logdir="/content/drive/MyDrive/image2text/logs",
    num_epochs=EPOCHS,
    verbose=True,
    load_best_on_end=True,
    overfit=False,
    callbacks=[SchedulerCallback(reduced_metric = 'loss', mode = 'batch')],
    main_metric="levenstein",
    minimize_metric=True
)
# <Попробовать на tpu,55 V100 1s/it, P100 1.4s/it>
# <Попробовать с одним каналом>

1/1 * Epoch (train): 100% 18561/18561 [5:04:40<00:00,  1.02it/s, f1_score=0.914, levenstein=9.318, loss=0.227, lr=8.954e-07, momentum=0.900]
1/1 * Epoch (valid): 100% 379/379 [02:40<00:00,  2.36it/s, f1_score=0.631, levenstein=22.010, loss=3.303]
[2021-03-25 12:51:24,691] 
1/1 * Epoch 1 (train): f1_score=0.9290 | levenstein=8.7121 | loss=0.1994 | lr=3.773e-06 | momentum=0.9000
1/1 * Epoch 1 (valid): f1_score=0.6585 | levenstein=18.9016 | loss=3.0847
Top best models:
/content/drive/MyDrive/image2text/logs/checkpoints/train.1.pth	18.9016
=> Loading checkpoint /content/drive/MyDrive/image2text/logs/checkpoints/best_full.pth
loaded state checkpoint /content/drive/MyDrive/image2text/logs/checkpoints/best_full.pth (global epoch 1, epoch 1, stage train)


In [None]:
tqdm._instances.clear()

In [None]:
# model.load_state_dict(torch.load('/content/drive/MyDrive/image2text/logs/checkpoints/best.pth')['model_state_dict'])
model.cuda()
model.eval()

In [None]:
pd_df = predict(df_test, model, test_loader, config.indices_dict)
pd_df['InChI_true'] = df_test['InChI']
pd_df = pd_df.assign(levenstein = lambda x: calculate_levenstein_test(x.InChI_true.values, x.InChI.values))

In [None]:
pd_df.levenstein.mean()

In [17]:
test_dataset = TestDataset(submission, transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset, config.batch_size)
pd_df = predict(submission, model, test_loader, config.indices_dict)
pd_df.to_csv('/content/drive/MyDrive/image2text/submission.csv', index = False)

100%|██████████| 12626/12626 [2:31:58<00:00,  1.38it/s]


In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/image2text/models/resnet34_transformer.pth')['model_state_dict'])
model.cuda()
model.eval()