#### All libraries

In [1]:
import os
import torch
import pytorch_lightning as pl
from torch_geometric.datasets import Planetoid
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import loggers as pl_loggers
import torch_geometric.transforms as T
import torch_geometric.data as geom_data
from torch.utils.tensorboard import SummaryWriter
%reload_ext autoreload
%autoreload 2

#### Set path

In [2]:
cwd = os.getcwd()
tb_logging_dir = os.path.join(cwd, "lightning_logs")
exp_name = "NC-AG-GCN"
exp_dir = os.path.join(tb_logging_dir, exp_name)
dataset_dir = os.path.join(cwd, "dataset", "CUHKSZ_AcademicGraph")

#### Set device

In [3]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# if device != "cpu":
#     torch.set_float32_matmul_precision('high')
device = "cpu"
torch.set_float32_matmul_precision("high")

#### Download dataset

In [11]:
from utils.dataset.CUHKSZ_AcademicGraph import CUHKSZ_AcademicGraph

AGDataset = CUHKSZ_AcademicGraph(dataset_dir, with_title=False, with_label=True)

d:\GitHub\GNN-Cora-CUHKSZAG\dataset\CUHKSZ_AcademicGraph\raw\CUHKSZ_AcademicGraph_Rawdata.zip
d:\GitHub\GNN-Cora-CUHKSZAG\dataset\CUHKSZ_AcademicGraph\raw\CUHKSZ_AcademicGraph-rawdata_released


In [28]:
from utils.model.NC_GCN import GCN

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode='min'
)
hparams = {"DROUPOUT_RATE": 0.3,  
           "BATCH_SIZE": 64,
           "NUM_CLASSES": 8,
           "NUM_FEATURES": AGDataset.num_features,
           "LEARNING_RATE": 0.001,
           }  # The number of neighbors in each order of sampling

tb_logger = pl_loggers.TensorBoardLogger(tb_logging_dir, name=exp_name)
 
trainer = pl.Trainer(accelerator = "cpu",
                     max_epochs=500,
                     callbacks=[early_stop_callback],
                     logger=tb_logger,
                     check_val_every_n_epoch=5,
                     )

version_dir = os.path.join(exp_dir, "version_"+str(trainer.logger.version))
writer_acc = SummaryWriter(log_dir=version_dir)
writer_loss = SummaryWriter(log_dir=version_dir)

checkpoint_dir = os.path.join(version_dir, "checkpoints")
print("Saving checkpoints to", checkpoint_dir)

GCNmodel = GCN(dataset=AGDataset, hparams=hparams,
            writer_acc=writer_acc, writer_loss=writer_loss).to(device)
trainer.fit(GCNmodel)
trainer.test(GCNmodel)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type    | Params
----------------------------------
0 | conv1 | GCNConv | 12.3 K
1 | conv2 | GCNConv | 136   
----------------------------------
12.4 K    Trainable params
0         Non-trainable params
12.4 K    Total params
0.050     Total estimated model params size (MB)


Saving checkpoints to D:\GitHub\aml-project\GNN\wsy\GCN\lightning_logs\NC-AG-GCN\version_18\checkpoints


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.684421181678772, 'test_accuracy': 0.7558578848838806}]

#### Load model

In [8]:
from utils.model.NC_GCN import GCN
from torch.utils.tensorboard import SummaryWriter

In [14]:
hparams = {
    "NUM_CLASSES": 8,
    "NUM_FEATURES": AGDataset.num_features,
    "LEARNING_RATE": 0.001,
    "DROUPOUT_RATE": 0.4,
}
version_dir = os.path.join(exp_dir, "best")
writer_acc = SummaryWriter(log_dir = version_dir)
writer_loss = SummaryWriter(log_dir = version_dir)
checkpoint_dir = "lightning_logs\\NC_AG_GCN\\best\\checkpoints"

checkpoint_file = os.path.join(checkpoint_dir, os.listdir(checkpoint_dir)[0])
Loaded_model = GCN.load_from_checkpoint(checkpoint_file, dataset=AGDataset, hparams=hparams,)
# BATCH_SIZE: 64
# DROUPOUT_RATE: 0.3
# LEARNING_RATE: 0.001
# NUM_CLASSES: 8
# NUM_FEATURES: 768
# test_accuracy: 0.7566, test_loss: 0.6618