In [1]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.5.1-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.5.1


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
chunk1 = torch.load('/content/drive/MyDrive/SCI_data/first.pt')

In [4]:
chunk2 = torch.load('/content/drive/MyDrive/SCI_data/second (1).pt')

In [5]:
# chunk3 = torch.load('/content/drive/MyDrive/SCI_data/third.pt')

In [6]:
chunk1+=chunk2
# chunk1+=chunk3

In [7]:
len(chunk1)

18000

In [8]:
from torch.utils.data import Dataset

class GraphDataset(Dataset):
  def __init__(self, data, preprocess):
    self.data = data
    self.preprocess = preprocess

  def __len__(self,):
    return len(self.data)

  def __getitem__(self, index):
    res = self.data[index]
    for p in self.preprocess:
      res = p(res)
    return res


# dataset = GraphDataset(chunk1, [drop_self_edges, drop_nodes_with_no_edges])

In [9]:
from sklearn.model_selection import train_test_split
rand_seed = 42
X_train, X_test = train_test_split(chunk1, test_size=0.1, random_state = rand_seed)
X_train, X_val = train_test_split(X_train, test_size=0.1, random_state = rand_seed)
print(len(X_train), len(X_val), len(X_val))

14580 1620 1620


In [10]:
device = 'cuda'

In [11]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(X_train, batch_size=32, shuffle=True)
val_loader = DataLoader(X_val, batch_size=32, shuffle=False)
test_loader = DataLoader(X_test, batch_size=32, shuffle=False)

In [12]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/840.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.0/840.4 kB[0m [31m8.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m839.7/840.4 kB[0m [31m13.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m24.9 MB/s[0m eta [36m0:00:00

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.nn.conv import GraphConv
from torch_geometric.utils import to_undirected
from torch_geometric.data import DataLoader
from torchmetrics.classification import BinaryAUROC
auroc = BinaryAUROC()


class Network(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, p=0.3):
        super().__init__()
        torch.manual_seed(123)
        self.conv1 = GraphConv(c_in, c_hidden)
        self.conv2 = GraphConv(c_hidden,3*c_hidden)

        self.conv3 = GraphConv(3*c_hidden, c_hidden)

        # self.pool = SAGPooling(c_hidden)

        self.lin1 = nn.Linear(c_hidden, 4*c_out)
        self.lin2 = nn.Linear(4*c_out, c_out)
        self.p = p

    def forward(self, x, edge_index, batch, is_train):
        x = self.conv1(x, edge_index)
        x = x.relu()

        x = self.conv2(x, edge_index)
        x = x.relu()

        x = self.conv3(x, edge_index)
        x = x.relu()

        x = global_mean_pool(x, batch)

        # classifier

        x = F.dropout(x, p=self.p, training=is_train)
        x = self.lin1(x)

        x = F.dropout(x, p=self.p, training=is_train)
        x = self.lin2(x)

        return x


def evaluate(loader):
    model.eval()
    total_loss = 0.0
    correct = 0
    total_samples = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            batch.to(device)
            # print(batch.edge_index)
            pred = model(batch.x.float(), batch.edge_index, batch.batch, False)
            target = F.one_hot(batch.y, 2).float()
            loss = criterion(pred, target)
            total_loss += loss.item()

            # Calculate accuracy
            pred_labels = torch.softmax(pred, -1).argmax(dim=-1)
            correct += (pred_labels == batch.y).sum().item()
            total_samples += len(batch.y)
            all_labels.append(batch.y)
            all_preds.append(pred_labels)

    pred = all_preds[0]
    label = all_labels[0]

    for p, l in zip(all_preds[1:], all_labels[1:]):
      pred = torch.cat([pred, p])
      label = torch.cat([label, l])

    return total_loss / len(loader), correct / total_samples, auroc(pred, label)

# Training loop with validation
num_epochs = 50
best_auroc = 0
model = Network(c_in=5, c_hidden=64, c_out=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for idx, batch in enumerate(train_loader):

        batch = batch.to(device)

        pred = model(batch.x.float(), batch.edge_index, batch.batch, True)
        target = F.one_hot(batch.y, 2).float()
        loss = criterion(pred, target)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_train_loss = epoch_loss / len(train_loader)
    avg_val_loss, val_accuracy, val_auroc = evaluate(val_loader)
    if val_auroc> best_auroc:
      best_auroc = val_auroc
      best_epoch = epoch
      torch.save(model.state_dict(), 'best_model_gnn.pth')
    print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val AUROC: {val_auroc:.4f}')

Epoch 1/50, Train Loss: 85.6892, Val Loss: 0.6639, Val Accuracy: 0.6093, Val AUROC: 0.6147
Epoch 2/50, Train Loss: 0.6774, Val Loss: 0.6755, Val Accuracy: 0.6722, Val AUROC: 0.6714
Epoch 3/50, Train Loss: 0.6538, Val Loss: 0.6288, Val Accuracy: 0.6870, Val AUROC: 0.6872
Epoch 4/50, Train Loss: 0.6482, Val Loss: 0.6108, Val Accuracy: 0.6907, Val AUROC: 0.6909
Epoch 5/50, Train Loss: 0.6276, Val Loss: 0.6154, Val Accuracy: 0.6747, Val AUROC: 0.6724
Epoch 6/50, Train Loss: 0.6304, Val Loss: 0.6075, Val Accuracy: 0.6833, Val AUROC: 0.6831
Epoch 7/50, Train Loss: 0.6268, Val Loss: 0.6011, Val Accuracy: 0.7000, Val AUROC: 0.7004
Epoch 8/50, Train Loss: 0.6227, Val Loss: 0.6034, Val Accuracy: 0.7006, Val AUROC: 0.7009
Epoch 9/50, Train Loss: 0.6206, Val Loss: 0.5982, Val Accuracy: 0.6975, Val AUROC: 0.6984
Epoch 10/50, Train Loss: 0.6215, Val Loss: 0.6071, Val Accuracy: 0.6858, Val AUROC: 0.6879
Epoch 11/50, Train Loss: 0.6201, Val Loss: 0.6013, Val Accuracy: 0.6951, Val AUROC: 0.6939
Epoch 1

In [17]:
# Testing
test_loss, test_accuracy, test_auroc = evaluate(test_loader)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {100*test_accuracy:.4f}%, Test AUROC: {test_auroc:.4f}')

Test Loss: 0.5759, Test Accuracy: 72.4444%, Test AUROC: 0.7235
