## Try these two experiments first
1. Create original input image to inchi string translator and get accuracy and loss
2. Create inchi image to inchi string translator and get accuracy and loss

Now, compare the two. If the accuracy of inchi image -> inchi string is significantly higher than the original image -> inchi string then think about ***reconstruction experiments*** from original image to inchi image. [try this then](https://www.google.com/search?channel=fs&client=ubuntu&q=converting+shapes+from+one+to+another+using+deep+learning).

# InChI decoding 
[Source](https://link.springer.com/content/pdf/10.1186%2Fs13321-015-0068-4.pdf)

1. **Skeletal connections layer** This layer prefixed with `/c` represents connections between skeletal atoms by listing the canonical numbers in the chain of connected atoms. 
2. ***branches are given in parentheses***
3. The canonical atomic numbers, which are used throughout the InChI, are always given in the formula’s element order. i.e. precendence is given to element according to periodic table while numbering elements. For example, `/C10H16N5O13P3` (the beginning of InChI for adenosine triphosphate) implies that atoms numbered 1–10 are carbons, 11–15 arenitrogens, 16–28 are oxygens, and 29–31 are phosporus. Hydrogen atoms are not explicitly numbered.


## image to inchi string

In [1]:
# Special Packages
# !pip install PeriodicElements
# !pip install albumentations
# !pip install timm
# !pip install python-Levenshtein
# !pip install torchmetrics

In [2]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2

In [3]:
import torch, torchmetrics, timm, re, pickle, Levenshtein
import torch.nn as nn
import torchvision as tv
import pytorch_lightning as pl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from pathlib import Path
from functools import partial
from collections import defaultdict
from fastprogress import progress_bar
from typing import Optional, Union
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from elements import elements
from albumentations.pytorch import ToTensorV2
from preprocessing import preprocess_image


# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
torch.manual_seed(manualSeed);

# This monkey-patch is there to be able to plot tensors
torch.Tensor.ndim = property(lambda x: len(x.shape))

Random Seed:  999


In [4]:
CHKPTDIR = Path("TranslationChkpts")
DATADIR = "data/bms-molecular-translation"
LABELS_CSV_PATH = "data/train_labels.csv"
VOCAB_FILEPATH = CHKPTDIR/"vocab.pt"
TRAINPATHS_PATH = CHKPTDIR/"train_paths.feather"
TESTPATHS_PATH = CHKPTDIR/"test_paths.feather"
CHKPTDIR.mkdir(parents=True, exist_ok=True)

tb_logger = pl.loggers.TensorBoardLogger(CHKPTDIR, name="InchINet")

N_WORKERS = 4
BATCH_SIZE = 256
PRECISION = 16
MAX_LEN = 16 # computed using corpus - max([len(vocab.tokenize(c)) for c in corpus]) 9 + 1 pad + 2 enclosing tokens
EMB_SIZE = 512
HDN_SIZE = 10
INP_SIZE = (128, 128)
N_INP_CH = 1
N_OUT_CH = 3
LR = 1e-2
EPOCHS = 10
beta1 = 0.5

In [5]:
!ls {DATADIR}

sample_submission.csv  test  train  train_labels.csv


# Data Block

### LightningDataModule API

To define a DataModule define 5 methods:
1. prepare_data (how to download(), tokenize, etc…)
2. setup (how to split, etc…)
3. train_dataloader
4. val_dataloader(s)
5. test_dataloader(s)

#### prepare_data
Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
1. download
2. tokenize
3. etc…

#### setup
There are also data operations you might want to perform on every GPU. Use setup to do things like:
1. count number of classes
2. build vocabulary
3. perform train/val/test splits
4. apply transforms (defined explicitly in your datamodule or assigned in init)
5. etc…


## Vocab and Tokenizer

In [6]:
class RawDataset(Dataset):
    def __init__(self, datadir, df=None):
        super().__init__()
        self.paths = list(Path(datadir).rglob("*.*"))
        if df is not None:
            self.idtoinchi_dict = {
                _id:_inchi for _id, _inchi in
                zip(df["image_id"].values.tolist(), df["InChI"].values.tolist())
            }
        if len(self.paths) == 0:
            print("No paths found.")
        self.piltotensor = PILToTensor()
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        imgpath = self.paths[idx]
        imgid = imgpath.stem
        img = preprocess_image(imgpath, out_size=INP_SIZE)
        img = torch.from_numpy(np.array(img))
        if hasattr(self, "idtoinchi_dict"):
            target = self.idtoinchi_dict[imgid]
            return img, target
        return img, "test_placeholder"

In [7]:
class Vocab:
    def __init__(self, vocab, add_special_tokens=True):
        self.vocab = vocab
        # Get all elements sorted according to Atomic number
        data = elements.Elements
        # All elements in periodic table
        self.elements = sorted(data, key=lambda i:i.AtomicNumber)  # Based on their AtomicNumber
        # Sort longer names first for regex pattern formation
        self.element_symbols = sorted([e.Symbol for e in self.elements], key=lambda e: len(e), reverse=True)
        # Create regex pattern
        self.pattern = f"({'|'.join([f'{e}[0-9]*' for e in self.element_symbols])})"
        
        if type(vocab) != type(None):
            if add_special_tokens:
                self.pad_token = "<pad>"
                self.unk_token = "<unk>"
                self.bos_token = "<bos>"
                self.eos_token = "<eos>"
                self.vocab = [self.pad_token, self.unk_token,self.bos_token,self.eos_token] + list(self.vocab)
            
            # Adding elements names into vocab and sort according to atomic number
#             self.vocab = np.unique(self.vocab.tolist() + self.element_symbols)
#             self.vocab = sorted(self.vocab.tolist(), key=lambda x: eval(f"elements.{''.join(re.findall(r'[A-Za-z]', x))}.AtomicNumber"))
            # create class to index mapping
            self._ctoi = defaultdict(lambda : self.unk_token, {c:i for i, c in enumerate(self.vocab)})
                
    def tokenize(self, string):
        tokens = re.split(self.pattern, string)
        tokens = list(filter(None, tokens))
        return tokens
    
    def ctoi(self, c):
        return self._ctoi[c]
    
    def itoc(self, i):
        return self.vocab[i]
    
    def __len__(self):
        return len(self.vocab)
    
    def save_vocab(self, path):
        torch.save(self.vocab, path)
        print("Saved @", path)
    
    @classmethod
    def from_corpus(cls, corpus):
        v = cls(None)
        vocab = np.unique([w for s in corpus for w in v.tokenize(s)])
        return cls(vocab)
    
    @classmethod
    def load_vocab(cls, path):
        vocab = torch.load(path)
        return cls(vocab)
    
# Reference - https://huggingface.co/transformers/v2.11.0/main_classes/tokenizer.html#pretrainedtokenizer
class Tokenizer:
    def __init__(self, vocab=None):
        self.vocab = vocab # Vocab class instance
    
    def tokenizer(self, x):
        return self.vocab.tokenizer(x)
    
    def encode(self, s, max_len=None):
        tokens = self.vocab.tokenize(s)
        seq = [self.vocab.ctoi(t) for t in tokens]
        attn_mask = [1]*len(seq)
        
        if max_len:
            # Add padding to input
            extra_len = max_len - len(seq) - 2 # 2 for start and end tokens
            # Add start input token
            seq = [self.vocab.ctoi(self.vocab.bos_token)] + seq
            # Add end input token
            seq += [self.vocab.ctoi(self.vocab.eos_token)]
            attn_mask += [1, 1]
            # Add padding token
            seq += [self.vocab.ctoi(self.vocab.pad_token)]*extra_len
            attn_mask += [0]*extra_len
            
        return {"inp_seq": seq, "attn_mask": attn_mask}
    
    def decode(self, tokens, inp_seq_name="inp_seq"):
        if isinstance(tokens, dict):
            seq = tokens[inp_seq_name]
            if isinstance(seq, (torch.Tensor, np.ndarray)):
                seq = seq.tolist()
        else:
            seq = tokens
        seq = ''.join([self.vocab.itoc(t) for t in seq])
        # remove special tokens
        for special_token in [self.vocab.pad_token, self.vocab.bos_token, self.vocab.eos_token]:
            seq = seq.replace(special_token, '')
        return seq
    
    @classmethod
    def fit(cls, corpus):
        vocab = Vocab.from_corpus(orcpus)
        return cls(vocab)
    
    @classmethod
    def load_from_file(cls, path):
        vocab = torch.load(path)
        return cls(vocab)

## Lightning Data Module

In [8]:
class ImgtoInChIDataset(Dataset):
    def __init__(self, paths, df=None, tsfms=None):
        self.paths = paths
        if df is not None:
            self.idtoinchi_dict = {
                _id:_inchi for _id, _inchi in
                zip(df["image_id"].values.tolist(), df["InChI"].values.tolist())
            }
        self.tsfms = tsfms
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        imgpath = self.paths[idx]
        imgid = Path(imgpath).stem
        img = np.array(preprocess_image(imgpath, out_size=INP_SIZE), dtype=np.float32)/255.
        if self.tsfms is not None:
            img = self.tsfms(image=img)["image"]
        
        if hasattr(self, "idtoinchi_dict"):
            target = self.idtoinchi_dict[imgid]
            target = target.split("/")[1]
            return img, target
        
        return img, "test_placeholder"
    
class ImgToInChIDataModule(pl.LightningDataModule):
    def __init__(self, tb_logger, valset_ratio=0.05) -> None:
        super().__init__()
        self.tb_logger = tb_logger
        self.valset_ratio = valset_ratio
        self.dims = (1, *INP_SIZE)
        
        self.train_tsfms = A.Compose([
#             A.Resize(*INP_SIZE, always_apply=True),
#             A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#             A.RandomCrop(*INP_SIZE),
#             A.RandomBrightnessContrast(p=0.5),
#             A.Normalize(mean=(0.5), std=(0.229)),
            ToTensorV2(),
        ])
        self.test_tsfms = A.Compose([
#             A.Resize(*INP_SIZE, always_apply=True),
#             A.Normalize(mean=(0.5), std=(0.229)),
            ToTensorV2(),
        ])
        
    def prepare_data(self, verbose=False):
        """Use this method to do things that might write to disk or that
        need to be done only from a single process in distributed settings."""
        # Load labels in DataFrame
        if verbose: print("Loading labels data...", end=' ')
        self.df = pd.read_csv(LABELS_CSV_PATH)
        if verbose: print("DONE!")
        
        # Load image paths
        if verbose: print("Loading paths...", end=' ')
        if TRAINPATHS_PATH.exists():
            self.train_paths = pd.read_feather(TRAINPATHS_PATH)
            self.train_paths = self.train_paths.train_paths.tolist()
        else:
            self.train_paths = pd.DataFrame(list((Path(DATADIR)/"train").rglob("*.*")), columns=["train_paths"])
            self.train_paths = self.train_paths.applymap(lambda x: str(x))
            self.train_paths.to_feather(TRAINPATHS_PATH)
        if TESTPATHS_PATH.exists():
            self.test_paths = pd.read_feather(TESTPATHS_PATH)
            self.test_paths = self.test_paths.test_paths.tolist()
        else:
            self.test_paths = pd.DataFrame(list((Path(DATADIR)/"test").rglob("*.*")), columns=["test_paths"])
            self.test_paths = self.test_paths.applymap(lambda x: str(x))
            self.test_paths.to_feather(TESTPATHS_PATH)
        if verbose: print("DONE!")
        
        # Get Vocab and Tokenizer
        if verbose: print("Loading vocab and tokenizer...", end=' ')
        if Path(VOCAB_FILEPATH).exists():
            vocab = Vocab.load_vocab(VOCAB_FILEPATH)
        else:
            corpus = [s.split("/")[1] for s in self.df.InChI.tolist()]
            vocab = Vocab.from_corpus(corpus)
#             print("# words =", len(vocab))
            vocab.save_vocab(VOCAB_FILEPATH)
        self.vocab_size = len(vocab)
        self.tokenizer = Tokenizer(vocab)
        if verbose: print("DONE!")
                
    def setup(self, stage:Optional[str]=None) -> None:
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            trainpaths, valpaths = train_test_split(self.train_paths, test_size=self.valset_ratio)
            self.trainset = ImgtoInChIDataset(trainpaths, self.df, self.train_tsfms)
            self.valset = ImgtoInChIDataset(valpaths, self.df, self.test_tsfms)
            
            # Sample batch
            imgs, inp_seqs, attn_masks = next(iter(self.train_dataloader()))
            self.tb_logger.experiment.add_images("Sample images", imgs)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.testset = ImgtoInChIDataset(self.test_paths, tsfms=self.test_tsfms)

    def train_dataloader(self):
        return DataLoader(self.trainset, BATCH_SIZE, shuffle=True, 
                          collate_fn=self.collate_fn, num_workers=N_WORKERS, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.valset, BATCH_SIZE, shuffle=False, 
                          collate_fn=self.collate_fn, num_workers=N_WORKERS, pin_memory=True)
    
    def collate_fn(self, batch):
        imgs = torch.cat([ins[0].unsqueeze(0) for ins in batch])
        targets = [ins[1] for ins in batch]
        targets = [self.tokenizer.encode(t, MAX_LEN) for t in targets]
        inp_seqs = torch.Tensor([t["inp_seq"] for t in targets]).long()
        attn_masks = torch.Tensor([t["attn_mask"] for t in targets]).float()
        return imgs, inp_seqs, attn_masks
    

In [9]:
# imgs, inp_seqs, attn_masks = next(iter(trainloader))
# tb_logger.experiment.add_images("Sample images", imgs)


# print("SAMPLE BATCH =", imgs.shape)
# fig, axes = plt.subplots(4, 8, figsize=(18, 10))
# for i, ax in enumerate(axes.flat):
#     ax.imshow(imgs[i].squeeze(0), cmap="gray")
#     ax.axis("off")

# Model
### At the time of sentence prediction can we use HMMs or [Beam Search](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning#overview) to make better decisions?

Use different LSTM layers for each inchi substring like /c /h
```
>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(1, 16, 10)
>>> h0 = torch.randn(2, 16, 20)
>>> c0 = torch.randn(2, 16, 20)
```
change number of layers here to num sublayers


[Model from here](https://www.kaggle.com/yasufuminakama/inchi-resnet-lstm-with-attention-starter)

In [10]:
class Encoder(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False):
        super().__init__()
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.n_features = self.cnn.fc.in_features
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()

    def forward(self, x):
        bs = x.size(0)
        features = self.cnn(x)
        features = features.permute(0, 2, 3, 1)
        return features
    
class Attention(nn.Module):
    """
    Attention network for calculate attention value
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: input size of encoder network
        :param decoder_dim: input size of decoder network
        :param attention_dim: input size of attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return attention_weighted_encoding, alpha
    
class DecoderWithAttention(nn.Module):
    """Decoder network with attention network used for training"""

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim=512, dropout=0.5):
        """
        :param attention_dim: input size of attention network
        :param embed_dim: input size of embedding network
        :param decoder_dim: input size of decoder network
        :param vocab_size: total number of characters used in training
        :param encoder_dim: input size of encoder network
        :param dropout: dropout rate
        """
        super().__init__()
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.device = device
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network
        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        :param encoder_out: output of encoder network
        :param encoded_captions: transformed sequence from character to integer
        :param caption_lengths: length of transformed sequence
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        # embedding transformed sequence for vector
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)
        # initialize hidden state and cell state of LSTM cell
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)
        # set decode length by caption length - 1 because of omitting start token
        decode_lengths = (caption_lengths - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(self.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        # predict sequence
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind
    
    def predict(self, encoder_out, decode_lengths, tokenizer):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        # embed start tocken for LSTM input
        start_tockens = torch.ones(batch_size, dtype=torch.long).to(self.device) * tokenizer.stoi["<sos>"]
        embeddings = self.embedding(start_tockens)
        # initialize hidden state and cell state of LSTM cell
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)
        predictions = torch.zeros(batch_size, decode_lengths, vocab_size).to(self.device)
        # predict sequence
        for t in range(decode_lengths):
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings, attention_weighted_encoding], dim=1),
                (h, c))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:, t, :] = preds
            if np.argmax(preds.detach().cpu().numpy()) == tokenizer.stoi["<eos>"]:
                break
            embeddings = self.embedding(torch.argmax(preds, -1))
        return predictions

In [11]:
# encoder_net = Encoder()
# encoder_net(imgs).shape

In [12]:
class Encoder(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False, out_channels=512):
        super().__init__()
        last_stride = 2 if INP_SIZE[0] == 256 else 1
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        self.cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.out_channels = out_channels
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()
        self.outfc = nn.Sequential(
            nn.Conv2d(512, out_channels, kernel_size=3, stride=last_stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )
        self.maxpool = nn.MaxPool2d(4,1,ceil_mode=True)

    def forward(self, x):
        out = self.cnn(x)
        out = self.outfc(out)
        out = out.view(x.size(0), self.out_channels, -1)
        out = out.permute(0, 2, 1)
        
#         out = self.maxpool(out)
#         out = out.view(out.size(0), -1)
        return out

class Decoder(nn.Module):
    def __init__(self, vocab_size, enc_out_channels=512, device=torch.device("cuda")):
        super().__init__()
        
        
        self.embd = nn.Embedding(vocab_size, EMB_SIZE)
        
        self.init_h = nn.Linear(enc_out_channels, MAX_LEN)  # linear layer to find initial hidden state
        self.init_c = nn.Linear(enc_out_channels, MAX_LEN)  # linear layer to find initial cell state
        self.lstm = nn.LSTM(enc_out_channels, vocab_size, batch_first=True)
        self.decfc = nn.Linear(64, MAX_LEN)
        
        self.softmax = nn.Softmax(dim=-1)
        self.device = device
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c
    
    def forward(self, encoder_out, inp_seqs):
#         print("inp_seq before emb =", inp_seqs.shape)
        emb = self.embd(inp_seqs)
#         print(f"emb = {emb.shape}, encoder_out = {encoder_out.shape}")
        lstm_out, _ = self.lstm(emb)
        out = lstm_out + encoder_out
        return out
    
    def predict(self, encoder_out, tokenizer):
        bs = encoder_out.size(0)
        syn_inp_seqs = torch.tensor(tokenizer.vocab.ctoi(tokenizer.vocab.bos_token), device=encoder_out.device)
        syn_inp_seqs = torch.repeat_interleave(syn_inp_seqs, bs)
        syn_inp_seqs = syn_inp_seqs.view(bs, 1)
#         print("syn_inp_seqs before emb =", syn_inp_seqs.shape)
        
        # Predict next tokens to start token
        pred_emb = []
        for i in range(MAX_LEN):
            emb = self.embd(syn_inp_seqs)
            pred, _ = self.lstm(emb)
            pred_emb.append(pred)
            syn_inp_seqs = pred.argmax(dim=-1)
#         print("len =", len(pred_emb))
        pred_emb = torch.cat(pred_emb, dim=1)
#         print("pred_emb =", pred_emb.shape)
        out = pred_emb + encoder_out
        return out
            
class InChINet(pl.LightningModule):
    def __init__(self, vocab_size, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.encoder_net = Encoder(out_channels=vocab_size)
        self.decoder_net = Decoder(vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, imgs, inp_seqs):
        encoder_out = self.encoder_net(imgs)
#         print("Encoder =", encoder_out.shape)
        pred_tokens = self.decoder_net(encoder_out, inp_seqs)
        return pred_tokens
    
    def training_step(self, train_batch, batch_idx):
        imgs, inp_seqs, attn_masks = train_batch
        output = self.forward(imgs, inp_seqs)
        loss = self.loss_fn(output.permute(0,2,1).float(), inp_seqs)
        # Logging to TensorBoard by default
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return loss
    
    def training_epoch_end(self, outputs):
        for name,params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)
    
    def validation_step(self, val_batch, batch_idx):
        imgs, inp_seqs, attn_masks = val_batch
        output = self.predict(imgs, self.tokenizer)
        loss = self.loss_fn(output.permute(0,2,1).float(), inp_seqs)
        self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True)
        
        lv_metric = self.calculate_lvdistance(output, inp_seqs)
        self.logger.log_metrics({"LvDistance": lv_metric}, step=1)
        return loss
    
    def predict(self, imgs, tokenizer):
        encoder_out = self.encoder_net(imgs)
        pred_tokens = self.decoder_net.predict(encoder_out, tokenizer)
        return pred_tokens

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        lr_scheduler = {
            'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100),
            'name': 'AnnealingLR'
        }
        return [optimizer], [lr_scheduler]
    
    def inference(self, imgs):
        output = self.predict(imgs, self.tokenizer)
        return self.postprocessing(output)
    
    def calculate_lvdistance(self, output, target):
        pred_seqs = self.postprocessing(output)
        batch_distance = np.mean([
            Levenshtein.distance(pred_seq, self.tokenizer.decode(inp_seq))
            for pred_seq, inp_seq in zip(pred_seqs, target)
        ])
        return batch_distance
    
    def postprocessing(self, output):
        final_preds = []
        pred_tokens = output.argmax(dim=-1)
        for i in range(pred_tokens.size(0)): # iterate on each sample
            pred = pred_tokens[i].unique(dim=-1).tolist()
            pred = self.tokenizer.decode(pred)
            res = re.search(r'C', pred)
            if res:
                pred = pred[res.span()[0]:]
            final_preds.append(pred)
        return final_preds
    
# encoder_net = Encoder()
# decoder_net = Decoder(len(vocab))
# encoder_out = encoder_net(imgs)
# print("Encoder =", encoder_out.shape)
# print(inp_seqs.shape)
# pred_tokens = decoder_net(encoder_out, inp_seqs, tokenizer)
# print(pred_tokens.shape)
# pred_tokens = decoder_net.predict(encoder_out, tokenizer)
# pred_tokens.shape

# Training and Validation

In [None]:
%tensorboard --logdir {CHKPTDIR}

dm = ImgToInChIDataModule(tb_logger=tb_logger)
dm.prepare_data(verbose=True)
model = InChINet(dm.vocab_size, dm.tokenizer)
# Add network graph to tensorboard
# tb_logger.log_graph(model, [imgs[0].unsqueeze(0).to(model.device), inp_seqs[0].unsqueeze(0).to(model.device)])
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(gpus=1, auto_lr_find=True, max_epochs=10, precision=PRECISION, profiler="simple", 
                     default_root_dir=CHKPTDIR, logger=tb_logger, callbacks=[lr_monitor])

trainer.fit(model, dm)

Reusing TensorBoard on port 6006 (pid 106034), started 0:21:59 ago. (Use '!kill 106034' to kill it.)

Loading labels data... DONE!
Loading paths... DONE!
Loading vocab and tokenizer... DONE!


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type             | Params
-------------------------------------------------
0 | encoder_net | Encoder          | 13.2 M
1 | decoder_net | Decoder          | 1.9 M 
2 | loss_fn     | CrossEntropyLoss | 0     
-------------------------------------------------
15.0 M    Trainable params
0         Non-trainable params
15.0 M    Total params
60.055    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
# device = torch.device("cuda")
# total_distance = []
# for (imgs, inp_seqs, attn_masks) in progress_bar(dm.val_dataloader()):
#     model.eval()
#     model = model.to(device)
#     imgs = imgs.to(device)
#     pred_seqs = model.inference(imgs)
#     batch_distance = np.mean([
#         Levenshtein.distance(pred_seq, dm.tokenizer.decode(inp_seq))
#         for pred_seq, inp_seq in zip(pred_seqs, inp_seqs)
#     ])
#     print(batch_distance)
#     total_distance += batch_distance
# np.nanmean(total_distance)