In [None]:
from dldd import ClassificationModel
from pytorch_lightning import Trainer
from torch_geometric.data import DataLoader
from dldd.utils import TwoGraphData
import torch
import pickle
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger

In [None]:
## Load the data
with open('data/train.pkl', 'rb') as file:
    train = pickle.load(file)
with open('data/val.pkl', 'rb') as file:
    val = pickle.load(file)
with open('data/test.pkl', 'rb') as file:
    test = pickle.load(file)

In [None]:
# create DataLoaders
BATCH_SIZE = 128
NUM_WORKERS = 1
FOLLOW_BATCH = ['protein_x', 'drug_x']
SHUFFLE = True
train_dl = DataLoader(train,
                batch_size=BATCH_SIZE,
                num_workers=NUM_WORKERS,
                follow_batch=FOLLOW_BATCH,
                shuffle=SHUFFLE)
val_dl = DataLoader(val,
                batch_size=BATCH_SIZE,
                num_workers=NUM_WORKERS,
                follow_batch=FOLLOW_BATCH,
                shuffle=SHUFFLE)
test_dl = DataLoader(test,
                batch_size=BATCH_SIZE,
                num_workers=NUM_WORKERS,
                follow_batch=FOLLOW_BATCH,
                shuffle=SHUFFLE)

In [None]:
EARLY_STOP_PATIENCE = 30 # How long to wait without improvement before killing the process
GRADIENT_CLIP_VAL = 20 # Gradient clipping prevents weights from becoming too big
model = ClassificationModel()
## You can change the name of the logger, then it will be in a different directory
logger = CSVLogger("logs", name="cold_target")
callbacks = [
    ModelCheckpoint(monitor="val_loss", save_top_k=3, mode="min"), ## Save 3 best models (lowest val loss)
    EarlyStopping(monitor="val_loss", patience=EARLY_STOP_PATIENCE, mode="min"), ## if val loss doesn't decrease for EARLY_STOP_PATIENCE epochs - stop training
]
trainer = Trainer(
    gpus=1,
    callbacks=callbacks,
    logger=logger,
    gradient_clip_val=GRADIENT_CLIP_VAL,
    stochastic_weight_avg=True,
    num_sanity_val_steps=0,
)

In [None]:
## The loss is not printed out anymore, but now you can view it in the logs_parsing.ipynb
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)
trainer.test(model, test_dataloaders=test_dl)