In [50]:
# Imports

import os
import re

import matplotlib.pyplot as plt
import matplotlib

# PyTorch Lightning
import pytorch_lightning as pl
import seaborn as sns

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torchtext.vocab import build_vocab_from_iterator

from tqdm.notebook import tqdm
from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import ModelCheckpoint

from nltk.tokenize import RegexpTokenizer

# Import GPU-related things
if torch.cuda.is_available():
    # import cupy as np
    # import cudf as pd

    # Ensure that all operations are deterministic on GPU (if used) for reproducibility
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False
# else:
import numpy as np
import pandas as pd

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

# Plotting
plt.set_cmap("cividis")
%matplotlib inline
set_matplotlib_formats("svg", "pdf")  # For export
matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_orig()

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/")

# Setting the seed
pl.seed_everything(42)

print('CUDA:', torch.cuda.is_available())
print("Device:", device)

  set_matplotlib_formats("svg", "pdf")  # For export
Global seed set to 42


CUDA: True
Device: cuda:0


In [27]:
files = [
    'data.csv',
    'edrug3d.sdf',
    'qm9-1.sdf',
    'qm9-2.sdf',
    'qm9-3.sdf',
    'qm9-4.sdf',
    'qm9-5.sdf',
    'qm9-6.sdf',
    'qm9-7.sdf',
    'qm9-8.sdf'
]


def check_missing_files():
    """Checks for missing files. Returns true, if all files are present."""
    for file in files:
        if not os.path.exists('./data/' + file):
            return False

    return True


if not check_missing_files():
    !wget -nc -O data.zip "https://hochschulebonnrheinsieg-my.sharepoint.com/:u:/g/personal/nico_piel_365h-brs_de1/ESuGOTn_IflEk7I5HkOFpbwBZKeOk9Qf2nL5JEcq2om6_Q?e=sHYsTk&download=1"
    !unzip -u data.zip
    !rm data.zip

In [28]:
class CustomDataset(data.Dataset):
    def __init__(self, path):
        super().__init__()
        self.data = pd.read_csv(path)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        series = self.data.iloc[idx]
        return series[0], series[1]

In [29]:
dataset = CustomDataset('./data/data.csv')
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)

In [30]:
# Get maximum sentence length
def get_max_sentence_length() -> np.ndarray:
    return np.array([len(features[0]) for features in dataset])

In [31]:
max_length = np.max(get_max_sentence_length())

In [51]:
# SMILES regex by Schwaller et. al.
smiles_regex = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
at_regex = r"\w+|\w+"

smiles_tokenizer = RegexpTokenizer(smiles_regex)
at_tokenizer = RegexpTokenizer(at_regex)

print(dataset[0][0])
print(smiles_tokenizer.tokenize(dataset[0][0]))

print(dataset[0][1])
print(at_tokenizer.tokenize(dataset[0][1]))

S(=O)(=O)(Nc1ncccc1)c1ccc(N)cc1
['S', '(', '=', 'O', ')', '(', '=', 'O', ')', '(', 'N', 'c', '1', 'n', 'c', 'c', 'c', 'c', '1', ')', 'c', '1', 'c', 'c', 'c', '(', 'N', ')', 'c', 'c', '1']
sy|o|o|nu|nb|nv|ca|ca|ca|ca|ca|ca|ca|ca|ca|ca|ca|hn|ha|ha|ha|ha|ha|ha|hn|hn|ha|h4
['sy', 'o', 'o', 'nu', 'nb', 'nv', 'ca', 'ca', 'ca', 'ca', 'ca', 'ca', 'ca', 'ca', 'ca', 'ca', 'ca', 'hn', 'ha', 'ha', 'ha', 'ha', 'ha', 'ha', 'hn', 'hn', 'ha', 'h4']


In [53]:
# Build vocabs
smiles_vocab = build_vocab_from_iterator([smiles_tokenizer.tokenize(feature[0]) for feature in dataset])
at_vocab = build_vocab_from_iterator([at_tokenizer.tokenize(feature[1]) for feature in dataset])

In [64]:
class ATTransformer(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.train_dataset = CustomDataset('.data/data.csv')
        self.save_hyperparameters()
        self._create_model()

    def _create_model(self):
        raise NotImplementedError

    def forward(self, X):
        raise NotImplementedError

    def train_dataloader(self):
        return data.DataLoader(self.train_dataset, batch_size=64, shuffle=True)