# Example of Graph Neural Network

In [2]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

# import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
# import wandb

import warnings

warnings.filterwarnings("ignore")
sys.path.append("../../..")
device = "cuda" if torch.cuda.is_available() else "cpu"

from LightningModules.GNN.Models.hetero_gnn import HeteroGNN
from LightningModules.GNN.Models.interaction_gnn import InteractionGNN

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Setup

In [3]:
with open("hetero_gnn.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
model = HeteroGNN(hparams)

## Understand Model Weight Sharing

In [13]:
import torch.nn as nn
from LightningModules.GNN.Models.submodels.convolutions import HeteroConv, HomoConv

In [72]:
model = HeteroConv(hparams)

In [77]:
model = nn.ModuleList([HeteroConv(hparams)]*3)
model = model.requires_grad_(False)

In [64]:
model = nn.ModuleList([HeteroConv(hparams) for _ in range(3)])
model = model.requires_grad_(False)

In [84]:
list(model[0].parameters())

[Parameter containing:
 tensor([[ 0.0209,  0.0441, -0.0505,  ..., -0.0070,  0.0477, -0.0489],
         [ 0.0159, -0.0508,  0.0448,  ...,  0.0236,  0.0128, -0.0097],
         [ 0.0287,  0.0209,  0.0118,  ..., -0.0255, -0.0196, -0.0341],
         ...,
         [ 0.0023, -0.0302, -0.0504,  ...,  0.0330,  0.0226, -0.0363],
         [-0.0393,  0.0485, -0.0498,  ...,  0.0148,  0.0471,  0.0305],
         [-0.0412, -0.0324,  0.0218,  ..., -0.0027, -0.0064, -0.0030]]),
 Parameter containing:
 tensor([-4.4499e-02, -9.3512e-03, -3.5530e-02,  1.4636e-02, -1.9537e-02,
         -3.1224e-02,  2.9622e-02, -2.0002e-02,  4.5058e-02, -1.4680e-02,
          3.7018e-02,  8.2067e-03,  4.0025e-02,  6.5809e-03, -2.8690e-02,
          1.6582e-02,  4.4137e-02,  1.4662e-03,  5.0867e-02, -9.6911e-03,
          6.0124e-03,  4.7642e-02,  9.8401e-03,  1.3704e-02, -2.4090e-02,
          3.1259e-02,  1.9301e-02, -4.2506e-03, -1.1817e-02,  4.3157e-02,
          3.2680e-05, -2.5641e-02,  4.3549e-02,  1.1252e-02, -1.3870

In [69]:
#Print model parameters
print(list(model[0].parameters())[0])

Parameter containing:
tensor([[ 0.0473,  0.0144,  0.0301,  ..., -0.0237,  0.0015,  0.0263],
        [ 0.0160, -0.0176,  0.0002,  ..., -0.0199, -0.0170, -0.0260],
        [-0.0272,  0.0017, -0.0345,  ...,  0.0405,  0.0497,  0.0279],
        ...,
        [-0.0384,  0.0482,  0.0409,  ..., -0.0478,  0.0338,  0.0466],
        [ 0.0483,  0.0466,  0.0213,  ..., -0.0335,  0.0020, -0.0196],
        [-0.0365,  0.0045,  0.0444,  ..., -0.0182, -0.0235,  0.0295]])


In [67]:
#Print model parameters
list(model[1].parameters())[0][0,0] = torch.tensor(1.0)

In [68]:
#Print model parameters
list(model[1].parameters())[0][0,0] 

tensor(1.)

## Train GNN

In [5]:
logger = WandbLogger(
    project=hparams["project"], group="InitialTest", save_dir=hparams["artifacts"]
)
trainer = Trainer(
    gpus=1, max_epochs=hparams["max_epochs"], logger=logger
)  # , precision=16)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type       | Params
------------------------------------------------------
0 | node_encoders          | ModuleList | 33.5 K
1 | edge_encoders          | ModuleList | 65.9 K
2 | edge_network           | Sequential | 82.3 K
3 | node_network           | Sequential | 82.3 K
4 | output_edge_classifier | Sequential | 82.4 K
------------------------------------------------------
346 K     Trainable params
0         Non-trainable params
346 K     Total params
1.386     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]