# GNN Modeling

## Data and Set up

In [1]:
# # set up code if running in colab or externel environment...

# from google.colab import drive
# drive.mount('/content/drive/')

# cd ../content/drive/MyDrive/ProjectX/

# %%capture
# !pip install pytorch_lightning;
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html;
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html;
# !pip install torch-geometric;
# !pip install wandb;

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [4]:
import numpy as np
import pandas as pd

np.random.seed(314159) # set random seed

import torch
import pytorch_lightning as pl

from torch_geometric.data import Data

import wandb
wandb.init() # initialize W&B experiment tracking

[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m (use `wandb login --relogin` to force relogin)


In [5]:
# load edge list
edge_list_path = 'data/edge_list.npy'
edge_list = torch.Tensor(np.load(edge_list_path).T).type(torch.int64) # read in format expected by pytorch geometric [2, n_edges]

# load protein-ID dictionary (need new ID system starting at index 0 for pytorch geometric)
protein_id_dict = np.load('data/protein_ids_dict.npy', allow_pickle=True).item() # maps my custom ID system to Ensembl IDs
protein_id_dict_inv = {Ensembl: id_ for id_, Ensembl in protein_id_dict.items()} # maps Ensembl IDs to my custom ID system

In [6]:
data_path = 'data/HPAnode_PPInetwork_labels_v3.csv' #NOTE: labels are generated from infomation in this dataset
node_dataset = pd.read_csv(data_path, index_col=0)

# map dataset
myID = node_dataset.index.map(protein_id_dict_inv).rename('myID')
node_dataset.insert(loc=0, column='myID', value=myID)
node_dataset = node_dataset.reset_index().set_index('myID')

In [7]:
# make sure dataset with myID is of correct order and format
node_dataset.sort_index(inplace=True) # should already be sorted, but just in case
assert((node_dataset.index.to_numpy() == np.arange(len(node_dataset))).all())

In [8]:
# create positives
label_name = 'my_label'

# find positives
pos_label_col = 'DisGenNet_thresh_pos' #FIXME: figure out meaning of columns and determing appropriate choice of positive labels
pos_labels = pd.array([1 if row[pos_label_col] else None for id_, row in node_dataset.iterrows()], dtype='Int32')
node_dataset[label_name] = pos_labels

# create negatives
def sample_negatives(PU_labels):
    '''randomly samples from the unlabeled samples'''

    # sample same # as positives
    num_pos = (PU_labels==1).sum()
    neg_inds = PU_labels[PU_labels.isna()].sample(num_pos).index

    # TODO: more sophisticated methods for sampling methods. (e.g.: use mutation rate, unsupervised learning, etc.)

    return neg_inds # returns ID's of negative samples

neg_label_inds = sample_negatives(node_dataset[label_name])
node_dataset[label_name].loc[neg_label_inds] = 0

# TODO: save this data for reproducibility (not now, but once this is finalized and fixed)

node_dataset[label_name].value_counts()

0    191
1    191
Name: my_label, dtype: Int64

In [9]:
label_col = label_name
node_dataset[label_col] = node_dataset[label_col].astype('Int32')

# TODO: decide whether or not to include network embedding features...
num_node_feats = 100
node_feat_cols = ['Tissue RNA - lung [NX]', 'Single Cell Type RNA - Mucus-secreting cells [NX]'] + [f'node_{i}' for i in range(num_node_feats)]

# get subset of node features features + labels
node_data = node_dataset[node_feat_cols + [label_col]]

X = torch.Tensor(node_data[node_feat_cols].to_numpy())#.type(torch.float64)

y = node_data[label_col].fillna(-1).astype('int') # fill NaN with -1 so that it can be converted to pytorch tensor
y = torch.Tensor(y).type(torch.int64)

# restrict to data with labels
node_data_labeled = node_data[node_data[label_col].notna()]
node_data_labeled

Unnamed: 0_level_0,Tissue RNA - lung [NX],Single Cell Type RNA - Mucus-secreting cells [NX],node_0,node_1,node_2,node_3,node_4,node_5,node_6,node_7,node_8,node_9,node_10,node_11,node_12,node_13,node_14,node_15,node_16,node_17,node_18,node_19,node_20,node_21,node_22,node_23,node_24,node_25,node_26,node_27,node_28,node_29,node_30,node_31,node_32,node_33,node_34,node_35,node_36,node_37,...,node_61,node_62,node_63,node_64,node_65,node_66,node_67,node_68,node_69,node_70,node_71,node_72,node_73,node_74,node_75,node_76,node_77,node_78,node_79,node_80,node_81,node_82,node_83,node_84,node_85,node_86,node_87,node_88,node_89,node_90,node_91,node_92,node_93,node_94,node_95,node_96,node_97,node_98,node_99,my_label
myID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
69,-0.660176,-0.105079,2.097750,-1.520459,0.979743,0.102077,-0.064198,-0.095211,0.785779,0.626443,0.230450,-0.087699,0.070256,0.129571,-0.102219,0.038228,0.086179,0.388514,0.244929,0.177068,-0.124961,-0.170738,-0.162408,0.291555,0.107341,0.073149,0.064578,0.289876,-0.011252,0.341365,-0.412306,0.070327,-0.080478,0.155356,-0.005560,-0.006526,-0.064627,-0.025518,-0.028840,0.425118,...,-0.048727,-0.067543,-0.022368,-0.035638,0.044015,-0.064344,0.009336,-0.201068,-0.018671,0.053095,-0.099977,-0.006505,-0.089804,-0.077025,-0.042594,0.027789,0.018138,-0.015677,0.048587,-0.031364,-0.031258,0.067396,0.051931,-0.079723,0.023503,-0.041777,0.062489,0.107803,0.087040,0.057689,0.099556,0.097325,-0.352616,0.111681,-0.197486,-0.071428,0.550138,0.093125,-0.146437,0
80,0.150101,-0.069827,2.622879,0.092524,1.558535,-1.148822,0.606971,0.573626,0.106728,-0.357630,-0.088841,0.545125,-0.446062,-0.006933,-0.077243,-0.304660,-0.023302,-0.140114,-0.393294,0.438554,0.552702,-0.190259,-0.090898,-0.337363,0.191093,-0.416474,0.284222,-0.046262,0.218801,-0.236490,0.113643,0.193755,-0.092640,-0.150932,0.922676,0.277999,0.637841,0.407525,0.149785,0.374702,...,-0.077789,-0.028746,-0.009711,-0.108646,-0.229950,0.068370,-0.194103,-0.053455,0.112825,-0.008329,-0.044336,0.039903,-0.005522,0.099053,-0.115262,-0.286986,-0.363197,0.199954,0.148879,-0.022436,-0.110743,-0.264602,0.075054,0.380742,-0.328094,0.073103,0.328778,-0.275994,0.341445,0.403339,-0.327284,-0.087676,0.254183,-0.066311,-0.014220,-0.059492,0.095315,0.159288,-0.186821,1
110,-0.144076,-0.055349,1.143110,-0.375444,-0.293915,0.119900,-0.075660,-0.116266,-0.127803,-0.645784,-0.204914,-0.310386,-0.043875,0.101189,0.042456,0.099138,0.091941,1.354582,0.638918,0.382970,-0.225512,0.001574,-0.383027,0.353084,0.236215,0.064633,0.140875,0.203618,0.166814,0.008469,0.069700,0.011852,0.034762,0.003330,-0.043733,-0.131317,0.058932,-0.037719,-0.017670,-0.081419,...,-0.060173,-0.265954,0.044442,0.207561,-0.128969,-0.236510,-0.076689,-0.018329,0.002775,-0.149684,0.042980,-0.122852,0.088874,0.039923,-0.128713,-0.003320,0.021708,0.058523,-0.059747,-0.036001,-0.001651,0.030332,0.043805,0.000257,-0.002485,-0.023343,0.051042,-0.001805,0.060330,0.019302,0.001051,-0.058577,-0.081520,-0.095410,0.039419,0.000573,0.150117,-0.226045,0.051211,0
223,0.516533,-0.105079,1.114144,0.627415,-0.097759,1.059575,0.282671,0.423967,-0.340315,0.361438,-0.404729,-0.871845,-0.223907,-0.309378,1.009465,1.087018,0.122369,-0.242121,-0.075401,0.272586,0.127249,-0.212193,0.667621,-0.394318,-0.298565,0.257163,0.014282,0.854887,-0.015084,-0.205926,-0.053955,0.101142,-0.070306,-0.435964,0.091965,-0.132697,0.698036,0.105351,-0.006197,-0.349521,...,-0.148243,0.007647,-0.004145,0.019249,0.027515,-0.022858,-0.033339,0.129889,0.064537,-0.163029,-0.176990,-0.252601,0.042586,-0.188929,0.005801,-0.018098,0.002150,-0.165400,0.188480,0.328509,-0.431322,-0.010545,-0.199571,-0.020562,-0.057336,-0.028942,-0.049069,-0.042420,0.051243,-0.041040,-0.009328,-0.113221,-0.089241,-0.018135,-0.101621,-0.063740,-0.087179,-0.009082,0.027355,0
228,0.578465,-0.101302,1.677973,0.581332,-0.799185,-0.650435,0.386856,0.852724,0.152394,0.765213,0.269906,0.069149,-0.927797,0.261491,0.707229,-0.037015,0.089576,0.405424,0.073973,0.533180,-0.175869,-0.320044,0.374951,0.062490,-0.316758,0.416350,0.334756,0.659739,-0.289682,-0.188616,0.169840,0.161243,-0.052012,-0.067168,0.086249,0.038677,0.384264,0.389687,-0.029247,0.261561,...,0.638762,-0.173297,-0.242189,-0.000403,0.242779,-0.192429,-0.026964,0.499232,0.217707,-0.372652,0.422246,-0.518930,0.012992,-0.372772,-0.228256,-0.086302,0.089449,-0.278769,0.278735,0.269680,-0.609697,0.057039,-0.243661,-0.101867,0.153069,0.151782,-0.085245,-0.035737,0.105956,0.095485,-0.085160,0.043194,0.094463,-0.135246,0.052647,0.171408,0.179968,-0.194070,0.171541,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14334,-0.598244,-0.102718,1.523235,-0.099984,-0.281082,-0.048439,-0.502699,-0.269441,0.803980,-0.275743,-0.113198,-0.154879,-0.015396,0.006290,-0.146122,0.090833,0.002239,0.135563,0.010683,0.012410,-0.121702,-0.127543,-0.018346,-0.001107,0.123678,-0.082830,0.049454,0.242458,-0.127734,0.436509,-0.312913,0.017383,-0.116716,0.120866,-0.063963,0.003101,-0.020417,0.097747,0.102915,0.163512,...,-0.033254,0.723845,0.099078,-0.237178,0.130144,0.712608,0.109773,-0.242300,0.231657,0.177417,-0.023918,-0.445134,0.102184,0.141951,0.183487,-0.163422,0.177213,-0.220893,0.065263,-0.151679,0.053321,-0.076083,-0.105861,0.098925,-0.045893,-0.011595,-0.036980,-0.017011,0.066338,-0.032756,-0.008828,0.055448,0.039220,-0.028579,-0.077158,0.033021,-0.021508,-0.024827,-0.047601,0
14421,-0.598244,-0.105079,1.057383,-0.401744,-0.391620,0.129470,-0.026114,-0.302946,-0.056680,-0.614087,-0.186207,-0.119515,-0.111065,-0.096851,0.080481,0.042539,-0.036076,0.050940,-0.036410,-0.065572,-0.085201,-0.079086,-0.085306,-0.071278,0.050447,-0.123501,0.014909,0.025751,-0.147131,-0.014617,0.014513,0.014960,-0.015077,-0.013874,-0.058387,-0.033015,0.044213,0.001831,-0.022036,0.055225,...,0.035144,0.032227,0.091320,-0.030062,0.096764,-0.012176,0.072770,0.020230,0.029948,-0.053388,0.050049,-0.013599,0.027267,-0.008128,-0.082385,0.056074,-0.016766,-0.040439,-0.036650,0.064133,0.069550,0.011994,-0.000009,-0.021670,0.001861,-0.118348,-0.073638,-0.040428,0.087722,0.058590,-0.056147,0.078024,-0.046593,-0.103124,-0.125405,-0.000180,-0.054954,0.099468,-0.111486,0
14486,-0.655015,-0.104449,1.035287,-0.416219,-0.386367,0.113579,-0.004378,-0.294410,-0.050775,-0.610978,-0.165169,-0.095164,-0.078233,-0.100342,0.029460,0.067828,-0.057355,0.041356,-0.045776,-0.058906,-0.078640,-0.065485,-0.057730,-0.078738,0.086327,-0.095175,0.011306,0.036057,-0.088348,-0.019747,0.004530,0.015076,0.016280,0.019458,-0.042180,-0.033788,0.024354,-0.015004,-0.011662,-0.018582,...,0.000105,-0.005356,-0.008855,-0.007397,0.007733,-0.008760,-0.013004,0.000191,-0.012769,0.000302,0.000316,0.006804,-0.000919,-0.009635,-0.008778,-0.012939,-0.011474,-0.003295,-0.006659,-0.007175,0.001826,-0.012760,-0.004621,-0.003358,0.007523,0.000784,-0.001065,-0.001341,-0.000007,0.007326,0.004381,-0.003438,0.007235,-0.006817,-0.000525,0.002603,-0.005936,0.000313,-0.001213,0
14492,0.640397,-0.103505,0.437030,0.302569,0.424052,-0.277284,-0.332591,1.518940,-0.449915,-0.439091,-0.130798,0.189508,0.086972,0.002282,-0.226984,-0.149930,-0.222510,0.544899,0.129155,-0.104642,-0.114074,0.211953,0.084092,-0.238644,0.106295,-0.148933,-0.182047,-0.264109,0.168537,-0.045219,0.076041,-0.074551,0.040553,0.081457,0.020118,-0.039456,-0.027891,-0.012482,0.048164,0.008899,...,-0.105166,-0.001291,0.012321,-0.036130,0.062269,-0.050980,0.014405,-0.165106,0.002664,0.143078,0.205902,0.113253,-0.013829,-0.034615,0.044395,-0.008373,-0.004195,-0.038226,0.138579,0.077532,-0.131487,0.035756,-0.111372,0.015379,0.012109,0.030853,-0.039365,-0.006260,0.000165,-0.006203,-0.005886,-0.056655,-0.098032,0.097823,0.012923,-0.074887,-0.041497,-0.130827,-0.107135,1


In [10]:
from sklearn.model_selection import train_test_split

X_myIDs = node_data_labeled.index.to_numpy() # myIDs for nodes with labels for training/testing
labels = node_data_labeled[label_col].to_numpy() # for stratification

test_size = 0.2
val_size = 0.1 * (1/(1-test_size))

myIDs_train_val, myIDs_test = train_test_split(X_myIDs, test_size=test_size, shuffle=True, stratify=labels)

labels_train_val = node_data_labeled.loc[myIDs_train_val][label_col].to_numpy()
myIDs_train, myIDs_val = train_test_split(myIDs_train_val, test_size=val_size, shuffle=True, stratify=labels_train_val)

# NOTE: train-val-test split is shuffled and stratified
# TODO: look into any special consideration necessary for train-test splits on graph-based models

# create masks
n_nodes = len(node_data)
train_mask = np.zeros(n_nodes, dtype=bool)
train_mask[myIDs_train] = True
train_mask = torch.Tensor(train_mask).type(torch.bool)

val_mask = np.zeros(n_nodes, dtype=bool)
val_mask[myIDs_val] = True
val_mask = torch.Tensor(val_mask).type(torch.bool)

test_mask = np.zeros(n_nodes, dtype=bool)
test_mask[myIDs_test] = True
test_mask = torch.Tensor(test_mask).type(torch.bool)

In [11]:
data = Data(x=X, y=y, edge_index=edge_list)
num_classes = 2
num_features = X.shape[1]

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

## Graph Convolutional Neural Network

In [12]:
from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F

# define GNN architecture
class GNNModel(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_classes, hidden_dense, GNN_conv_layer=GCNConv, dropout_rate=0.1, **kwargs):
        """
        Args:
            num_features (int): Dimension of input features
            hidden_channels (List[int]): Dimension of hidden features
            num_classes (int): Dimension of the output.
            hidden_dense (int): number of units in hidden dense layer following convolutions.
            GNN_conv_layer: Class of the graph convolutional layer to use.
            dropout_rate (float): Dropout rate to apply throughout the network
            kwargs: Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()

        self.convs = []
        self.convs.append(GNN_conv_layer(in_channels=num_features, out_channels=hidden_channels[0], **kwargs)) # first GNN Conv layer

        # for c1, c2 in zip(hidden_channels[1:-1], hidden_channels[2:]): # middle layers
        #     self.convs.append(GNN_conv_layer(in_channels=c1, out_channels=c2, **kwargs))

        self.convs = torch.nn.ModuleList(self.convs)

        self.dense1 = torch.nn.Linear(hidden_channels[-1], hidden_dense)
        self.dense_out = torch.nn.Linear(hidden_dense, num_classes)

        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index):
        """
        Args:
            x: node features
            edge_index: edge list
        """

        for i,conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = x.relu()
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.dense1(x)
        x = x.relu()
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.dense_out(x)

        return x

In [13]:
import pytorch_lightning as pl

# define Pytorch Lightning model
class LitGNN(pl.LightningModule):
    def __init__(self, model_name, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model_name = model_name
        self.model = GNNModel(**model_kwargs)
        self.loss_module = torch.nn.CrossEntropyLoss()

        self.example_input_array = data

    def forward(self, data, mode="train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)

        # Only calculate the loss and acc on the nodes corresponding to the mask
        if mode == "train":
            mask = data.train_mask
        elif mode == "val":
            mask = data.val_mask
        elif mode == "test":
            mask = data.test_mask
        else:
            raise ValueError(f"Unknown forward mode: {mode}")

        #TODO: add other metrics like recall, precision, f1, etc...
        loss = self.loss_module(x[mask], data.y[mask])
        acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()
        return x, loss, acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())#SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, loss, acc = self.forward(batch, mode="train")
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        logits, _, acc = self.forward(batch, mode="val")
        self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return logits

    def validation_epoch_end(self, validation_step_outputs):
        # NOTE: can't save non-standard GNN model like this
        # TODO: look into how to save torch geometric models
        # dummy_input = data
        # model_filename = f'{self.model_name}_{str(self.global_step).zfill(5)}.onnx'
        # torch.onnx.export(self, dummy_input, model_filename)
        # wandb.save(model_filename)

        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        self.logger.experiment.log({'val_logits': wandb.Histogram(flattened_logits.to('cpu')), 
                                    'global_step': self.global_step})

    def test_step(self, batch, batch_idx):
        x, _, acc = self.forward(batch, mode="test")
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    # def test_epoch_end(self, test_step_outputs):
    #     # save model as onnx format
    #     pass

In [14]:
# import os
# notebook_name = 'modeling_gnn.ipynb'
# os.environ['WANDB_NOTEBOOK_NAME'] = notebook_name

In [15]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import torch_geometric.loader

import datetime

model_name = f'bigboi_gat_{str(datetime.datetime.today())[:10]}'

logger = WandbLogger(name=model_name, project="Project X", log_model="all")#, version=...)


AVAIL_GPUS = min(1, torch.cuda.device_count())

model = LitGNN(model_name, num_features=num_features, hidden_channels=[128, 256, 128], 
               num_classes=num_classes, hidden_dense=64, GNN_conv_layer=GCNConv, dropout_rate=0.1)

data_loader = torch_geometric.loader.DataLoader([data], batch_size=1)

MAX_EPOCHS=500
trainer = pl.Trainer(
        callbacks=[ModelCheckpoint(save_weights_only=False, mode="max", monitor="val_acc"), 
                   EarlyStopping(monitor="val_acc_epoch", patience=20, verbose=False, mode="max")],
        gpus=AVAIL_GPUS,
        max_epochs=MAX_EPOCHS,
        logger=logger,
        # progress_bar_refresh_rate=0,
    )

trainer.fit(model, data_loader, data_loader)
model = LitGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"

  | Name        | Type             | Params | In sizes                     | Out sizes 
---------------------------------------------------------------------------------------------
0 | model       | GNNModel         | 21.6 K | [[14552, 102], [2, 4214097]] | [14552, 2]
1 | loss_module | CrossEntropyLoss | 0      | [[266, 2], [266]]            | ?         
---------------------------------------------------------------------------------------------
21.6 K    Trainable params
0         Non-trainable params
21.6 K    Total params
0.086     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [16]:
# evaluate model

from sklearn.metrics import classification_report
logits, _, _ = model.forward(data.to(device='cpu'))

preds_train = logits[data.train_mask].argmax(dim=-1)
preds_test = logits[data.test_mask].argmax(dim=-1)

y_train = data.y[data.train_mask]
y_test = data.y[data.test_mask]

train_report = classification_report(y_train, preds_train, labels=[0,1], target_names=['negative', 'positive'])
test_report = classification_report(y_test, preds_test, labels=[0,1], target_names=['negative', 'positive'])

print('training metrics')
print(train_report)
print()
print('testing metrics')
print(test_report)

training metrics
              precision    recall  f1-score   support

    negative       0.76      0.77      0.77       133
    positive       0.77      0.76      0.77       133

    accuracy                           0.77       266
   macro avg       0.77      0.77      0.77       266
weighted avg       0.77      0.77      0.77       266


testing metrics
              precision    recall  f1-score   support

    negative       0.76      0.72      0.74        39
    positive       0.72      0.76      0.74        38

    accuracy                           0.74        77
   macro avg       0.74      0.74      0.74        77
weighted avg       0.74      0.74      0.74        77



In [17]:
wandb.finish()

VBox(children=(Label(value=' 3.98MB of 3.98MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…



0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_acc_epoch,▃▂▂▁▂▁▂▃▃▄▄▄▄▄▅▅▅▆▅▅▅▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇███
train_acc_step,▁
train_loss_epoch,████▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train_loss_step,▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
val_acc_epoch,▄▂▁▁▂▂▃▃▃▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████▇▇▇▇▇▇▇██
val_acc_step,▄▂▁▁▂▂▃▃▃▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████▇▇▇▇▇▇▇██

0,1
epoch,62.0
global_step,62.0
train_acc_epoch,0.78947
train_acc_step,0.7782
train_loss_epoch,0.43928
train_loss_step,0.47397
trainer/global_step,62.0
val_acc_epoch,0.82051
val_acc_step,0.82051
