In [5]:
!pip install torch-geometric
!pip install torchmetrics



In [6]:
import os
import numpy
import torch_geometric
import json
import torchmetrics
import torch

In [7]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

data_list = []
for file in os.listdir("./graphs"):
  with open("graphs/" + file, 'r') as f:
    j = json.load(f)
    x = torch.tensor(j["x"])
    edge_index = torch.tensor(j["edge_index"])
    y = torch.tensor(j["y"])
    data = Data(x=x, edge_index=edge_index, y=y)
    data_list.append(data)
train_set, validation_set, test_set = torch.utils.data.random_split(data_list, [0.6, 0.2, 0.2])
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=32, shuffle=True)
test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
print(len(train_set))
print(len(validation_set))
print(len(test_set))

766
255
255


In [8]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_mean_pool


class GraphConvModel(torch.nn.Module):
    def __init__(self):
        super(GraphConvModel, self).__init__()

        self.gat1 = GATConv(128, 256, add_self_loops=True)
        self.gat2 = GATConv(256, 256, add_self_loops=True)
        self.gat3 = GATConv(256, 256, add_self_loops=True)
        self.lin1 = Linear(256, 3)

    def forward(self, x, edge_index, batch):
        x = self.gat1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.gat2(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.gat3(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.lin1(x)
        return x

In [9]:
model = GraphConvModel()
print(model)

GraphConvModel(
  (gat1): GATConv(128, 256, heads=1)
  (gat2): GATConv(256, 256, heads=1)
  (gat3): GATConv(256, 256, heads=1)
  (lin1): Linear(in_features=256, out_features=3, bias=True)
)


In [30]:
from torchmetrics.classification import MulticlassAUROC
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.classification import MulticlassF1Score
from torchmetrics.classification import MulticlassRecall
from torchmetrics.classification import MulticlassSpecificity

def graph_test(model, loss_fn, loader):
      """
      model: pytorch GNN model
      loss_fn: loss function
      loader: DataLoader
      device: device used to bind the model and tensor
      return loss, weighted auroc, weighted accuracy
      """
      model.eval()
      total_loss = 0.0
      auroc_score = 0.0
      acc_score = 0.0

      pred = None
      target = None
      total_nodes_n = 0
      with torch.no_grad():
            for data in loader:
                  out = model(data.x, data.edge_index, data.batch)
                  loss = loss_fn(out, data.y)  # Compute the loss solely based on the training nodes.
                  total_loss += loss
                  if (pred is None):
                      pred = out
                  else:
                      pred = torch.cat((pred, out), 0)
                  if (target is None):
                      target = data.y
                  else:
                      target = torch.cat((target, data.y), 0)

      pred = pred
      target = target
      aucroc_metric = MulticlassAUROC(average=None, num_classes=pred.shape[1])
      acc_metric = MulticlassAccuracy(average=None, num_classes=pred.shape[1])
      f1_metric = MulticlassF1Score(average='weighted', num_classes=pred.shape[1])
      sensitivity_metric = MulticlassRecall(average='weighted', num_classes=pred.shape[1])
      specificity_metric = MulticlassSpecificity(average='weighted', num_classes=pred.shape[1])
      f1_metric.update(pred, target)
      f1_score = f1_metric.compute()
      sensitivity_metric.update(pred, target)
      sensitivity_score = sensitivity_metric.compute()
      specificity_metric.update(pred, target)
      specificity_score = specificity_metric.compute()
      aucroc_metric.update(pred, target)
      aucroc_classes = aucroc_metric.compute()
      acc_metric.update(pred, target)
      acc_classes = acc_metric.compute()
      freqs = torch.bincount(target, minlength=pred.shape[1])
      freqs = freqs / target.shape[0]
      auroc_score = torch.sum(aucroc_classes * freqs)
      acc_score = torch.sum(acc_classes * freqs)
      return total_loss / len(loader), auroc_score, acc_score, f1_score, sensitivity_score, specificity_score

In [31]:
learning_rate = 2e-3
epoch_num = 400
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    total_loss = 0.0
    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = loss_fn(out, data.y)  # Compute the loss.
        #print(out)
        #print(data.y)
        total_loss += loss
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
    return total_loss / len(train_loader)

for epoch in range(epoch_num):
    train_loss = train()
    train_loss, train_auroc, train_acc, train_f1, train_sensitivity, train_specifity = graph_test(model, loss_fn, train_loader)
    validation_loss, validation_auroc, validation_acc, validation_f1, validation_sensitivity, validation_specifity = graph_test(model, loss_fn, validation_loader)
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.5f}, Train Auc: {train_auroc:.4f}, Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f}, Train Sensitivity: {train_sensitivity:.4f}, Train Specificity: {train_specifity:.4f}, Valid Auc: {validation_auroc:.4f}, Valid Acc: {validation_acc:.4f}, Valid F1: {validation_f1:.4f}, Valid Sensitivity: {validation_sensitivity:.4f}, Valid Specificity: {validation_specifity:.4f}')

Epoch: 000, Train Loss: 0.63828, Train Auc: 0.9089, Train Acc: 0.7507, Train F1: 0.7499, Train Sensitivity: 0.7507, Train Specificity: 0.8743, Valid Auc: 0.8077, Valid Acc: 0.6235, Valid F1: 0.6178, Valid Sensitivity: 0.6235, Valid Specificity: 0.8103
Epoch: 001, Train Loss: 0.63015, Train Auc: 0.9037, Train Acc: 0.7415, Train F1: 0.7411, Train Sensitivity: 0.7415, Train Specificity: 0.8724, Valid Auc: 0.8240, Valid Acc: 0.6588, Valid F1: 0.6584, Valid Sensitivity: 0.6588, Valid Specificity: 0.8298
Epoch: 002, Train Loss: 0.60473, Train Auc: 0.9115, Train Acc: 0.7232, Train F1: 0.7248, Train Sensitivity: 0.7232, Train Specificity: 0.8604, Valid Auc: 0.8312, Valid Acc: 0.6549, Valid F1: 0.6523, Valid Sensitivity: 0.6549, Valid Specificity: 0.8255
Epoch: 003, Train Loss: 0.53880, Train Auc: 0.9267, Train Acc: 0.7833, Train F1: 0.7825, Train Sensitivity: 0.7833, Train Specificity: 0.8912, Valid Auc: 0.8040, Valid Acc: 0.6588, Valid F1: 0.6576, Valid Sensitivity: 0.6588, Valid Specificity:

In [None]:
torch.save(model, 'model2.pt')

In [33]:
pred = None
target = None
for data in test_loader:
    out = model(data.x, data.edge_index, data.batch)
    #loss = loss_fn(out, data.y)  # Compute the loss solely based on the training nodes.
    #total_loss += loss
    if (pred is None):
        pred = out
    else:
        pred = torch.cat((pred, out), 0)
    if (target is None):
        target = data.y
    else:
        target = torch.cat((target, data.y), 0)
acc_metric = MulticlassAccuracy(average=None, num_classes=pred.shape[1])
acc_metric.update(pred, target)
acc_classes = acc_metric.compute()
freqs = torch.bincount(target, minlength=pred.shape[1])
freqs = freqs / target.shape[0]
acc_score = torch.sum(acc_classes * freqs)
print(acc_score)

tensor(0.7804)


In [34]:
test_loss, test_auroc, test_acc, test_f1, test_sensitivity, test_specifity = graph_test(model, loss_fn, test_loader)
print(test_loss)
print(test_auroc)
print(test_acc)
print(test_f1)
print(test_sensitivity)
print(test_specifity)

tensor(1.1326)
tensor(0.9208)
tensor(0.7804)
tensor(0.7795)
tensor(0.7804)
tensor(0.8893)
