In [None]:
# Plotting 
#%pip install svgutils

import numpy as np
import os
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG

import torch
import torch.nn.functional as F

files = {e.split("-")[1].split(".")[0]: os.getenv("HOME")+"/data/USTPO/" 
         + e for e in os.listdir(os.getenv("HOME")+"/data/USTPO/")}
with open(files["valid"]) as f:
    lines = f.readlines()
lines = [line[:-1] for line in lines]


for i in range(30):
    rxn = rdChemReactions.ReactionFromSmarts
    reaction = rxn(lines[i])
    rs = list(reaction.GetReactants())
    ps = list(reaction.GetProducts())
    title = ""
    for r in rs:
        title += Chem.MolToSmiles(r) + "."

    title = title[:-1] + ">>"
    for p in ps:
        title += Chem.MolToSmiles(r) + "."
    title = title[:-1]

    d = Draw.MolDraw2DSVG(900, 400)
    opts = d.drawOptions()        
    opts.bgColor = None
    opts.clearBackground = False
    opts.bondLineWidth = 1
    d.DrawReaction(reaction)
    d.FinishDrawing()



    svg = d.GetDrawingText()
    s = svg.replace('svg:','')

    import svgutils.transform as sg
    fig = sg.fromstring(s)
    label = sg.TextElement(450-len(lines[i])*3.2, 80, lines[i], size=11, 
                           font='sans-serif', anchor='bottom', color='#000000')
    fig.append(label)
    display(SVG(fig.to_str()))

In [None]:
# Parsing
# Plotting 
#%pip install svgutils

import numpy as np
import os
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
import pickle

import torch
import torch.nn.functional as F
import time


#phase = "valid"
for phase in ["train", "valid", "test"]:
    files = {e.split("-")[1].split(".")[0]: os.getenv("HOME")+"/data/USTPO/" 
             + e for e in os.listdir(os.getenv("HOME")+"/data/USTPO/")}
    with open(files[phase]) as f:
        lines = f.readlines()
    lines = [line[:-1] for line in lines]


    chars = " ^#%()+-./0123456789=@ABCDEFGHIKLMNOPRSTVXYZ[\\]abcdefgilmnoprstuy$"
    char_to_ix = { ch:i for i,ch in enumerate(chars) }
    ix_to_char = { i:ch for i,ch in enumerate(chars) }
    max_len = 160
    vacant = " "*max_len

    lines_out = []
    lines_out_ = []
    for line in lines:
        rs_, ps_ = line.split(" >> ")
        if max(len(rs_),len(ps_)) > max_len: continue
        rs_, ps_ = rs_+vacant[len(rs_):], ps_+vacant[len(ps_):]
        rs, ps = torch.zeros(max_len),torch.zeros(max_len)
        for i in range(max_len): rs[i], ps[i] = char_to_ix[rs_[i]], char_to_ix[ps_[i]]
        lines_out.append({"rs": rs, "ps": ps})
        lines_out_.append({"rs": rs_, "ps": ps_})
        #print(rs_)
        #print(rs)
        #rs = F.one_hot(rs.to(dtype=torch.int64), num_classes=len(chars))
        #print(rs)
        #print(torch.argmax(rs, axis=1))
    
    print("finished parsing " + phase)

    with open("/home/arvid/data/USTPO_text/"+ phase +".pickle", 'wb') as handle:
        pickle.dump(lines_out_, handle, protocol=pickle.HIGHEST_PROTOCOL)



In [None]:
import pickle
data = {}
for phase in ["train", "valid", "test"]:
    with open("/home/arvid/data/USTPO_text/" + phase + ".pickle", 'rb') as handle:
        data[phase] = pickle.load(handle)

In [None]:
data

In [None]:
for key in data.keys(): print(key + " " + str(len(data[key])))

In [None]:
len(" ^#%()+-./0123456789=@ABCDEFGHIKLMNOPRSTVXYZ[\\]abcdefgilmnoprstuy$")

In [None]:
data["train"][10]

In [None]:
F.one_hot(data["train"][100]["rs"].to(device, dtype=torch.int64), num_classes=len(chars))

In [None]:
# Define ReactionDataset Class
# How to shuffle each epoch? Check later.


from torch.utils.data import Dataset
import torch.nn.functional as F

class ReactionDataset(Dataset):

    def __init__(self, data, split, rep=" ^#%()+-./0123456789=@ABCDEFGHIKLMNOPRSTVXYZ[\\]abcdefgilmnoprstuy$"):
        self.split = split
        self.data = data[self.split]
        self.rep = rep
        # Add augmentation methods here later
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        rs_smiles = self.data[index]["rs"]
        ps_smiles = self.data[index]["ps"]
        
        # Augment smiles here for train
        
        rs = 
        
        return {
            'rs': F.one_hot(self.data[index]["rs"].to(dtype=torch.int64), num_classes=len(self.rep)),
            'ps':  F.one_hot(self.data[index]["ps"].to(dtype=torch.int64), num_classes=len(self.rep))
        }

In [None]:
from torch.utils.data import DataLoader

datasets = {}
dataloaders = {}
for split in ['train', 'valid', 'test']:
    datasets[split] = ReactionDataset(data=data,
                                   split=split)

    dataloaders[split] = DataLoader(datasets[split],
                                    batch_size=32,
                                    shuffle=(split != 'test'),
                                    num_workers=4,
                                    pin_memory=False)# Was True before.

In [None]:
element = datasets["train"].__getitem__(index=100)
element["rs"].shape

In [None]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
src = torch.rand(10, 64, 512)
out = transformer_encoder(src)

In [None]:
from ...constants import *

In [1]:
with open("/home/arvid/data/USTPO_paper_5x/patents40k_x5MSShuf.csv") as f:
    lines = f.readlines()

In [35]:
with open("/home/arvid/data/USTPO_paper_5x/patents40k_x5MSShuf.csv") as f:
    lines = f.readlines()
import numpy as np
print(lines[0])
lines = lines[1:]
ix = np.arange(0,4000)
print(ix)
np.random.shuffle(ix)
print(ix)
ix = np.repeat(ix*10,10)
print(ix)
ix += np.tile(np.arange(0,10),4000)
print(ix)
split_ix = int((.75*40000)-(.75*40000)%10)
print(split_ix)
#data = {"train", "eval"}
#for line in lines[1:]:
data = {"train": [], "eval": []}
train_ix = ix[:split_ix]
val_ix = ix[split_ix:]
print(train_ix)
np.random.shuffle(train_ix)
np.random.shuffle(val_ix)
print(train_ix)
print(len(train_ix))

for e in train_ix:
    splits = lines[e].split(",")
    #print(splits)
    rs = splits[0]
    #print(splits[1])
    ps = splits[1][:-1]
    if max(len(rs),len(ps)) < 160:
        data["train"].append({"rs": rs, "ps": ps})

for e in val_ix:
    splits = lines[e].split(",")
    #print(splits)
    rs = splits[0]
    #print(splits[1])
    ps = splits[1][:-1]
    if max(len(rs),len(ps)) < 160:
        data["eval"].append({"rs": rs, "ps": ps})


input,target

[   0    1    2 ... 3997 3998 3999]
[3873  494 3565 ... 1005  770 3427]
[38730 38730 38730 ... 34270 34270 34270]
[38730 38731 38732 ... 34277 34278 34279]
30000
[38730 38731 38732 ...  4967  4968  4969]
[32110 17038  2751 ... 15860 28529 31564]
30000


In [31]:
data["eval"][0]

{'rs': '.C(Cl)(=O)C(C)=C.CC(C)(c1cc(cc(C(C)(C)C)c1O)CS)C',
 'ps': 'c1(C(C)(C)C)cc(cc(c1O)C(C)(C)C)CSC(C(C)=C)=O'}

In [36]:
import pickle
with open("/home/arvid/data/USTPO_paper_5x/USTPO_5x_parsed.pickle", 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)