In [None]:
import gc
from tqdm import tqdm
import pandas as pd

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from mlm_pytorch.mlm_pytorch.mlm_pytorch import MLM
from x_transformers.x_transformers import TransformerWrapper, Encoder, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
from SmilesPE.pretokenizer import atomwise_tokenizer

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
class Zinc(Dataset):
    def __init__(self, data_dir: str='/scratch/arihanth.srikar', split: str='train', to_gen: int=-1):
        extra = ''
        
        # dataset files
        # df = pd.read_pickle(f'{data_dir}/x001{extra}.pickle')
        df = pd.read_csv(f'{data_dir}/x001.csv')
        df = df[df['set'] == split].copy()
        
        # read entire dataset and convert to list
        self.smiles = df['smiles'].tolist()
        
        # clear memory
        del df
        
        # load specified number of samples
        self.to_gen = to_gen if to_gen > 0 else len(self.smiles)
        
        # token encoder and decoder
        with open(f'{data_dir}/vocab{extra}.txt', 'r') as f:
            self.token_decoder = f.read().splitlines()
        self.token_encoder = {k: v for v, k in enumerate(self.token_decoder)}

        self.vocab_size = len(self.token_decoder)
        self.pad_token_id = self.token_encoder['<pad>']

    def __len__(self):
        return self.to_gen

    def __getitem__(self, idx):
        
        # pick random indices if not utilizing entire dataset
        if self.to_gen != len(self.smiles):
            idx = torch.randint(0, len(self.smiles), (1,)).item()
        
        # treat the smiles as products
        p = self.smiles[idx]
        p = [self.token_encoder[tok] for tok in atomwise_tokenizer(p)]
        
        # append end of products token
        p = [self.token_encoder['<sop>']] + p + [self.token_encoder['<eop>']]
        mask = [1] * len(p)
        
        return torch.tensor(p), torch.tensor(mask)


    def collate_fn(self, batch):
        smiles, mask = zip(*batch)
        smiles = torch.nn.utils.rnn.pad_sequence(smiles, batch_first=True, padding_value=self.token_encoder['<pad>'])
        mask = (smiles != self.token_encoder['<pad>']).bool()
        return smiles, mask

In [None]:
train_dataset = Zinc(split='train', to_gen=100*384)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=128, collate_fn=train_dataset.collate_fn, shuffle=True, num_workers=8, pin_memory=True, prefetch_factor=4)

In [None]:
train_dataset.vocab_size

In [None]:
smiles, mask = next(iter(train_dataloader))

In [None]:
encoder_transformer = TransformerWrapper(
    num_tokens = train_dataset.vocab_size,
    max_seq_len = 512,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True
    )
)

In [None]:
decoder = TransformerWrapper(
    num_tokens = train_dataset.vocab_size,
    max_seq_len = 512,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True,
        cross_attend = True
    )
)

In [None]:
encoder = MLM(
    encoder_transformer,
    mask_token_id = train_dataset.token_encoder['<mask>'],          # the token id reserved for masking
    pad_token_id = train_dataset.token_encoder['<pad>'],           # the token id for padding
    mask_prob = 0.15,           # masking probability for masked language modeling
    replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
    mask_ignore_token_ids = []  # other tokens to exclude from masking, include the [cls] and [sep] here
)

In [None]:
decoder = AutoregressiveWrapper(
    decoder,
    pad_value = train_dataset.token_encoder['<pad>'],
    ignore_index=train_dataset.token_encoder['<pad>'],
)

In [None]:
smiles, mask = next(iter(train_dataloader))
smiles.shape, mask.shape

In [None]:
encoder = encoder.cuda()
smiles, mask = smiles.cuda(), mask.cuda()

In [None]:
with torch.no_grad():
    logits, enc, loss = encoder(smiles, mask=mask, return_logits_and_embeddings=True)

In [None]:
logits.shape, enc.shape, loss.item()

In [None]:
decoder = decoder.cuda()

In [None]:
with torch.no_grad():
    decoder_logits, decoder_loss = decoder(smiles, context=enc, context_mask=mask)

In [None]:
decoder_logits.shape, decoder_loss.item()

In [None]:
device = 'cuda'
optimizer = AdamW(list(encoder.parameters())+list(decoder.parameters()), lr=1e-4)

encoder.to(device)
decoder.to(device)
for epoch in range(10):
    avg_encoder_loss, avg_decoder_loss = 0, 0
    with tqdm(train_dataloader) as pbar:
        pbar.set_description(f'Epoch {epoch+1}')
        for i, (smiles, mask) in enumerate(pbar):
            smiles, mask = smiles.to(device), mask.to(device)
            
            encoder_logits, enc, encoder_loss = encoder(smiles, mask=mask, return_logits_and_embeddings=True)
            decoder_logits, decoder_loss = decoder(smiles, context=enc, context_mask=mask)

            encoder_loss.backward()
            # decoder_loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            avg_encoder_loss += encoder_loss.item()
            avg_decoder_loss += decoder_loss.item()
            
            pbar.set_postfix({
                'encoder_loss': encoder_loss.item(),
                'decoder_loss': decoder_loss.item(),
                'avg_encoder_loss': avg_encoder_loss/(i+1),
                'avg_decoder_loss': avg_decoder_loss/(i+1)
                })