# GNN Modeling

## Data and Set up

In [1]:
import numpy as np
import pandas as pd

np.random.seed(314159) # set random seed

import torch
import pytorch_lightning as pl

from torch_geometric.data import Data

In [2]:
# load edge list
edge_list_path = 'data/edge_list.npy'
edge_list = torch.Tensor(np.load(edge_list_path).T).type(torch.int64) # read in format expected by pytorch geometric [2, n_edges]

# load protein-ID dictionary (need new ID system starting at index 0 for pytorch geometric)
protein_id_dict = np.load('data/protein_ids_dict.npy', allow_pickle=True).item() # maps my custom ID system to Ensembl IDs
protein_id_dict_inv = {Ensembl: id_ for id_, Ensembl in protein_id_dict.items()} # maps Ensembl IDs to my custom ID system

In [3]:
data_path = 'data/HPAnode_PPInetwork_labels_v3.csv' #NOTE: labels are generated from infomation in this dataset
node_dataset = pd.read_csv(data_path, index_col=0)

# map dataset
myID = node_dataset.index.map(protein_id_dict_inv).rename('myID')
node_dataset.insert(loc=0, column='myID', value=myID)
node_dataset = node_dataset.reset_index().set_index('myID')

In [4]:
# make sure dataset with myID is of correct order and format
node_dataset.sort_index(inplace=True) # should already be sorted, but just in case
assert((node_dataset.index.to_numpy() == np.arange(len(node_dataset))).all())

In [5]:
# create positives
label_name = 'my_label'

# find positives
pos_label_col = 'DisGenNet_thresh_pos' #FIXME: figure out meaning of columns and determing appropriate choice of positive labels
pos_labels = pd.array([1 if row[pos_label_col] else None for id_, row in node_dataset.iterrows()], dtype='Int32')
node_dataset[label_name] = pos_labels

# create negatives
def sample_negatives(PU_labels):
    '''randomly samples from the unlabeled samples'''

    # sample same # as positives
    num_pos = (PU_labels==1).sum()
    neg_inds = PU_labels[PU_labels.isna()].sample(num_pos).index

    # TODO: more sophisticated methods for sampling methods. (e.g.: use mutation rate, unsupervised learning, etc.)

    return neg_inds # returns ID's of negative samples

neg_label_inds = sample_negatives(node_dataset[label_name])
node_dataset[label_name].loc[neg_label_inds] = 0

# TODO: save this data for reproducibility (not now, but once this is finalized and fixed)

node_dataset[label_name].value_counts()

0    191
1    191
Name: my_label, dtype: Int64

In [6]:
label_col = label_name
node_dataset[label_col] = node_dataset[label_col].astype('Int32')

# TODO: decide whether or not to include network embedding features...
num_node_feats = 100
node_feat_cols = ['Tissue RNA - lung [NX]', 'Single Cell Type RNA - Mucus-secreting cells [NX]'] + [f'node_{i}' for i in range(num_node_feats)]

# get subset of node features features + labels
node_data = node_dataset[node_feat_cols + [label_col]]

X = torch.Tensor(node_data[node_feat_cols].to_numpy())#.type(torch.float64)

y = node_data[label_col].fillna(-1).astype('int') # fill NaN with -1 so that it can be converted to pytorch tensor
y = torch.Tensor(y).type(torch.int64)

# restrict to data with labels
node_data_labeled = node_data[node_data[label_col].notna()]
node_data_labeled

Unnamed: 0_level_0,Tissue RNA - lung [NX],Single Cell Type RNA - Mucus-secreting cells [NX],node_0,node_1,node_2,node_3,node_4,node_5,node_6,node_7,...,node_91,node_92,node_93,node_94,node_95,node_96,node_97,node_98,node_99,my_label
myID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
69,-0.660176,-0.105079,2.097750,-1.520459,0.979743,0.102077,-0.064198,-0.095211,0.785779,0.626443,...,0.099556,0.097325,-0.352616,0.111681,-0.197486,-0.071428,0.550138,0.093125,-0.146437,0
80,0.150101,-0.069827,2.622879,0.092524,1.558535,-1.148822,0.606971,0.573626,0.106728,-0.357630,...,-0.327284,-0.087676,0.254183,-0.066311,-0.014220,-0.059492,0.095315,0.159288,-0.186821,1
110,-0.144076,-0.055349,1.143110,-0.375444,-0.293915,0.119900,-0.075660,-0.116266,-0.127803,-0.645784,...,0.001051,-0.058577,-0.081520,-0.095410,0.039419,0.000573,0.150117,-0.226045,0.051211,0
223,0.516533,-0.105079,1.114144,0.627415,-0.097759,1.059575,0.282671,0.423967,-0.340315,0.361438,...,-0.009328,-0.113221,-0.089241,-0.018135,-0.101621,-0.063740,-0.087179,-0.009082,0.027355,0
228,0.578465,-0.101302,1.677973,0.581332,-0.799185,-0.650435,0.386856,0.852724,0.152394,0.765213,...,-0.085160,0.043194,0.094463,-0.135246,0.052647,0.171408,0.179968,-0.194070,0.171541,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14334,-0.598244,-0.102718,1.523235,-0.099984,-0.281082,-0.048439,-0.502699,-0.269441,0.803980,-0.275743,...,-0.008828,0.055448,0.039220,-0.028579,-0.077158,0.033021,-0.021508,-0.024827,-0.047601,0
14421,-0.598244,-0.105079,1.057383,-0.401744,-0.391620,0.129470,-0.026114,-0.302946,-0.056680,-0.614087,...,-0.056147,0.078024,-0.046593,-0.103124,-0.125405,-0.000180,-0.054954,0.099468,-0.111486,0
14486,-0.655015,-0.104449,1.035287,-0.416219,-0.386367,0.113579,-0.004378,-0.294410,-0.050775,-0.610978,...,0.004381,-0.003438,0.007235,-0.006817,-0.000525,0.002603,-0.005936,0.000313,-0.001213,0
14492,0.640397,-0.103505,0.437030,0.302569,0.424052,-0.277284,-0.332591,1.518940,-0.449915,-0.439091,...,-0.005886,-0.056655,-0.098032,0.097823,0.012923,-0.074887,-0.041497,-0.130827,-0.107135,1


In [7]:
from sklearn.model_selection import train_test_split

X_myIDs = node_data_labeled.index.to_numpy() # myIDs for nodes with labels for training/testing
labels = node_data_labeled[label_col].to_numpy() # for stratification

test_size = 0.2
val_size = 0.1 * (1/(1-test_size))

myIDs_train_val, myIDs_test = train_test_split(X_myIDs, test_size=test_size, shuffle=True, stratify=labels)

labels_train_val = node_data_labeled.loc[myIDs_train_val][label_col].to_numpy()
myIDs_train, myIDs_val = train_test_split(myIDs_train_val, test_size=val_size, shuffle=True, stratify=labels_train_val)

# NOTE: train-val-test split is shuffled and stratified
# TODO: look into any special consideration necessary for train-test splits on graph-based models

# create masks
n_nodes = len(node_data)
train_mask = np.zeros(n_nodes, dtype=bool)
train_mask[myIDs_train] = True
train_mask = torch.Tensor(train_mask).type(torch.bool)

val_mask = np.zeros(n_nodes, dtype=bool)
val_mask[myIDs_val] = True
val_mask = torch.Tensor(val_mask).type(torch.bool)

test_mask = np.zeros(n_nodes, dtype=bool)
test_mask[myIDs_test] = True
test_mask = torch.Tensor(test_mask).type(torch.bool)

In [8]:
data = Data(x=X, y=y, edge_index=edge_list)
num_classes = 2
num_features = X.shape[1]

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

## Graph Convolutional Neural Network

In [9]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

# define GCN architecture
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, dropout_rate=0):
        super(GCN, self).__init__()
        self.convs = []
        self.convs.append(GCNConv(num_features, hidden_channels)) # first GCNConv layer

        for _ in range(num_layers - 1): # middle layers
            self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.convs = torch.nn.ModuleList(self.convs)

        self.dense1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.dense_out = torch.nn.Linear(hidden_channels, num_classes)

        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index):

        for conv in self.convs:
            x = conv(x, edge_index)
            x = x.relu()
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        x = self.dense1(x)
        x = x.relu()
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.dense_out(x)

        return x

In [10]:
import pytorch_lightning as pl

# define Pytorch Lightning model
class LitGCN(pl.LightningModule):
    def __init__(self, model_name, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model_name = model_name
        self.model = GCN(**model_kwargs)
        self.loss_module = torch.nn.CrossEntropyLoss()

        self.example_input_array = data

    def forward(self, data, mode="train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)

        # Only calculate the loss and acc on the nodes corresponding to the mask
        if mode == "train":
            mask = data.train_mask
        elif mode == "val":
            mask = data.val_mask
        elif mode == "test":
            mask = data.test_mask
        else:
            raise ValueError(f"Unknown forward mode: {mode}")

        #TODO: add other metrics like recall, precision, f1, etc...
        loss = self.loss_module(x[mask], data.y[mask])
        acc = (x[mask].argmax(dim=-1) == data.y[mask]).sum().float() / mask.sum()
        return x, loss, acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())#SGD(self.parameters(), lr=0.1, momentum=0.9, weight_decay=2e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, loss, acc = self.forward(batch, mode="train")
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        logits, _, acc = self.forward(batch, mode="val")
        self.log("val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return logits

    def validation_epoch_end(self, validation_step_outputs):
        # NOTE: can't save non-standard GNN model like this
        # TODO: look into how to save torch geometric models
        # dummy_input = data
        # model_filename = f'{self.model_name}_{str(self.global_step).zfill(5)}.onnx'
        # torch.onnx.export(self, dummy_input, model_filename)
        # wandb.save(model_filename)

        flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
        self.logger.experiment.log({'val_logits': wandb.Histogram(flattened_logits.to('cpu')), 
                                    'global_step': self.global_step})

    def test_step(self, batch, batch_idx):
        x, _, acc = self.forward(batch, mode="test")
        self.log("test_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    # def test_epoch_end(self, test_step_outputs):
    #     # save model as onnx format
    #     pass

In [11]:
import os
notebook_name = 'modeling_gnn.ipynb'
os.environ['WANDB_NOTEBOOK_NAME'] = notebook_name

In [12]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import wandb
import torch_geometric.loader

import datetime

model_name = f'my_gcn_{str(datetime.datetime.today())[:10]}'

# logger = TensorBoardLogger("tb_logs", name=model_name)#, log_graph=True)
logger = WandbLogger(name=model_name, project="Project X", log_model="all")#, version=...)


AVAIL_GPUS = min(1, torch.cuda.device_count())
# AVAIL_GPUS = 0 # use when running out VRAM

model = LitGCN(model_name, hidden_channels=16, num_layers=2)#hidden_channels=64, num_layers=10, dropout_rate=0)

data_loader = torch_geometric.loader.DataLoader([data])#, batch_size=1, num_workers=2)


trainer = pl.Trainer(
        callbacks=[ModelCheckpoint(save_weights_only=False, mode="max", monitor="val_acc")],
        gpus=AVAIL_GPUS,
        max_epochs=500,
        logger=logger,
        # progress_bar_refresh_rate=0,
    )  # 0 because epoch size is 1

trainer.fit(model, data_loader, data_loader)
model = LitGCN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mawni00[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name        | Type             | Params | In sizes                     | Out sizes 
---------------------------------------------------------------------------------------------
0 | model       | GCN              | 2.2 K  | [[14552, 102], [2, 4214097]] | [14552, 2]
1 | loss_module | CrossEntropyLoss | 0      | [[266, 2], [266]]            | ?         
---------------------------------------------------------------------------------------------
2.2 K     Trainable params
0         Non-trainable params
2.2 K     Total params
0.009     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

  rank_zero_warn(
  rank_zero_warn(


Epoch 499: 100%|██████████| 2/2 [00:00<00:00,  6.14it/s, loss=0.0873, v_num=iu9a, train_loss_step=0.082, train_acc_step=0.992, val_acc_step=0.718, val_acc_epoch=0.718, train_loss_epoch=0.0821, train_acc_epoch=0.989]


In [13]:
# evaluate

from sklearn.metrics import classification_report
logits, _, _ = model.forward(data.to(device='cpu'))

preds_train = logits[data.train_mask].argmax(dim=-1)
preds_test = logits[data.test_mask].argmax(dim=-1)

y_train = data.y[data.train_mask]
y_test = data.y[data.test_mask]

train_report = classification_report(y_train, preds_train, labels=[0,1], target_names=['negative', 'positive'])
test_report = classification_report(y_test, preds_test, labels=[0,1], target_names=['negative', 'positive'])

print('training metrics')
print(train_report)
print()
print('testing metrics')
print(test_report)

training metrics
              precision    recall  f1-score   support

    negative       0.76      0.72      0.74       133
    positive       0.73      0.77      0.75       133

    accuracy                           0.74       266
   macro avg       0.74      0.74      0.74       266
weighted avg       0.74      0.74      0.74       266


testing metrics
              precision    recall  f1-score   support

    negative       0.79      0.69      0.74        39
    positive       0.72      0.82      0.77        38

    accuracy                           0.75        77
   macro avg       0.76      0.75      0.75        77
weighted avg       0.76      0.75      0.75        77

