In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm.notebook import tqdm

import cv2, os, pickle


import joblib
from multiprocessing import cpu_count

In [2]:
# Mean ratio of first 10K images is 1.73, this image width/height result in a ratio of 1.75
DEBUG = False
IMG_HEIGHT = 256
IMG_WIDTH = 448
VAL_SIZE = int(100) if DEBUG else int(100e3) # 100K validation molecules
CHUNK_SIZE = 40000 # to get ~100MB TFRecords

MAX_INCHI_LEN = 200 # maximum InChI length to prevent to much padding

In [3]:
if DEBUG:
    train = pd.read_csv('/kaggle/input/bms-molecular-translation/train_labels.csv', dtype={ 'image_id': 'string', 'InChI': 'string' }).head(int(1e3))
else:
    train = pd.read_csv('/kaggle/input/bms-molecular-translation/train_labels.csv', dtype={ 'image_id': 'string', 'InChI': 'string' })

# Drop all InChI longer than MAX_INCHI_LEN - 2,  <start>InChI <end>, remove 'InChI=1S/' at start
train['InChI_len'] = train['InChI'].apply(len).astype(np.uint16)
train = train.loc[train['InChI_len'] <= MAX_INCHI_LEN - 2 + 9].reset_index(drop=True)
train.head(3)

Unnamed: 0,image_id,InChI,InChI_len
0,000011a64c74,InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12...,81
1,000019cc0cd2,InChI=1S/C21H30O4/c1-12(22)25-14-6-8-20(2)13(1...,155
2,0000252b6d2b,InChI=1S/C24H23N5O4/c1-14-13-15(7-8-17(14)28-1...,158


In [4]:
if DEBUG:
    test = pd.read_csv('/kaggle/input/bms-molecular-translation/sample_submission.csv', usecols=['image_id'], dtype={ 'image_id': 'string' }).head(int(1e3))
else:
    test = pd.read_csv('/kaggle/input/bms-molecular-translation/sample_submission.csv', usecols=['image_id'], dtype={ 'image_id': 'string' })
test.head(3)

Unnamed: 0,image_id
0,00000d2a601c
1,00001f7fc849
2,000037687605


In [5]:
def get_vocabulary():
    tokens = ['<start>', '<end>', '<pad>']
    vocabulary = set()
    for s in tqdm(train['InChI'].values):
        vocabulary.update(s)
    return tokens + list(vocabulary)

vocabulary = get_vocabulary()

  0%|          | 0/2371184 [00:00<?, ?it/s]

In [6]:
# Save vocabulary mappings
# character -> integer
vocabulary_to_int = dict(zip(vocabulary, np.arange(len(vocabulary), dtype=np.int8)))
with open('vocabulary_to_int.pkl', 'wb') as handle:
    pickle.dump(vocabulary_to_int, handle)

#  integer -> character
int_to_vocabulary = dict(zip(np.arange(len(vocabulary), dtype=np.int8), vocabulary))
with open('int_to_vocabulary.pkl', 'wb') as handle:
    pickle.dump(int_to_vocabulary, handle)

In [7]:
train['InChIClean'] = train['InChI'].apply(lambda InChI: '/'.join(InChI.split('=')[1].split('/')[1:]))
train.head()

Unnamed: 0,image_id,InChI,InChI_len,InChIClean
0,000011a64c74,InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12...,81,C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12(13)11(4)...
1,000019cc0cd2,InChI=1S/C21H30O4/c1-12(22)25-14-6-8-20(2)13(1...,155,C21H30O4/c1-12(22)25-14-6-8-20(2)13(10-14)11-1...
2,0000252b6d2b,InChI=1S/C24H23N5O4/c1-14-13-15(7-8-17(14)28-1...,158,C24H23N5O4/c1-14-13-15(7-8-17(14)28-12-10-20(2...
3,000026b49b7e,InChI=1S/C17H24N2O4S/c1-12(20)18-13(14-7-6-10-...,147,C17H24N2O4S/c1-12(20)18-13(14-7-6-10-24-14)11-...
4,000026fc6c36,InChI=1S/C10H19N3O2S/c1-15-10(14)12-8-4-6-13(7...,96,C10H19N3O2S/c1-15-10(14)12-8-4-6-13(7-8)5-2-3-...


In [8]:
# convert the InChI strings to integer lists
# start/end/pad tokens are used
def inchi_str2int(InChI):
    res = []
    res.append(vocabulary_to_int.get('<start>'))
    for c in InChI:
        res.append(vocabulary_to_int.get(c))
    
    res.append(vocabulary_to_int.get('<end>'))
    while len(res) < MAX_INCHI_LEN: 
        res.append(vocabulary_to_int.get('<pad>'))
        
    return np.array(res, dtype=np.uint8)

tqdm.pandas() # progress_apply
train['InChI_int'] = train['InChIClean'].progress_apply(inchi_str2int)
train.head()

  0%|          | 0/2371184 [00:00<?, ?it/s]

Unnamed: 0,image_id,InChI,InChI_len,InChIClean,InChI_int
0,000011a64c74,InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12...,81,C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12(13)11(4)...,"[0, 38, 26, 36, 17, 40, 10, 37, 8, 16, 33, 26,..."
1,000019cc0cd2,InChI=1S/C21H30O4/c1-12(22)25-14-6-8-20(2)13(1...,155,C21H30O4/c1-12(22)25-14-6-8-20(2)13(10-14)11-1...,"[0, 38, 40, 26, 17, 36, 10, 37, 12, 16, 33, 26..."
2,0000252b6d2b,InChI=1S/C24H23N5O4/c1-14-13-15(7-8-17(14)28-1...,158,C24H23N5O4/c1-14-13-15(7-8-17(14)28-12-10-20(2...,"[0, 38, 40, 12, 17, 40, 36, 15, 3, 37, 12, 16,..."
3,000026b49b7e,InChI=1S/C17H24N2O4S/c1-12(20)18-13(14-7-6-10-...,147,C17H24N2O4S/c1-12(20)18-13(14-7-6-10-24-14)11-...,"[0, 38, 26, 24, 17, 40, 12, 15, 40, 37, 12, 8,..."
4,000026fc6c36,InChI=1S/C10H19N3O2S/c1-15-10(14)12-8-4-6-13(7...,96,C10H19N3O2S/c1-15-10(14)12-8-4-6-13(7-8)5-2-3-...,"[0, 38, 26, 10, 17, 26, 19, 15, 36, 37, 40, 8,..."


In [9]:
val = train.iloc[-VAL_SIZE:].reset_index(drop=True)
train = train.iloc[:-VAL_SIZE].reset_index(drop=True)
N_IMGS = len(train)

In [11]:
def remove_blobs(img, min_size=10, debug=False):
    if debug:
        fig, ax = plt.subplots(1,2, figsize=(30,8))
        ax[0].imshow(img)
        ax[0].set_title('original image', size=16)
    
    height, width = img.shape

    # find all the connected components (white blobs in your image)
    nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=8)
    # Removes background, which is seen as a big component
    sizes = stats[1:, -1]
  
    blob_idxs = []    
    for idx, s in enumerate(sizes):
        if s < min_size:
            blob_idxs.append(idx+1)
    
    img[np.isin(output, blob_idxs)] = 0
    
    if debug:
        ax[1].imshow(img)
        ax[1].set_title('image with removed blobs', size=16)
        plt.show()
    
    return img

In [12]:
def crop(img, debug=False):
    if debug:
        fig, ax = plt.subplots(1,2, figsize=(30,8))
        ax[0].imshow(img)
        ax[0].set_title(f'original image, shape: {img.shape}', size=16)
        
    _, thresh = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)[-2:]
    
    x_min, y_min, x_max, y_max = np.inf, np.inf, 0, 0
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        x_min = min(x_min, x)
        y_min = min(y_min, y)
        x_max = max(x_max, x + w)
        y_max = max(y_max, y + h)

    img_cropped = img[y_min:y_max, x_min:x_max]
    
    if debug:
        ax[1].imshow(img_cropped)
        ax[1].set_title(f'cropped image, shape: {img_cropped.shape}', size=16)
        plt.show()
    
    return img_cropped

In [13]:
def pad_kernel(kernel, max_pad=np.inf):
    kernel = np.array(kernel)
    h, w = kernel.shape
    pad_h = min((max(h, w) - h) // 2, max_pad)
    pad_w = min((max(h, w) - w) // 2, max_pad)
    return np.pad(kernel, ([pad_h, pad_h], [pad_w, pad_w]), 'constant', constant_values=-1)

# creates a mask of missing pixels to be filled using
def create_mask(kernel, img_b):
    mask = cv2.filter2D(img_b, -1, kernel)
    kernel_flat_sum = (kernel == a).flatten().sum()
    threshold_min = kernel_flat_sum * threshold_ratio
    threshold_max = kernel_flat_sum + 1
    return (mask > threshold_min) & (mask < threshold_max)

# make kernels
a = np.float32(1.0 / 255.0)
threshold_ratio = 0.50
# single pixel width horizontal line with 1 pixel missing
kernel_h_single_mono = pad_kernel([
    [ a, a,  a, -1,  a,  a, a ]
], max_pad=1)
# single pixel width horizontal line with 3 pixels missing
kernel_h_single_triple = pad_kernel([
    [ a, a, a, -1, -1, -1, a, a, a ]
], max_pad=1)

kernel_h_multi = pad_kernel([
    [ a, a, a, a, a, a, a ],
    [ a, a, a,-1, a, a, a ],
    [ a, a, a, a, a, a, a ],
], max_pad=1)

kernel_v_single = pad_kernel([
    [ a],
    [ a],
    [ a],
    [-1],
    [ a],
    [ a],
    [ a],
], max_pad=1)

kernel_v_multi = pad_kernel([
    [ a, a, a ],
    [ a, a, a ],
    [ a, a, a ],
    [ a,-1, a ],
    [ a, a, a ],
    [ a, a, a ],
    [ a, a, a ],
], max_pad=1)

kernel_lr_single = pad_kernel([
    [ -1,-1,-1,-1, a ],
    [ -1,-1,-1, a,-1 ],
    [ -1,-1,-1,-1,-1 ],
    [ -1, a,-1,-1,-1 ],
    [  a,-1,-1,-1,-1 ],
])

kernel_lr_multi = pad_kernel([
    [ -1,-1,-1, a, a ],
    [ -1,-1, a, a, a ],
    [ -1, a,-1, a,-1 ],
    [  a, a, a,-1,-1 ],
    [  a, a,-1,-1,-1 ],
])

kernel_rl_single = pad_kernel([
    [  a,-1,-1,-1,-1 ],
    [ -1, a,-1,-1,-1 ],
    [ -1,-1,-1,-1,-1 ],
    [ -1,-1,-1, a,-1 ],
    [ -1,-1,-1,-1, a ],
])

kernel_rl_multi = pad_kernel([
    [ a, a,-1,-1,-1],
    [ a, a, a,-1,-1],
    [-1, a,-1, a,-1],
    [-1,-1, a, a, a],
    [-1,-1,-1, a, a],
])

def fill_missing_pixels(img, debug):
    img_b = img.astype(np.float32)
    img_b[img_b > 0] = 255

    mask_h_single_mono = create_mask(kernel_h_single_mono, img_b)

    mask_h_single_triple = create_mask(kernel_h_single_triple, img_b)

    mask_h_single = mask_h_single_mono | mask_h_single_triple

    mask_h_multi = create_mask(kernel_h_multi, img_b)


    mask_v_single = create_mask(kernel_v_single, img_b)


    mask_v_multi = create_mask(kernel_v_multi, img_b)


    mask_lr_single = create_mask(kernel_lr_single, img_b)


    mask_lr_multi = create_mask(kernel_lr_multi, img_b)


    mask_rl_single = create_mask(kernel_lr_single, img_b)


    mask_rl_multi = create_mask(kernel_rl_multi, img_b)

    mask_single = mask_h_single | mask_v_single | mask_lr_single | mask_rl_single
    mask_multi = mask_h_multi  | mask_v_multi |mask_lr_multi | mask_rl_multi
    mask = mask_single | mask_multi

    if debug:
        fig, ax = plt.subplots(2, 2 ,figsize=(35,20))
        ax[0,0].imshow(mask_h_single)
        ax[0,0].set_title('mask_h_single', size=16)
        ax[0,1].imshow(mask_v_single)
        ax[0,1].set_title('mask_v_single', size=16)
        ax[1,0].imshow(mask_lr_single)
        ax[1,0].set_title('mask_lr_single', size=16)
        ax[1,1].imshow(mask_lr_single)
        ax[1,1].set_title('mask_lr_single', size=16)
        plt.show()

        fig, ax = plt.subplots(2, 2, figsize=(35,20))
        ax[0,0].imshow(mask_h_multi)
        ax[0,0].set_title('mask_h_multi', size=16)
        ax[0,1].imshow(mask_v_multi)
        ax[0,1].set_title('mask_v_multi', size=16)
        ax[1,0].imshow(mask_lr_multi)
        ax[1,0].set_title('mask_lr_multi', size=16)
        ax[1,1].imshow(mask_rl_multi)
        ax[1,1].set_title('mask_rl_multi', size=16)
        plt.show()

        fig, ax = plt.subplots(2, 1 ,figsize=(15,20))
        ax[0].imshow(img)
        ax[0].set_title('original image', size=16)
        
        img_fill = mask.copy()
        img_fill[img_fill > 0] = 255

        img_rgb = np.stack([
            img_fill,
            img_b,
            np.zeros(img.shape),
        ], axis=2)

        ax[1].imshow(img_rgb)
        ax[1].set_title('image with filled missing pixels (red)', size=16)
        plt.show()    

    # all pixels in the mask are filled up
    img[mask] = 255

    return img


In [14]:
def pad_resize(img):
    h, w = img.shape
    s = max(w, h)
    pad_h, pad_v = 0, 0
    hw_ratio = (h / w) - (IMG_HEIGHT / IMG_WIDTH)
    if hw_ratio < 0:
        pad_h = int(abs(hw_ratio) * w / 2)
    else:
        wh_ratio = (w / h) - (IMG_WIDTH / IMG_HEIGHT)
        pad_v = int(abs(wh_ratio) * h // 2)

    img = np.pad(img, [(pad_h, pad_h), (pad_v, pad_v)], mode='constant')
    img = cv2.resize(img,(IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)

    return img

In [15]:
def process_img(image_id, folder='train', debug=False):
    # read image and invert colors to get black background and white molecule
    file_path =  f'/kaggle/input/bms-molecular-translation/{folder}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    img0 = 255 - cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
    
    # rotate counter clockwise to get horizontal images
    h, w = img0.shape
    if h > w:
        img0 = np.rot90(img0)
    
    # remove blobs, crop, fill missing pixels, pad and resize
    img = remove_blobs(img0, min_size=2, debug=debug)
    img = crop(img, debug=debug)
    img = fill_missing_pixels(img, debug=debug)
    img = pad_resize(img)
    
    if debug:
        fig, ax = plt.subplots(1, 2, figsize=(20,10))
        ax[0].imshow(img0)
        ax[0].set_title('Original image', size=16)
        ax[1].imshow(img)
        ax[1].set_title('Fully processed image', size=16)
    
    # normalize to range 0-255 and encode as png
    img = (img / img.max() * 255).astype(np.uint8)
    img = cv2.imencode('.png', img)[1].tobytes()

    return img

In [16]:
def split_in_chunks(data):
    return [data[i:i + CHUNK_SIZE] for i in range(0, len(data), CHUNK_SIZE)]

train_data_chunks = {
    'train': {
        'image_id': split_in_chunks(train['image_id'].values),
        'InChI': split_in_chunks(train['InChI_int'].values),
    },
    'val': {
        'image_id': split_in_chunks(val['image_id'].values),
        'InChI': split_in_chunks(val['InChI_int'].values),
    }
}

test_data_chunks = {
    'test': {
        'image_id': split_in_chunks(test['image_id'].values),
    }
}

In [17]:
def make_tfrecords(data_chunks, folder='train'):
    # Try to make output folder
    try:
        os.makedirs(f'./train')
        os.makedirs(f'./val')
        os.makedirs(f'./test')
    except:
        print(f'folders already created')

    for k, v in data_chunks.items():
        for chunk_idx, image_id_chunk in tqdm(enumerate(v['image_id']), total=len(v['image_id'])):
            # process images in parallel
            jobs = [joblib.delayed(process_img)(fp, folder) for fp in image_id_chunk]
            bs = 10
            processed_images_chunk = joblib.Parallel(
                n_jobs=cpu_count(),
                verbose=0,
                require='sharedmem',
                batch_size=bs,
                backend='threading',
            )(jobs)

            # Create the TFRecords from the processed images
            with tf.io.TFRecordWriter(f'./{k}/batch_{chunk_idx}.tfrecords') as file_writer:
                if 'InChI' in v.keys(): # TRAIN/VAL, InChI included
                    for image, InChI in zip(processed_images_chunk, v['InChI'][chunk_idx]):
                        record_bytes = tf.train.Example(features=tf.train.Features(feature={
                            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                            'InChI': tf.train.Feature(int64_list=tf.train.Int64List(value=InChI)),
                        })).SerializeToString()
                        file_writer.write(record_bytes)
                else: # TEST, image_id included for submission file
                    for image, image_id in zip(processed_images_chunk, image_id_chunk):
                        record_bytes = tf.train.Example(features=tf.train.Features(feature={
                            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                            'image_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[str.encode(image_id)])),
                        })).SerializeToString()
                        file_writer.write(record_bytes)

make_tfrecords(train_data_chunks)
make_tfrecords(test_data_chunks, 'test')

  0%|          | 0/57 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

folders already created


  0%|          | 0/41 [00:00<?, ?it/s]

In [18]:
# convert in int encoded InChI to string
def inchi_int2char(InChI):
    res = []
    for i in InChI:
        c = int_to_vocabulary.get(i)
        if c not in ['<start>', '<end>', '<pad>']:
            res.append(c)
    return ''.join(res)

In [19]:
# Check train TFRecords
def decode_tfrecord(record_bytes):
    fea_dict= {
        'image': tf.io.FixedLenFeature([], tf.string),
        'InChI': tf.io.FixedLenFeature([MAX_INCHI_LEN], tf.int64),}
    
    features = tf.io.parse_single_example(record_bytes, fea_dict)

    image = tf.io.decode_jpeg(features['image'])    
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, 1])
    image = tf.cast(image, tf.float32)  / 255.0
    
    InChI = features['InChI']
    InChI = tf.reshape(InChI, [MAX_INCHI_LEN])
    
    return image, InChI


# Check test TFRecords
def decode_test_tfrecord(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'image_id': tf.io.FixedLenFeature([], tf.string),
    })

    image = tf.io.decode_jpeg(features['image'])
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, 1])
    image = tf.cast(image, tf.float32)  / 255.0
    
    image_id = features['image_id']
    
    return image, image_id
