#### All libraries

In [1]:
import os
import torch
import pytorch_lightning as pl
from torch_geometric.datasets import Planetoid
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import loggers as pl_loggers
import torch_geometric.transforms as T
import torch_geometric.data as geom_data
from torch.utils.tensorboard import SummaryWriter

%reload_ext autoreload
%autoreload 2

#### Set path

In [2]:
cwd = os.getcwd()
tb_logging_dir = os.path.join(cwd, "lightning_logs")
exp_name = "NC-Cora-GraphSAGE"
exp_dir = os.path.join(tb_logging_dir, exp_name)

#### Set device

In [3]:
torch.set_float32_matmul_precision('high')
device = torch.device('cpu')

#### Cora dataset

In [4]:
dataset = Planetoid(root = "dataset/Cora", name="Cora", split = "full", transform=T.ToDevice(device))

In [5]:
from utils.model.NC_GraphSAGE import GraphSAGE

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=5,
    verbose=False,
    mode='min'
)
hparams = {"DROUPOUT_RATE": 0,  
           "BATCH_SIZE": 64,
           "LEARNING_RATE": 0.001,
           "HIDDEN_DIM": [112, 16], 
           "AGGREGATOR_TYPE": "mean"} 

tb_logger = pl_loggers.TensorBoardLogger(tb_logging_dir, name=exp_name)
trainer = pl.Trainer(accelerator = "cpu",
                     max_epochs=500, 
                     callbacks=[early_stop_callback], 
                     check_val_every_n_epoch=5, 
                     logger=tb_logger)

version_dir = os.path.join(exp_dir, "version_"+str(trainer.logger.version))
checkpoint_dir = os.path.join(version_dir, "checkpoints")
print("Saving checkpoints to", checkpoint_dir)

writer_acc = SummaryWriter(log_dir = version_dir)
writer_loss = SummaryWriter(log_dir = version_dir)
SAGEmodel = GraphSAGE(dataset=dataset, hparams=hparams, input_dim = dataset.num_features,
                   writer_acc=writer_acc, writer_loss=writer_loss).to(device)
trainer.fit(SAGEmodel)
trainer.test(SAGEmodel)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Saving checkpoints to D:\GitHub\aml-project\GNN\wsy\GCN\lightning_logs\NC-Cora-GraphSAGE\version_8\checkpoints



  | Name  | Type              | Params
--------------------------------------------
0 | aggr  | MeanAggregation   | 0     
1 | model | Sequential_cd6ea3 | 324 K 
--------------------------------------------
324 K     Trainable params
0         Non-trainable params
324 K     Total params
1.299     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.43383315205574036, 'test_accuracy': 0.8600000739097595}]

#### Load model

In [17]:
checkpoint_dir = "..\\lightning_logs\\NC-Cora-GraphSAGE\\version_7\\checkpoints"
checkpoint_file = os.path.join(checkpoint_dir, os.listdir(checkpoint_dir)[0])
Loaded_model = GraphSAGE.load_from_checkpoint(checkpoint_file, dataset=dataset, hparams=hparams, input_dim = dataset.num_features,
                   writer_acc=writer_acc, writer_loss=writer_loss)
# AGGREGATOR_TYPE: mean
# BATCH_SIZE: 64
# DROUPOUT_RATE: 0
# HIDDEN_DIM:
# - 112
# - 16
# LEARNING_RATE: 0.001
# NUM_NEIGHBORS:
# - 10
# - 10
# test_accuracy = 0.865, test_loss = 0.4468