## Try these two experiments first
1. Create original input image to inchi string translator and get accuracy and loss
2. Create inchi image to inchi string translator and get accuracy and loss (Use [AutoEncoders](https://www.youtube.com/watch?v=E28CVTbNoSA&ab_channel=PascalPoupart))

Now, compare the two. If the accuracy of inchi image -> inchi string is significantly higher than the original image -> inchi string then think about ***reconstruction experiments*** from original image to inchi image. [try this then](https://www.google.com/search?channel=fs&client=ubuntu&q=converting+shapes+from+one+to+another+using+deep+learning).

# InChI decoding 
[Source](https://link.springer.com/content/pdf/10.1186%2Fs13321-015-0068-4.pdf)

1. **Skeletal connections layer** This layer prefixed with `/c` represents connections between skeletal atoms by listing the canonical numbers in the chain of connected atoms. 
2. ***branches are given in parentheses***
3. The canonical atomic numbers, which are used throughout the InChI, are always given in the formula’s element order. i.e. precendence is given to element according to periodic table while numbering elements. For example, `/C10H16N5O13P3` (the beginning of InChI for adenosine triphosphate) implies that atoms numbered 1–10 are carbons, 11–15 arenitrogens, 16–28 are oxygens, and 29–31 are phosporus. Hydrogen atoms are not explicitly numbered.


## image to inchi string

In [1]:
# Special Packages
# !pip install PeriodicElements
# !pip install albumentations
# !pip install timm
# !pip install python-Levenshtein
# !pip install torchmetrics

In [2]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2

In [3]:
import torch, torchmetrics, timm, re, pickle, Levenshtein
import torch.nn as nn
import torchvision as tv
import pytorch_lightning as pl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from pathlib import Path
from functools import partial
from collections import defaultdict
from fastprogress import progress_bar
from typing import Optional, Union, Tuple
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from elements import elements
from albumentations.pytorch import ToTensorV2
from preprocessing import preprocess_image


# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
torch.manual_seed(manualSeed);

# This monkey-patch is there to be able to plot tensors
torch.Tensor.ndim = property(lambda x: len(x.shape))

Random Seed:  999


In [4]:
CHKPTDIR = Path("TranslationFULLChkpts")
DATADIR = "data/bms-molecular-translation"
LABELS_CSV_PATH = "data/train_labels.csv"
VOCAB_FILEPATH = CHKPTDIR/"vocab_dict.pt"
TRAINPATHS_PATH = CHKPTDIR/"train_paths.feather"
TESTPATHS_PATH = CHKPTDIR/"test_paths.feather"
CHKPTDIR.mkdir(parents=True, exist_ok=True)

tb_logger = pl.loggers.TensorBoardLogger(CHKPTDIR, name="InchINet")

LAYERS_SEQ = ('main_layer', 'c_layer', 'h_layer', 'b_layer', 't_layer', 'm_layer', 's_layer', 'i_layer')
NULL_TOKEN = "99999"
N_WORKERS = 4
BATCH_SIZE = 256
PRECISION = 16
EMB_SIZE = 512
INP_SIZE = (128, 128)
LR = 1e-2
EPOCHS = 2
beta1 = 0.5

In [5]:
!ls {DATADIR}

sample_submission.csv  test  train  train_labels.csv


# Data Block

## Vocab and Tokenizer

#### *Fix this at end:* in some samples /t and /m layers are repeated after /i layer. For now, I mixed the repeated part to the first one. Either separate them in post processing or add t2 and m2 layers.

In [6]:
def dissect_inchi(inchi_string:str) -> Union[dict, str]:
    if len(inchi_string) > 0:
        layers_dict = {} 
        inchi_string = inchi_string.split("/")
        layers_counter = len(inchi_string) - 1 # layers_counter is for sanity check, -1 for InChI=1S
        # Get InChI standard format and main layer from string
        assert inchi_string[0] == 'InChI=1S', "Error in `dissect_inchi` function, string must start with `InChI=1S`"
        layers_dict["main_layer"] = inchi_string[1]
        del inchi_string[0:2]

        for layer in inchi_string:
            if len(layer) == 0:
                layers_counter -= 1
            elif len(layer) == 1:
                lyr_name = f"{layer}_layer"
                if lyr_name in layers_dict.keys():
                    layers_dict[lyr_name] = layers_dict[lyr_name] + ",99999"
                    layers_counter -= 1
                else:
                    layers_dict[lyr_name] = "99999"
            else:
                lyr_name = f"{layer[0]}_layer"
                if lyr_name in layers_dict.keys():
                    layers_dict[lyr_name] = layers_dict[lyr_name] + "," + layer[1:]
                    layers_counter -= 1
                else:
                    layers_dict[lyr_name] = layer[1:]

        assert layers_counter == len(layers_dict.keys()), \
        f"""Error in `dissect_inchi` function. String is not fully analysed.
        Expected {layers_counter} layers but got {len(layers_dict.keys())} layers.
        {inchi_string}
        """
        
        # Add null token to empty layers for further convenience
        seq_dict = {}
        dissected_keys = layers_dict.keys()
        for layer_name in LAYERS_SEQ:
            if layer_name in dissected_keys:
                seq_dict[layer_name] = layers_dict[layer_name]
            else:
                seq_dict[layer_name] = '99999'
                
        return seq_dict
    return ''
    
# idx = 1
# print(dm.df.InChI[1][:-2])
# dissect_inchi(dm.df.InChI[1][:-2])

In [7]:
class Vocab:
    def __init__(self, vocab_dict:dict[list], max_lengths:dict[int], patterns_dict:dict[str], add_special_tokens:bool=True) -> None:
        self.vocab_dict = vocab_dict
        self.max_lengths = max_lengths
        self.patterns_dict = patterns_dict
        if add_special_tokens:
            # correct max lengths if special tokens are added
            self.max_lengths = {k: v + 2 for k, v in self.max_lengths.items()}
            self.pad_token, self.unk_token, self.bos_token, self.eos_token = "<pad>", "<unk>", "<bos>", "<eos>"
            self.null_token = "99999" # for empty sublayers i.e. layers having prefixes only. e.g. /t, /i
            self.ctoi_dict = {}
            for k, v in self.vocab_dict.items():
                if self.pad_token in self.vocab_dict[k]:
                    self.vocab_dict[k] = v
                else:
                    self.vocab_dict[k] = [self.pad_token, self.null_token, self.unk_token,self.bos_token,self.eos_token] + v
                self.ctoi_dict[k] = defaultdict(self.handle_unk_char, {c:i for i, c in enumerate(self.vocab_dict[k])})
    
    def handle_unk_char(self):
        return self.vocab_dict["main_layer"].index(self.unk_token)
        
    def main_layer_tokenizer(self, string:str) -> list[str]:
        tokens = re.split(self.patterns_dict["main_layer"], string)
        tokens = list(filter(None, tokens))
        return tokens
    
    def slash_layer_tokenizer(self, string:str) -> list[str]:
        tokens = re.split(self.patterns_dict["slash_layer"], string)
        tokens = list(filter(None, tokens))
        return tokens
    
    def ctoi(self, c:str, layer_name:str) -> int:
        return self.ctoi_dict[layer_name][c]
    
    def itoc(self, i:int, layer_name:str) -> str:
        return self.vocab_dict[layer_name][i]
    
    @property
    def vocab_size(self) -> dict:
        return {k:len(v) for k,v in self.vocab_dict.items()}
    
    def save_vocab(self, path:str) -> None:
        torch.save({"vocab_dict": self.vocab_dict, "max_lengths": self.max_lengths, "patterns_dict": self.patterns_dict}, path)
        print("Saved @", path)
        
    @classmethod
    def from_file(cls, path:str, add_special_tokens:bool=True) -> object:
        vocab = torch.load(path)
        vocab_dict, max_lengths, patterns_dict = vocab["vocab_dict"], vocab["max_lengths"], vocab["patterns_dict"]
        return cls(vocab_dict, max_lengths, patterns_dict, add_special_tokens)
    
    @classmethod
    def from_dataframe_column(cls, inchi_column:pd.Series, add_special_tokens:bool=True, verbose=True) -> object:
        # Create corpus from inchi strings
        corpus_dict = Vocab.get_inchi_corpus(inchi_column)
        return cls.from_corpus(corpus_dict, add_special_tokens, verbose)
    
    @classmethod
    def from_corpus(cls, corpus_dict:dict[str], add_special_tokens:bool=True, verbose=True) -> object:
        names_list = list(corpus_dict.keys())
        names_list.remove("main_layer")
        vocab_dict, max_lengths = {}, {}
        # Create vocab for main layer
        if verbose: print(f"Creating main layer vocab...", end=' ')
        vocab_dict["main_layer"], max_len, main_layer_pattern = Vocab.create_main_layer_vocab(corpus_dict["main_layer"], True)
        max_lengths["main_layer"] = max_len
        if verbose: print("done!")
            
        # Create vocabs for rest of the inchi string layers
        for layer_name in names_list:
            if verbose: print(f"Creating {layer_name} vocab...", end=' ')
            vocab_dict[layer_name], max_len, slash_layer_pattern = Vocab.create_slash_layer_vocab(corpus_dict[layer_name], True)
            max_lengths[layer_name] = max_len
            if verbose: print("done!")
        
        # Patterns
        patterns_dict = {"main_layer": main_layer_pattern, "slash_layer": slash_layer_pattern}
    
        return cls(vocab_dict, max_lengths, patterns_dict, add_special_tokens)
        
    @staticmethod
    def get_inchi_corpus(inchi_column:pd.Series) -> dict[str]:
        corpus, counter = {}, 0
        for string in progress_bar(inchi_column):
            try:
                dissected = dissect_inchi(string)
                for k, v in dissected.items():
                    if k not in corpus:
                        corpus[k] = [v]
                    else:
                        corpus[k].append(v)
            except AssertionError:
                counter += 1
        if counter > 0:
            print(f"[Warning] Found {counter} assertions in data while creating layer-wise corpus.")
        return corpus
    
    @staticmethod
    def create_main_layer_vocab(main_layer_corpus:list[str], return_pattern:bool=False) -> Union[Tuple[list, int], Tuple[list, int, str]]:
        def tokenize(string):
            tokens = re.split(pattern, string)
            return list(filter(None, tokens))

        data = elements.Elements
        # All elements in periodic table
        elems = sorted(data, key=lambda i:i.AtomicNumber)  # Based on their AtomicNumber
        # Sort longer names first for regex pattern formation
        element_symbols = sorted([e.Symbol for e in elems], key=lambda e: len(e), reverse=True)
        # Create regex pattern
        pattern = f"({'|'.join([f'{e}[0-9]*' for e in element_symbols])})"
        vocab = [tokenize(string) for string in main_layer_corpus]
        max_length = max(map(len, vocab))
        vocab = np.unique([w for string in vocab for w in string]).tolist()
        if return_pattern:
            return vocab, max_length, pattern
        return vocab, max_length
    
    @staticmethod
    def create_slash_layer_vocab(layer_corpus:list[str], return_pattern:bool=False) -> Union[Tuple[list, int], Tuple[list, int, str]]:
        def tokenize(string):
            tokens = re.split(pattern, string)
            return list(filter(None, tokens))
            
        pattern = r"([\d]*)"
        vocab = [tokenize(string) for string in layer_corpus]
        max_length = max(map(len, vocab))
        vocab = np.unique([w for string in vocab for w in string]).tolist()
        if return_pattern:
            return vocab, max_length, pattern
        return vocab, max_length

class Tokenizer:
    def __init__(self, vocab:Vocab=None) -> None:
        self.vocab = vocab # Vocab class instance
    
    def tokenize(self, x:Union[str, dict], layer_name:str=None) -> Union[list, dict[list]]:
        if isinstance(x, str):
            if layer_name == None:
                raise AttributeError("`layer_name` is required when `x` is a string")
            return eval(f"self.vocab.{layer_name}_tokenizer")(x)
        elif isinstance(x, dict):
            for k,v in x.items():
                if k == "main_layer": 
                    x[k] = self.vocab.main_layer_tokenizer(v)
                else:
                    x[k] = self.vocab.slash_layer_tokenizer(v)
            return x
        else:
            raise ValueError("`x` must be either string or dict(layer_name=string)")
    
    def encode(self, inchi_string:str, pad_to_max_len:bool=True) -> dict[dict[list]]:
        dissected = dissect_inchi(inchi_string)
        tokens_dict = self.tokenize(dissected)
        if hasattr(self.vocab, "bos_token"): # Add start and end special tokens
            tokens_dict = {k: [self.vocab.bos_token] + v + [self.vocab.eos_token] for k, v in tokens_dict.items()}
        inp_seq_dict = {k: [self.vocab.ctoi(c, k) for c in v] for k, v in tokens_dict.items()}
        attn_mask_dict = {k: [1] * len(v) for k, v in inp_seq_dict.items()}

        # Add padding for max lengths
        if pad_to_max_len:
            for k, v in inp_seq_dict.items():
                extra_len = self.vocab.max_lengths[k] - len(inp_seq_dict[k])
                inp_seq_dict[k] = v + [self.vocab.ctoi(self.vocab.pad_token, k)] * extra_len
                attn_mask_dict[k] = attn_mask_dict[k] + [0] * extra_len
        return {"inp_seq_dict": inp_seq_dict, "attn_mask_dict": attn_mask_dict}
    
    def decode(self, inp_seq_dict:dict[list]) -> str:
        decoded_string, inp_keys = 'InChI=1S/', inp_seq_dict.keys()
        for k in LAYERS_SEQ:
            if k in inp_keys:
                substring = [self.vocab.itoc(t, k) for t in inp_seq_dict[k]]
                _from = substring.index(self.vocab.bos_token) + 1 # 1 is for excluding start token from string
                _to = substring.index(self.vocab.eos_token)
                substring = ''.join(substring[_from:_to])
                if k != 'main_layer':
                    decoded_string += f"/{k[0]}{substring}"
                else:
                    decoded_string += substring
        return decoded_string
    
    @classmethod
    def from_file(cls, path:str, add_special_tokens:bool=True) -> object:
        vocab = torch.load(path)
        vocab_dict, max_lengths, patterns_dict = vocab["vocab_dict"], vocab["max_lengths"], vocab["patterns_dict"]
        vocab_obj = Vocab(vocab_dict, max_lengths, patterns_dict, add_special_tokens)
        return cls(vocab_obj)

In [8]:
def vocab_unit_tests(vocab:Vocab, corpus_dict:dict[list], 
                     max_lengths:dict[int], add_special_tokens:bool=True) -> str:
    max_lengths = {k: v - 2 for k, v in max_lengths.items()}
    print("Checking main string length...", end=' ')
    main_flag = False
    for string in progress_bar(corpus_dict['main_layer']):
        if len(vocab.slash_layer_tokenizer(string)) == max_lengths["main_layer"]:
            main_flag = True
    assert main_flag == True, "Main string test failed :("
    print("done!")
    del corpus_dict["main_layer"]
    for layer_name in corpus_dict.keys():
        flag = False
        print(f"Checking {layer_name} string length...", end=' ')
        for string in progress_bar(corpus_dict[layer_name]):
            if len(vocab.slash_layer_tokenizer(string)) == max_lengths[layer_name]:
                flag = True
        assert flag == True, f"{layer_name} string test failed :("
        print("done!")
    return "ALL GOOD!"

def tokenizer_unit_test(tokenizer:Tokenizer, df:pd.DataFrame, sample_size:int=10000) -> str:
    print("Running tokenizer test...")
    for idx in progress_bar(np.random.randint(0, len(df), size=sample_size)):
        inp_string = df.InChI[idx]
        inp_seq_dict = tokenizer.encode(inp_string)['inp_seq_dict']
        out_string = tokenizer.decode(inp_seq_dict)
        if Levenshtein.distance(inp_string, out_string) != 0:
            assert IOError("Input InChI string is equal to output InChI string.")
    return "ALL GOOD!"

# Unit Testing
# test_corpus_dict = Vocab.get_inchi_corpus(df.InChI)
# test_vocab = Vocab.from_corpus(test_corpus_dict)
# test_tokenizer = Tokenizer(test_vocab)
# tokenizer_unit_test(test_tokenizer, df)
# vocab_unit_tests(test_vocab, test_corpus_dict, test_vocab.max_lengths, add_special_tokens=True)

## Lightning Data Module
### LightningDataModule API

To define a DataModule define 5 methods:
1. prepare_data (how to download(), tokenize, etc…)
2. setup (how to split, etc…)
3. train_dataloader
4. val_dataloader(s)
5. test_dataloader(s)

#### prepare_data
Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
1. download
2. tokenize
3. etc…

#### setup
There are also data operations you might want to perform on every GPU. Use setup to do things like:
1. count number of classes
2. build vocabulary
3. perform train/val/test splits
4. apply transforms (defined explicitly in your datamodule or assigned in init)
5. etc…


In [9]:
class ImgtoInChIDataset(Dataset):
    def __init__(self, paths:list, df:pd.DataFrame=None, tsfms:A.Compose=None) -> None:
        self.paths = paths
        if df is not None:
            self.idtoinchi_dict = {
                _id:_inchi for _id, _inchi in
                zip(df["image_id"].values.tolist(), df["InChI"].values.tolist())
            }
        self.tsfms = tsfms
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, idx:int) -> Tuple[torch.Tensor, str]:
        imgpath = self.paths[idx]
        imgid = Path(imgpath).stem
        img = np.array(preprocess_image(imgpath, out_size=INP_SIZE), dtype=np.float32)/255.
        if self.tsfms is not None:
            img = self.tsfms(image=img)["image"]
        
        if hasattr(self, "idtoinchi_dict"):
            target = self.idtoinchi_dict[imgid]
#             target = target.split("/")[1]
            return img, target

        return img, "test_placeholder"
    
class ImgToInChIDataModule(pl.LightningDataModule):
    def __init__(self, tb_logger, valset_ratio=0.05) -> None:
        super().__init__()
        self.tb_logger = tb_logger
        self.valset_ratio = valset_ratio
        self.dims = (1, *INP_SIZE)
        
        self.train_tsfms = A.Compose([
#             A.Resize(*INP_SIZE, always_apply=True),
#             A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#             A.RandomCrop(*INP_SIZE),
#             A.RandomBrightnessContrast(p=0.5),
#             A.Normalize(mean=(0.5), std=(0.229)),
            ToTensorV2(),
        ])
        self.test_tsfms = A.Compose([
#             A.Resize(*INP_SIZE, always_apply=True),
#             A.Normalize(mean=(0.5), std=(0.229)),
            ToTensorV2(),
        ])
        
    def prepare_data(self, verbose:bool=False) -> None:
        """Use this method to do things that might write to disk or that
        need to be done only from a single process in distributed settings."""
        # Load labels in DataFrame
        if verbose: print("Loading labels data...", end=' ')
        self.df = pd.read_csv(LABELS_CSV_PATH)
        if verbose: print("DONE!")
        
        # Load image paths
        if verbose: print("Loading paths...", end=' ')
        if TRAINPATHS_PATH.exists():
            self.train_paths = pd.read_feather(TRAINPATHS_PATH)
            self.train_paths = self.train_paths.train_paths.tolist()
        else:
            self.train_paths = pd.DataFrame(list((Path(DATADIR)/"train").rglob("*.*")), columns=["train_paths"])
            self.train_paths = self.train_paths.applymap(lambda x: str(x))
            self.train_paths.to_feather(TRAINPATHS_PATH)
            self.train_paths = self.train_paths.train_paths.tolist()
        if TESTPATHS_PATH.exists():
            self.test_paths = pd.read_feather(TESTPATHS_PATH)
            self.test_paths = self.test_paths.test_paths.tolist()
        else:
            self.test_paths = pd.DataFrame(list((Path(DATADIR)/"test").rglob("*.*")), columns=["test_paths"])
            self.test_paths = self.test_paths.applymap(lambda x: str(x))
            self.test_paths.to_feather(TESTPATHS_PATH)
            self.test_paths = self.test_paths.test_paths.tolist()
        if verbose: print("DONE!")
        
        # Get Vocab and Tokenizer
        if verbose: print("Loading vocab and tokenizer...", end=' ')
        
        
        if VOCAB_FILEPATH.exists():
            self.tokenizer = Tokenizer.from_file(VOCAB_FILEPATH)
        else:
            self.vocab = Vocab.from_dataframe_column(self.df.InChI)
            self.tokenizer = Tokenizer(self.vocab)
            self.tokenizer.vocab.save_vocab(VOCAB_FILEPATH)
            
        self.vocab_size = self.tokenizer.vocab.vocab_size
        if verbose: print("DONE!")
                
    def setup(self, stage:Optional[str]=None) -> None:
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            trainpaths, valpaths = train_test_split(self.train_paths, test_size=self.valset_ratio)
            self.trainset = ImgtoInChIDataset(trainpaths, self.df, self.train_tsfms)
            self.valset = ImgtoInChIDataset(valpaths, self.df, self.test_tsfms)
            
            # Sample batch
            imgs, inp_seqs, attn_masks = next(iter(self.train_dataloader()))
            self.tb_logger.experiment.add_images("Sample images", imgs)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.testset = ImgtoInChIDataset(self.test_paths, tsfms=self.test_tsfms)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.trainset, BATCH_SIZE, shuffle=True, 
                          collate_fn=self.collate_fn, num_workers=N_WORKERS, pin_memory=True)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.valset, BATCH_SIZE, shuffle=False, 
                          collate_fn=self.collate_fn, num_workers=N_WORKERS, pin_memory=True)
    
    def collate_fn(self, batch:tuple) -> Tuple[torch.Tensor, dict[list], dict[list]]:
        imgs = torch.cat([ins[0].unsqueeze(0) for ins in batch])
        targets = [ins[1] for ins in batch]
        targets = [self.tokenizer.encode(t) for t in targets]
        
        batch_inp_seqs = {k: [] for k in LAYERS_SEQ}
        for sample in [t["inp_seq_dict"] for t in targets]:
            for layer_name in LAYERS_SEQ:
                batch_inp_seqs[layer_name].append(sample[layer_name])
        
        batch_attn_masks = {k: [] for k in LAYERS_SEQ}
        for sample in [t["attn_mask_dict"] for t in targets]:
            for layer_name in LAYERS_SEQ:
                batch_attn_masks[layer_name].append(sample[layer_name])
        return imgs, batch_inp_seqs, batch_attn_masks

In [10]:
# imgs, inp_seqs, attn_masks = next(iter(trainloader))
# tb_logger.experiment.add_images("Sample images", imgs)


# print("SAMPLE BATCH =", imgs.shape)
# fig, axes = plt.subplots(4, 8, figsize=(18, 10))
# for i, ax in enumerate(axes.flat):
#     ax.imshow(imgs[i].squeeze(0), cmap="gray")
#     ax.axis("off")

# Model
### At the time of sentence prediction can we use HMMs or [Beam Search](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#overview) to make better decisions?

Use different LSTM layers for each inchi substring like /c /h
```
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(1, 16, 10)
>>> h0 = torch.randn(2, 16, 20)
>>> c0 = torch.randn(2, 16, 20)
```
change number of layers here to num sublayers


[Model from here](https://www.kaggle.com/yasufuminakama/inchi-resnet-lstm-with-attention-starter)

Network commented in this cell below DID NOT WORKED WELL. FALLING BACK...
<!-- class Encoder(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False, out_channels=512):
        super().__init__()
        last_stride = 2 if INP_SIZE[0] == 256 else 1
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.out_channels = out_channels
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()
        self.resout = nn.Sequential(
            nn.Conv2d(512, out_channels, kernel_size=3, stride=last_stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )
        self.encout = nn.Sequential(
            nn.Conv1d(16, 1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm1d(1),
        )
        
    def forward(self, x):
        out = self.cnn(x)
        out = self.resout(out)
        out = out.view(x.size(0), self.out_channels, -1)
        out = out.permute(0,2,1)
        out = self.encout(out).squeeze(1)
#         print("ENC out =", out.shape)
        return out

class Decoder(nn.Module):
    def __init__(self, vocab_size, enc_out_channels=512):
        super().__init__()
        self.embd = nn.Embedding(vocab_size, EMB_SIZE)
        self.lstm = nn.LSTM(enc_out_channels, vocab_size, batch_first=True)
        self.encfc = nn.Linear(enc_out_channels, vocab_size)
        self.embfc = nn.Linear(vocab_size, vocab_size)
    
    def forward(self, encoder_out, inp_seqs):
        encoder_out = self.encfc(encoder_out)
        encoder_out = encoder_out.unsqueeze(1)
#         print("encoder_out =", encoder_out.shape)
        
        emb = self.embd(inp_seqs)
        lstm_out, _ = self.lstm(emb)
#         lstm_out = lstm_out.reshape(encoder_out.size(0), -1)
#         print("lstm_out =", lstm_out.shape)
        enc_out = torch.repeat_interleave(encoder_out, lstm_out.size(1), dim=1)
#         print("enc_out =", enc_out.shape)
        out = lstm_out + enc_out
#         for i in range(lstm_out.size(1)):
#             lstm_out[:,i,:] = self.embfc(lstm_out[:,i,:])
#             lstm_out[:,i,:] = lstm_out[:,i,:] + encoder_out
#         print("lstm_out after for =", lstm_out.shape)
        return out
    
    def predict(self, encoder_out, tokenizer):
#         print(encoder_out.shape)
        encoder_out = self.encfc(encoder_out)
        bs = encoder_out.size(0)
        syn_inp_seqs = torch.tensor(tokenizer.vocab.ctoi(tokenizer.vocab.bos_token), device=encoder_out.device)
        syn_inp_seqs = torch.repeat_interleave(syn_inp_seqs, bs)
        syn_inp_seqs = syn_inp_seqs.view(bs, 1)
        
        # Predict next tokens to start token
        pred_emb = []
        for i in range(MAX_LEN):
            emb = self.embd(syn_inp_seqs)
            pred, _ = self.lstm(emb)
            pred = pred.squeeze(1)
            
#             pred = self.embfc(pred)
            pred = pred + encoder_out
            pred = pred.unsqueeze(1)
            pred_emb.append(pred)
            syn_inp_seqs = pred.argmax(dim=-1)
#         print("len =", len(pred_emb))
        pred_emb = torch.cat(pred_emb, dim=1)
        return pred_emb
            
class InChINet(pl.LightningModule):
    def __init__(self, vocab_size, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.encoder_net = Encoder()
        self.decoder_net = Decoder(vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, imgs, inp_seqs):
        encoder_out = self.encoder_net(imgs)
#         print("Encoder =", encoder_out.shape)
        pred_tokens = self.decoder_net(encoder_out, inp_seqs)
        return pred_tokens
    
    def training_step(self, train_batch, batch_idx):
        imgs, inp_seqs, attn_masks = train_batch
        output = self.forward(imgs, inp_seqs)
        loss = self.loss_fn(output.permute(0,2,1).float(), inp_seqs)
        # Logging to TensorBoard by default
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return loss
    
    def training_epoch_end(self, outputs):
        for name,params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)
    
    def validation_step(self, val_batch, batch_idx):
        imgs, inp_seqs, attn_masks = val_batch
        output = self.predict(imgs, self.tokenizer)
        loss = self.loss_fn(output.permute(0,2,1).float(), inp_seqs)
        self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True)
        
        lv_metric = self.calculate_lvdistance(output, inp_seqs)
        self.logger.log_metrics({"LvDistance": lv_metric}, step=1)
        return loss
    
    def predict(self, imgs, tokenizer):
        encoder_out = self.encoder_net(imgs)
        pred_tokens = self.decoder_net.predict(encoder_out, tokenizer)
        return pred_tokens

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        lr_scheduler = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(optimizer, LR, epochs=EPOCHS, steps_per_epoch=8996),
            'name': 'OneCycleLR'
        }
        return [optimizer], [lr_scheduler]
    
    def inference(self, imgs):
        output = self.predict(imgs, self.tokenizer)
        return self.postprocessing(output)
    
    def calculate_lvdistance(self, output, target):
        pred_seqs = self.postprocessing(output)
        batch_distance = np.mean([
            Levenshtein.distance(pred_seq, self.tokenizer.decode(inp_seq))
            for pred_seq, inp_seq in zip(pred_seqs, target)
        ])
        return batch_distance
    
    def postprocessing(self, output):
        final_preds = []
        pred_tokens = output.argmax(dim=-1)
        for i in range(pred_tokens.size(0)): # iterate on each sample
            pred = pred_tokens[i].unique(dim=-1).tolist()
            pred = self.tokenizer.decode(pred)
            res = re.search(r'C', pred)
            if res:
                pred = pred[res.span()[0]:]
            final_preds.append(pred)
        return final_preds
    

# dm = ImgToInChIDataModule(tb_logger=tb_logger)
# dm.prepare_data(verbose=True)
# dm.setup('fit')
# imgs, inp_seqs, attn_masks = next(iter(dm.train_dataloader()))

# encoder_net = Encoder()
# encoder_out = encoder_net(imgs)
# print("Encoder =", encoder_out.shape)

# decoder_net = Decoder(dm.vocab_size)
# pred_tokens = decoder_net(encoder_out, inp_seqs)
# print(pred_tokens.shape)

# pred_tokens = decoder_net.predict(encoder_out, dm.tokenizer)
# pred_tokens.shape -->

In [43]:
class Encoder(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False, out_channels=512):
        super().__init__()
        last_stride = 2 if INP_SIZE[0] == 256 else 1
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.out_channels = out_channels
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()
        self.outfc = nn.Sequential(
            nn.Conv2d(512, out_channels, kernel_size=3, stride=last_stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x):
        out = self.cnn(x)
        out = self.outfc(out)
        out = out.view(x.size(0), self.out_channels, -1)
        out = out.permute(0, 2, 1)
        return out

class Decoder(nn.Module):
    def __init__(self, vocab_size, max_len, enc_out_channels=512):
        super().__init__()
        self.embd = nn.Embedding(vocab_size, EMB_SIZE)
        self.lstm = nn.LSTM(enc_out_channels, vocab_size, batch_first=True)
        self.decfc = nn.Linear(64, max_len)
    
    def forward(self, encoder_out, inp_seqs):
        inp_seqs = torch.tensor(inp_seqs, device=encoder_out.device)
        emb = self.embd(inp_seqs)
        lstm_out, _ = self.lstm(emb)
        out = lstm_out + encoder_out
        return out
    
    def predict(self, encoder_out, tokenizer):
        bs = encoder_out.size(0)
        syn_inp_seqs = torch.tensor(tokenizer.vocab.ctoi(tokenizer.vocab.bos_token), device=encoder_out.device)
        syn_inp_seqs = torch.repeat_interleave(syn_inp_seqs, bs)
        syn_inp_seqs = syn_inp_seqs.view(bs, 1)
        
        # Predict next tokens to start token
        pred_emb = []
        for i in range(MAX_LEN):
            emb = self.embd(syn_inp_seqs)
            pred, _ = self.lstm(emb)
            pred_emb.append(pred)
            syn_inp_seqs = pred.argmax(dim=-1)
        pred_emb = torch.cat(pred_emb, dim=1)
        out = pred_emb + encoder_out
        return out
            
class InChINet(pl.LightningModule):
    def __init__(self, vocab_size, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.encoder_net = Encoder(out_channels=vocab_size)
        self.decoder_net = Decoder(vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, imgs, inp_seqs):
        encoder_out = self.encoder_net(imgs)
#         print("Encoder =", encoder_out.shape)
        pred_tokens = self.decoder_net(encoder_out, inp_seqs)
        return pred_tokens
    
    def training_step(self, train_batch, batch_idx):
        imgs, inp_seqs, attn_masks = train_batch
        output = self.forward(imgs, inp_seqs)
        loss = self.loss_fn(output.permute(0,2,1).float(), inp_seqs)
        # Logging to TensorBoard by default
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return loss
    
    def training_epoch_end(self, outputs):
        for name,params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)
    
    def validation_step(self, val_batch, batch_idx):
        imgs, inp_seqs, attn_masks = val_batch
        output = self.predict(imgs, self.tokenizer)
        loss = self.loss_fn(output.permute(0,2,1).float(), inp_seqs)
        self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True)
        
        lv_metric = self.calculate_lvdistance(output, inp_seqs)
        self.logger.log_metrics({"LvDistance": lv_metric}, step=1)
        return loss
    
    def predict(self, imgs, tokenizer):
        encoder_out = self.encoder_net(imgs)
        pred_tokens = self.decoder_net.predict(encoder_out, tokenizer)
        return pred_tokens

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        lr_scheduler = {
            'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100),
            'name': 'AnnealingLR'
        }
        return [optimizer], [lr_scheduler]
    
    def inference(self, imgs):
        output = self.predict(imgs, self.tokenizer)
        return self.postprocessing(output)
    
    def calculate_lvdistance(self, output, target):
        pred_seqs = self.postprocessing(output)
        batch_distance = np.mean([
            Levenshtein.distance(pred_seq, self.tokenizer.decode(inp_seq))
            for pred_seq, inp_seq in zip(pred_seqs, target)
        ])
        return batch_distance
    
    def postprocessing(self, output):
        final_preds = []
        pred_tokens = output.argmax(dim=-1)
        for i in range(pred_tokens.size(0)): # iterate on each sample
            pred = pred_tokens[i].unique(dim=-1).tolist()
            pred = self.tokenizer.decode(pred)
            res = re.search(r'C', pred)
            if res:
                pred = pred[res.span()[0]:]
            final_preds.append(pred)
        return final_preds
    
    
# dm = ImgToInChIDataModule(tb_logger=tb_logger)
# dm.prepare_data(verbose=True)
# dm.setup('fit')
# imgs, inp_seqs, attn_masks = next(iter(dm.train_dataloader()))

# encoder_net = Encoder()
# encoder_out = encoder_net(imgs)
print("Encoder =", encoder_out.shape)

decoder_net = Decoder(dm.vocab_size["main_layer"], dm.tokenizer.vocab.max_lengths["main_layer"])
pred_tokens = decoder_net(encoder_out, inp_seqs["main_layer"])
print(pred_tokens.shape)

# pred_tokens = decoder_net.predict(encoder_out, tokenizer)
# pred_tokens.shape

Encoder = torch.Size([256, 16, 512])


RuntimeError: The size of tensor a (427) must match the size of tensor b (512) at non-singleton dimension 2

In [33]:
dm.tokenizer.vocab.max_lengths

{'main_layer': 13,
 'c_layer': 187,
 'h_layer': 112,
 't_layer': 45,
 'm_layer': 7,
 's_layer': 5,
 'b_layer': 68,
 'i_layer': 51}

In [34]:
dm.tokenizer.vocab.vocab_size

{'main_layer': 427,
 'c_layer': 101,
 'h_layer': 101,
 't_layer': 75,
 'm_layer': 8,
 's_layer': 6,
 'b_layer': 81,
 'i_layer': 55}

In [42]:
dm.df.InChI[2]

'InChI=1S/C24H23N5O4/c1-14-13-15(7-8-17(14)28-12-10-20(28)30)27-11-9-16-21(23(25)31)26-29(22(16)24(27)32)18-5-3-4-6-19(18)33-2/h3-8,13H,9-12H2,1-2H3,(H2,25,31)'

# Training and Validation

In [None]:
# %tensorboard --logdir {CHKPTDIR}

dm = ImgToInChIDataModule(tb_logger=tb_logger)
dm.prepare_data(verbose=True)

model = InChINet(dm.vocab_size, dm.tokenizer)
# Add network graph to tensorboard
# tb_logger.log_graph(model, [imgs[0].unsqueeze(0).to(model.device), inp_seqs[0].unsqueeze(0).to(model.device)])
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(gpus=1, auto_lr_find=True, max_epochs=1, precision=PRECISION, profiler="simple", 
                     default_root_dir=CHKPTDIR, logger=tb_logger, callbacks=[lr_monitor])

trainer.fit(model, dm)

In [None]:
device = torch.device("cuda")
total_distance = []
for (imgs, inp_seqs, attn_masks) in progress_bar(dm.val_dataloader()):
    model.eval()
    model = model.to(device)
    imgs = imgs.to(device)
    pred_seqs = model.inference(imgs)
#     print([(pred_seq, dm.tokenizer.decode(inp_seq)) for pred_seq, inp_seq in zip(pred_seqs, inp_seqs)])
#     break
    batch_distance = [
        Levenshtein.distance(pred_seq, dm.tokenizer.decode(inp_seq))
        for pred_seq, inp_seq in zip(pred_seqs, inp_seqs)
    ]
    total_distance += batch_distance
np.mean(total_distance)

In [None]:
# 2.3149080108901905