# Graphstorm PyTorch Lightning Demonstration - Node Classification

In this notebook, we'll demonstrate how to use Graphstorm with [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable) for Node Classification.

---

## Setup 

Please follow the README.md in graphstorm-lightning/examples.

In [1]:
import yaml
import graphstorm as gs
import graphstorm_lightning as gsl
import pytorch_lightning as pl
import requests
from pathlib import Path

In [2]:
num_nodes = 1

---

## Data Preparation

In this notebook we'll create ACM graph dataset following [this guide](https://graphstorm.readthedocs.io/en/latest/tutorials/own-data.html).

In [3]:
acm_raw = Path("/tmp/acm_raw")
if not acm_raw.exists():
    acm_raw.mkdir(parents=True)
    
    # get dataset creation script
    url = "https://raw.githubusercontent.com/awslabs/graphstorm/main/examples/acm_data.py"
    acm_data = acm_raw / "acm_data.py"
    response = requests.get(url)
    assert response.status_code == 200
    with open(acm_data, "wb") as f:
        f.write(response.content)

    !python {acm_data} --output-path {acm_raw}

Namespace(download_path='/tmp/ACM.mat', dataset_name='acm', output_type='raw', output_path='/tmp/acm_raw')
Graph(num_nodes={'author': 17431, 'paper': 12499, 'subject': 73},
      num_edges={('author', 'writing', 'paper'): 37055, ('paper', 'cited', 'paper'): 30789, ('paper', 'citing', 'paper'): 30789, ('paper', 'is-about', 'subject'): 12499, ('paper', 'written-by', 'author'): 37055, ('subject', 'has', 'paper'): 12499},
      metagraph=[('author', 'paper', 'writing'), ('paper', 'paper', 'cited'), ('paper', 'paper', 'citing'), ('paper', 'subject', 'is-about'), ('paper', 'author', 'written-by'), ('subject', 'paper', 'has')])

 Number of classes: 14

 Paper node labels: torch.Size([12499])

 ('paper', 'citing', 'paper') edge labels:30789
Saving ACM data to /tmp/acm.dgl ......
/tmp/acm.dgl saved.
Saving ACM node text to /tmp/acm_text.pkl ......
/tmp/acm_text.pkl saved.
author nodes have: Index(['node_id', 'feat'], dtype='object') columns ......
paper nodes have: Index(['node_id', 'label', 'f

In [4]:
acm_gs = "/tmp/acm_gs"
if not Path(acm_gs).exists():
    !python -m graphstorm.gconstruct.construct_graph \
              --conf-file {acm_raw}/config.json \
              --output-dir {acm_gs} \
              --num-parts {num_nodes} \
              --graph-name acm

---

## Model Training

In [5]:
# Works in Jupyter, Colab and Kaggle!
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=2)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/root/.cache/pypoetry/virtualenvs/graphstorm-lightning-5me8nBHW-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard sup

In [6]:
config = yaml.safe_load(f"""
  gsf:
    basic:
      graph_name: acm
      part_config: /tmp/acm_nc_{num_nodes}p/acm.json
      model_encoder_type: rgcn
    gnn:
      fanout: "15,10"
      num_layers: 2
      hidden_size: 128
      use_mini_batch_infer: false
    input:
      restore_model_path: null
    output:
      save_model_path: null
      save_embed_path: null
    hyperparam:
      dropout: 0.5
      lr: 0.001
      num_epochs: 10
      batch_size: 1024
      wd_l2norm: 0
    rgcn:
      num_bases: -1
      use_self_loop: true
      sparse_optimizer_lr: 1e-2
      use_node_embeddings: false
    node_classification:
      node_feat_name:
        - paper:feat
        - author:feat
        - subject:feat
      target_ntype: paper
      label_field: label
      multilabel: false
      num_classes: 40
      eval_metric:
        - accuracy
""")

In [7]:
datamodule = gsl.datamodule.GSgnnNodeTrainDataModule(trainer=trainer, config=config, graph_data_uri=acm_gs)

In [8]:
model = gsl.module.GSgnnNodeModel(datamodule=datamodule, config=config)

In [9]:
trainer.fit(model=model, datamodule=datamodule)

/root/.cache/pypoetry/virtualenvs/graphstorm-lightning-5me8nBHW-py3.10/lib/python3.10/site-packages/dgl/distributed/dist_context.py:248: net_type is deprecated and will be removed in future release.
INFO:root:Start to load partition from /tmp/acm_nc_1p/part0/graph.dgl which is 5055161 bytes. It may take non-trivial time for large partition.
INFO:root:Finished loading partition from /tmp/acm_nc_1p/part0/graph.dgl.
INFO:root:Finished loading node data.
INFO:root:Finished loading edge data.
INFO:root:part 0, train: 9999, val: 1249, test: 1249
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Initialize the distributed services with graphbolt: False



  | Name  | Type           | Params
-----------------------------------------
0 | model | GSgnnNodeModel | 333 K 
-----------------------------------------
333 K     Trainable params
0         Non-trainable params
333 K     Total params
1.332     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

INFO:root:[Rank 0] dist_inference: finishes 0 iterations.
INFO:root:[Rank 0] dist_inference: finishes 0 iterations.
/root/.cache/pypoetry/virtualenvs/graphstorm-lightning-5me8nBHW-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

INFO:root:[Rank 0] dist_inference: finishes 0 iterations.
INFO:root:[Rank 0] dist_inference: finishes 0 iterations.


Validation: |                                                                                                 …

INFO:root:[Rank 0] dist_inference: finishes 0 iterations.
INFO:root:[Rank 0] dist_inference: finishes 0 iterations.
`Trainer.fit` stopped: `max_epochs=2` reached.
