#### All libraries

In [1]:
import os
import torch
import pytorch_lightning as pl
from torch_geometric.datasets import Planetoid
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import loggers as pl_loggers
import torch_geometric.transforms as T
import torch_geometric.data as geom_data
from torch.utils.tensorboard import SummaryWriter
%reload_ext autoreload
%autoreload 2

#### Set Path

In [6]:
cwd = os.getcwd()
tb_logging_dir = os.path.join(cwd, "lightning_logs")
exp_name = "NC-AG-GraphSAGE"
exp_dir = os.path.join(tb_logging_dir, exp_name)
dataset_dir = os.path.join(cwd, "dataset", "CUHKSZ_AcademicGraph")

#### Set Device

In [27]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if device != "cpu":
#     torch.set_float32_matmul_precision('high')
device = "cpu"
torch.set_float32_matmul_precision("high")

#### Download dataset

In [25]:
from utils.dataset.CUHKSZ_AcademicGraph import CUHKSZ_AcademicGraph

AGDataset = CUHKSZ_AcademicGraph(dataset_dir, with_title=False, with_label=True)

D:\GitHub\aml-project\GNN\wsy\GCN\dataset\CUHKSZ_AcademicGraph\raw\CUHKSZ_AcademicGraph_Rawdata.zip
D:\GitHub\aml-project\GNN\wsy\GCN\dataset\CUHKSZ_AcademicGraph\raw\CUHKSZ_AcademicGraph-rawdata_released
test


#### Training model

In [32]:
from utils.model.NC_GraphSAGE import GraphSAGE

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode='min'
)
hparams = {"DROUPOUT_RATE": 0,  
           "BATCH_SIZE": 64,
           "LEARNING_RATE": 0.001,
           "HIDDEN_DIM": [112, 16], 
           "AGGREGATOR_TYPE": "max"
           } 

tb_logger = pl_loggers.TensorBoardLogger(tb_logging_dir, name=exp_name)
 
trainer = pl.Trainer(accelerator="cpu",
                     max_epochs=500,
                     callbacks=[early_stop_callback],
                     logger=tb_logger
                     # check_val_every_n_epoch=5,
                     )

version_dir = os.path.join(exp_dir, "version_"+str(trainer.logger.version))
writer_acc = SummaryWriter(log_dir=version_dir)
writer_loss = SummaryWriter(log_dir=version_dir)

checkpoint_dir = os.path.join(version_dir, "checkpoints")
print("Saving checkpoints to", checkpoint_dir)

SAGEmodel = GraphSAGE(dataset=AGDataset, hparams=hparams, input_dim = AGDataset.num_features,
                      writer_acc=writer_acc, writer_loss=writer_loss).to(device)

trainer.fit(SAGEmodel)
trainer.test(SAGEmodel)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type              | Params
--------------------------------------------
0 | aggr  | MaxAggregation    | 0     
1 | model | Sequential_4d2598 | 175 K 
--------------------------------------------
175 K     Trainable params
0         Non-trainable params
175 K     Total params
0.703     Total estimated model params size (MB)


Saving checkpoints to D:\GitHub\aml-project\GNN\wsy\GCN\lightning_logs\NC-AG-GraphSAGE\version_19\checkpoints


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.6651288270950317, 'test_accuracy': 0.761148989200592}]

#### Loading model

In [35]:
checkpoint_dir = "lightning_logs\\NC_AG_GraphSAGE\\best\\checkpoints"
checkpoint_file = os.path.join(checkpoint_dir, os.listdir(checkpoint_dir)[0])
Loaded_model = GraphSAGE.load_from_checkpoint(checkpoint_file, dataset=AGDataset, input_dim=AGDataset.num_features, 
                                              hparams=hparams, writer_acc=writer_acc, writer_loss=writer_loss)
# AGGREGATOR_TYPE: max
# BATCH_SIZE: 64
# DROUPOUT_RATE: 0
# HIDDEN_DIM:
# - 112
# - 16
# LEARNING_RATE: 0.001
# test_accuracy = 0.7611, test_loss = 0.6651