In [1]:
import numpy as np
import pandas as pd
import gc
import re
import cv2
import os
import time
import math
import random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau


import torchvision.models as models
from torchvision import transforms

from tqdm.auto import tqdm
tqdm.pandas()

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, Blur
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm
import Levenshtein
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

In [2]:
PATH = './data/'
TEST_DIR = PATH + 'test'

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

class CFG : 
    debug = False
    max_len = 275
    print_freq = 1000
    num_workers = 0
    model_name = 'resnet34'
    size = 224
    scheduler='CosineAnnealingLR'
    epochs = 1
    T_max = 4 
    encoder_lr = 1e-4
    decoder_lr = 4e-4
    min_lr = 1e-6
    batch_size = 16
    weight_decay = 1e-6
    gradient_accumulation_steps = 1
    max_grad_norm = 5
    attention_dim = 256
    embed_dim = 256
    decoder_dim = 512
    dropout = 0.5
    seed = 0
    n_fold = 5
    trn_fold = [0]
    train = True

In [3]:
if CFG.debug:
    CFG.epochs = 1
    train = train.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)

In [4]:
if torch.cuda.is_available() : 
    DEVICE = torch.device('cuda')
else : 
    DEVICE = torch.device('cpu')
    
print(torch.__version__, DEVICE)

1.8.0 cuda


In [5]:
# Out of Memory 해결 법
gc.collect()
torch.cuda.empty_cache()

In [6]:
# Code From https://www.kaggle.com/yasufuminakama/inchi-resnet-lstm-with-attention-starter

class Tokenizer(object):
    
    def __init__(self):
        self.stoi = {}
        self.itos = {}

    def __len__(self):
        return len(self.stoi)
    
    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(text.split(' '))
        vocab = sorted(vocab)
        vocab.append('<sos>')
        vocab.append('<eos>')
        vocab.append('<pad>')
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        
    def text_to_sequence(self, text):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        for s in text.split(' '):
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence
    
    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            sequence = self.text_to_sequence(text)
            sequences.append(sequence)
        return sequences

    def sequence_to_text(self, sequence):
        return ''.join(list(map(lambda i: self.itos[i], sequence)))
    
    def sequences_to_texts(self, sequences):
        texts = []
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)
        return texts
    
    def predict_caption(self, sequence):
        caption = ''
        for i in sequence:
            if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
                break
            caption += self.itos[i]
        return caption
    
    def predict_captions(self, sequences):
        captions = []
        for sequence in sequences:
            caption = self.predict_caption(sequence)
            captions.append(caption)
        return captions

tokenizer = torch.load('tokenizer2.pth')
print(f"tokenizer.stoi: {tokenizer.stoi}")

tokenizer.stoi: {'(': 0, ')': 1, '+': 2, ',': 3, '-': 4, '/b': 5, '/c': 6, '/h': 7, '/i': 8, '/m': 9, '/s': 10, '/t': 11, '0': 12, '1': 13, '10': 14, '100': 15, '101': 16, '102': 17, '103': 18, '104': 19, '105': 20, '106': 21, '107': 22, '108': 23, '109': 24, '11': 25, '110': 26, '111': 27, '112': 28, '113': 29, '114': 30, '115': 31, '116': 32, '117': 33, '118': 34, '119': 35, '12': 36, '120': 37, '121': 38, '122': 39, '123': 40, '124': 41, '125': 42, '126': 43, '127': 44, '128': 45, '129': 46, '13': 47, '130': 48, '131': 49, '132': 50, '133': 51, '134': 52, '135': 53, '136': 54, '137': 55, '138': 56, '139': 57, '14': 58, '140': 59, '141': 60, '142': 61, '143': 62, '144': 63, '145': 64, '146': 65, '147': 66, '148': 67, '149': 68, '15': 69, '150': 70, '151': 71, '152': 72, '153': 73, '154': 74, '155': 75, '156': 76, '157': 77, '158': 78, '159': 79, '16': 80, '161': 81, '163': 82, '165': 83, '167': 84, '17': 85, '18': 86, '19': 87, '2': 88, '20': 89, '21': 90, '22': 91, '23': 92, '24': 9

In [7]:
class TestDataset(Dataset) : 
    def __init__(self, df, transform = None) : 
        super().__init__()
        self.df = df
        self.transform = transform
        
    def __len__(self) : 
        return len(self.df)
    
    def __getitem__(self, idx) : 
        path = self.df.path.iloc[idx]
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        
        augmented = self.transform(image = image)
        image = augmented['image']
        return image

In [8]:
Transform = Compose([
    Resize(CFG.size, CFG.size),
    Normalize(
        mean = [0.485, 0.456, 0.406],
        std = [0.229, 0.224, 0.225],
    ),
    ToTensorV2(),
])

In [9]:
class Encoder(nn.Module) : 
    def __init__(self, model_name = 'resnet34', pretrained = False) :
        super().__init__()
        self.cnn = timm.create_model(model_name, pretrained = pretrained)
        self.n_features = self.cnn.fc.in_features
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()
        
    def forward(self, x) : 
        bs = x.size(0)
        features = self.cnn(x)
        features = features.permute(0, 2, 3, 1)
        return features

In [10]:
class Attention(nn.Module) : 
    def __init__(self, encoder_dim, decoder_dim, attention_dim) : 
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim = 1)
        
    def forward(self, encoder_out, decoder_hidden) : 
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = self.softmax(att)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim = 1)
        return attention_weighted_encoding, alpha

In [11]:
class DecoderWithAttention(nn.Module) : 
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim = 512, dropout = 0.5) : 
        super(DecoderWithAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.device = device
        
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p = self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias = True)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)
        self.init_weights()
        
    def init_weights(self) : 
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        
    def load_pretrained_embeddings(self, embeddings) : 
        self.embedding.weight = nn.Parameter(embeddings)
        
    def fine_tune_embeddings(self, fine_tune = True) : 
        for p in self.embedding.parameters() : 
            p.requires_grad = fine_tune
            
    def init_hidden_state(self, encoder_out) : 
        mean_encoder_out = encoder_out.mean(dim = 1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c
    
    def forward(self, encoder_out, encoded_captions, caption_lengths) : 
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels = encoder_out.size(1)
        
        caption_length, sort_ind = caption_lengths.squeeze(1).sort(dim = 0, descending = True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        embeddings = self.embedding(encoded_captions)
        
        h, c = self.init_hidden_state(encoder_out)
        
        decode_lengths = (caption_length - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(self.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        
        for t in range(max(decode_lengths)) : 
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim = 1),
                                   (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
            
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind
    
    def predict(self, encoder_out, decode_lengths, tokenizer) : 
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels = encoder_out.size(1)
        
        start_tockens = torch.ones(batch_size, dtype = torch.long).to(self.device) * tokenizer.stoi["<sos>"]
        embeddings = self.embedding(start_tockens)
        h, c = self.init_hidden_state(encoder_out)
        predictions = torch.zeros(batch_size, decode_lengths, vocab_size).to(self.device)
        
        for t in range(decode_lengths) : 
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(torch.cat([embeddings, attention_weighted_encoding], dim = 1),
                                   (h, c))
            preds = self.fc(self.dropout(h))
            predictions[:, t, :] = preds
            
            if np.argmax(preds.detach().cpu().numpy() == tokenizer.stoi["<eos>"]) : 
                break
                
            embeddings = self.embedding(torch.argmax(preds, -1))
        return predictions    

In [12]:
def get_score(y_true, y_pred) : 
    scores = []
    
    for true, pred in zip(y_true, y_pred) : 
        score = Levenshtein.distance(true, pred)
        scores.append(score)
        
    avg_score = np.mean(scores)
    return avg_score

In [13]:
def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

In [14]:
def seed_torch(seed = 0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

In [15]:
class AverageMeter(object) : 
    def __init__(self) : 
        self.reset()
        
    def reset(self) : 
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n = 1) : 
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [16]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [17]:
def bms_collate(batch):
    imgs, labels, label_lengths = [], [], []
    for data_point in batch:
        imgs.append(data_point[0])
        labels.append(data_point[1])
        label_lengths.append(data_point[2])
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.stoi["<pad>"])
    return torch.stack(imgs), labels, torch.stack(label_lengths).reshape(-1, 1)

In [18]:
test = pd.read_csv('./data/sample_submission.csv')
print(f'test.shape: {test.shape}')

test.shape: (1616107, 2)


In [19]:
test.head()

Unnamed: 0,image_id,InChI
0,00000d2a601c,InChI=1S/H2O/h1H2
1,00001f7fc849,InChI=1S/H2O/h1H2
2,000037687605,InChI=1S/H2O/h1H2
3,00004b6d55b6,InChI=1S/H2O/h1H2
4,00004df0fe53,InChI=1S/H2O/h1H2


In [20]:
tmp_test =test.copy()
tmp_test.head()

Unnamed: 0,image_id,InChI
0,00000d2a601c,InChI=1S/H2O/h1H2
1,00001f7fc849,InChI=1S/H2O/h1H2
2,000037687605,InChI=1S/H2O/h1H2
3,00004b6d55b6,InChI=1S/H2O/h1H2
4,00004df0fe53,InChI=1S/H2O/h1H2


In [21]:
def get_test_path(img_name) : 
    return f"{TEST_DIR}/{img_name[0]}/{img_name[1]}/{img_name[2]}/{img_name}.png"

tmp_test['path'] = tmp_test['image_id'].apply(get_test_path)

In [22]:
tmp_test.head()

Unnamed: 0,image_id,InChI,path
0,00000d2a601c,InChI=1S/H2O/h1H2,./data/test/0/0/0/00000d2a601c.png
1,00001f7fc849,InChI=1S/H2O/h1H2,./data/test/0/0/0/00001f7fc849.png
2,000037687605,InChI=1S/H2O/h1H2,./data/test/0/0/0/000037687605.png
3,00004b6d55b6,InChI=1S/H2O/h1H2,./data/test/0/0/0/00004b6d55b6.png
4,00004df0fe53,InChI=1S/H2O/h1H2,./data/test/0/0/0/00004df0fe53.png


In [23]:
def valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device) : 
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    encoder.eval()
    decoder.eval()
    text_preds = []
    start = end = time.time()
    
    for step, (images) in enumerate(valid_loader) : 
        data_time.update(time.time() - end)
        images = images.to(device)
        batch_size = images.size(0)
        
        with torch.no_grad() : 
            features = encoder(images)
            predictions = decoder.predict(features, CFG.max_len, tokenizer)
        
        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
        _text_preds = tokenizer.predict_captions(predicted_sequence)
        text_preds.append(_text_preds)
        
        batch_time.update(time.time() - end)
        end = time.time()
        
        if step % CFG.print_freq == 0 or step == (len(valid_loader) - 1) :
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
            
    text_preds = np.concatenate(text_preds)
    return text_preds

In [24]:
def test_fn() : 
    test_dataset = TestDataset(tmp_test, Transform)
    test_loader = DataLoader(test_dataset,
                             batch_size = CFG.batch_size,
                             shuffle = False,
                             num_workers = CFG.num_workers,
                             pin_memory = True,
                             drop_last = False)
    
    encoder = Encoder(CFG.model_name, pretrained = True)
    encoder.to(DEVICE)
    encoder_optimizer = Adam(encoder.parameters(), lr = CFG.encoder_lr,
                             weight_decay = CFG.weight_decay, amsgrad = False)
    
    decoder = DecoderWithAttention(attention_dim = CFG.attention_dim,
                                  embed_dim = CFG.embed_dim,
                                  decoder_dim = CFG.decoder_dim,
                                  vocab_size = len(tokenizer),
                                  dropout = CFG.dropout,
                                  device = DEVICE)
    decoder.to(DEVICE)
    decoder_optimizer = Adam(decoder.parameters(), lr = CFG.decoder_lr,
                            weight_decay = CFG.weight_decay, amsgrad = False)
    
    check_point = torch.load('./resnet34_fold0_best.pth')
    
    encoder.load_state_dict(check_point['encoder'])
    encoder_optimizer.load_state_dict(check_point['encoder_optimizer'])
    decoder.load_state_dict(check_point['decoder'])
    decoder_optimizer.load_state_dict(check_point['decoder_optimizer'])
    
    encoder.eval()
    decoder.eval()
    
    criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.stoi["<pad>"])
    
    text_preds = valid_fn(test_loader, encoder, decoder, tokenizer, criterion, DEVICE)
    text_preds = [f"InChI=1S/{text}" for text in text_preds]
    
    return text_preds

In [25]:
text_preds = test_fn()

EVAL: [0/101007] Data 0.110 (0.110) Elapsed 0m 3s (remain 5256m 43s) 
EVAL: [1000/101007] Data 0.083 (0.086) Elapsed 5m 39s (remain 564m 52s) 
EVAL: [2000/101007] Data 0.086 (0.091) Elapsed 11m 54s (remain 589m 26s) 
EVAL: [3000/101007] Data 0.094 (0.091) Elapsed 18m 1s (remain 588m 23s) 
EVAL: [4000/101007] Data 0.092 (0.092) Elapsed 24m 12s (remain 586m 47s) 
EVAL: [5000/101007] Data 0.085 (0.092) Elapsed 30m 24s (remain 583m 47s) 
EVAL: [6000/101007] Data 0.081 (0.093) Elapsed 36m 47s (remain 582m 23s) 
EVAL: [7000/101007] Data 0.084 (0.094) Elapsed 43m 22s (remain 582m 28s) 
EVAL: [8000/101007] Data 0.100 (0.094) Elapsed 49m 55s (remain 580m 20s) 
EVAL: [9000/101007] Data 0.106 (0.094) Elapsed 56m 12s (remain 574m 33s) 
EVAL: [10000/101007] Data 0.097 (0.094) Elapsed 62m 20s (remain 567m 21s) 
EVAL: [11000/101007] Data 0.087 (0.094) Elapsed 68m 23s (remain 559m 36s) 
EVAL: [12000/101007] Data 0.092 (0.093) Elapsed 74m 26s (remain 552m 5s) 
EVAL: [13000/101007] Data 0.103 (0.093) El

In [26]:
for i in tqdm(range(test.shape[0])) :
    test['InChI'][i] = text_preds[i]

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000


1164000
1165000
1166000
1167000
1168000
1169000
1170000
1171000
1172000
1173000
1174000
1175000
1176000
1177000
1178000
1179000
1180000
1181000
1182000
1183000
1184000
1185000
1186000
1187000
1188000
1189000
1190000
1191000
1192000
1193000
1194000
1195000
1196000
1197000
1198000
1199000
1200000
1201000
1202000
1203000
1204000
1205000
1206000
1207000
1208000
1209000
1210000
1211000
1212000
1213000
1214000
1215000
1216000
1217000
1218000
1219000
1220000
1221000
1222000
1223000
1224000
1225000
1226000
1227000
1228000
1229000
1230000
1231000
1232000
1233000
1234000
1235000
1236000
1237000
1238000
1239000
1240000
1241000
1242000
1243000
1244000
1245000
1246000
1247000
1248000
1249000
1250000
1251000
1252000
1253000
1254000
1255000
1256000
1257000
1258000
1259000
1260000
1261000
1262000
1263000
1264000
1265000
1266000
1267000
1268000
1269000
1270000
1271000
1272000
1273000
1274000
1275000
1276000
1277000
1278000
1279000
1280000
1281000
1282000
1283000
1284000
1285000
1286000
1287000
1288000


In [27]:
test.head()

Unnamed: 0,image_id,InChI
0,00000d2a601c,InChI=1S/C10H14BrN5S/c1-6-10(11)9(16(3)15-6)4-...
1,00001f7fc849,InChI=1S/C15H18ClN3/c16-12-5-1-10(2-6-12)7-14-...
2,000037687605,InChI=1S/C16H13BrN2O/c1-11(20)12-6-7-14(10-18)...
3,00004b6d55b6,"InChI=1S/C14H19FN4O/c1-14(2,3)12-13(16)17-18-1..."
4,00004df0fe53,InChI=1S/C9H12O2/c10-3-1-2-4-5(6)7(11-9)3-8(4)...


In [29]:
test.to_csv('./submission.csv', index = False)