In [1]:

%load_ext autoreload
%autoreload 2

## Testing SmilesDataset

In [47]:
from smiles_dataset import SmilesDataset
from smiles_lightning_data_module import SmilesDataModule
from lightning_model import LightningClassicGNN
import pytorch_lightning as pl
from torch_geometric.transforms import distance
from torch_geometric.loader import DataLoader
import os
import torch
# making sure we are as determinstic as possibe
torch.use_deterministic_algorithms(True)
import numpy as np

from torch_geometric.data import Data
from typing import List, Callable
from functools import partial
from smiles_dataset import SmilesDataset
from torch_geometric.transforms import Compose, distance
from datasets import BaceDataset

%load_ext autoreload
%autoreload 2

def filter_target(target_names:List[str], target:str)-> Callable[[Data],Data]:
    """ Transform to be given to SmilesDataset, has the effect of filtering out all irelevant targets in the Data objects in the dataset at runtime
    Example: for BACE, target_names=['Class', 'PIC50'], we want to train a classifier => target='Class'
    """
    target_idx = target_names.index(target)

    return partial(filter_target_with_idx, target_idx=target_idx)

def filter_target_with_idx(graph:Data, target_idx:int) -> Data:
    new_graph = Data(x=graph.x, edge_index=graph.edge_index, edge_attr=graph.edge_attr, pos=graph.pos ,y=graph.y[:,target_idx:target_idx+1], z=graph.z, name=graph.name, idx=graph.idx) 

    return new_graph


root= "../data/bace"
filename="bace.csv"
seed=42
## pytorch lighting takes of seeding everything
pl.seed_everything(seed=seed, workers=True)


#! rm -rf ../data/bace/processed

data,target_names = BaceDataset(root=root)
del data
transforms=Compose([filter_target(target_names=target_names, target='pIC50'), distance.Distance()])
dataset = SmilesDataset(root=root,filename="bace.csv", transform=transforms)

# from torch dataset, create lightning data module to make sure training splits are always done the same ways
data_module = SmilesDataModule(dataset=dataset, seed=seed)

num_node_features = data_module.num_node_features
num_edge_features= data_module.num_edge_features

gnn_model = LightningClassicGNN(classification=False, output_dim=1, num_node_features=num_node_features, num_edge_features=num_edge_features)

num_epochs=1


# from pytorch_ligthing import loggers
# logger = loggers.WandbLogger()

# default root dir is where the logs and weights are logged
# useful when debugging is limit_train_batches
# by default uses TensorBoardLogger, can be configured 
# Plugins allows us to connect to arbitrary cluster
# can set max_epochs
# can use precision to specify number of bit floating points to reduce memory footprint ()
# can use accumulate_grad_batches to speed-up training too
trainer = pl.Trainer(deterministic=True, auto_lr_find=True, default_root_dir=os.getcwd(), precision="bf16", max_epochs=num_epochs)



Global seed set to 42
Global seed set to 42
Using bfloat16 Automatic Mixed Precision (AMP)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [48]:
# tune to find the learning rate
trainer.tune(gnn_model,datamodule=data_module)

In [5]:

# we can resume from a checkpoint using trainer.fit(ckpth_path="some/path/to/my_checkpoint.ckpt")
trainer.fit(gnn_model, datamodule=data_module)

(5, 5)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
