In [1]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.5.1-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.5.1


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Loading chunks of the dataset
##### Note: To preserve RAM, several unused variables are deleted throughout the norebook

In [3]:
import torch
chunk1 = torch.load('/content/drive/MyDrive/SCI_data/first.pt')

In [4]:
chunk2 = torch.load('/content/drive/MyDrive/SCI_data/second (1).pt')

In [9]:
from torch.utils.data import Dataset

class GraphDataset(Dataset):
  def __init__(self, data, preprocess):
    self.data = data
    self.preprocess = preprocess

  def __len__(self,):
    return len(self.data)

  def __getitem__(self, index):
    res = self.data[index]
    for p in self.preprocess:
      res = p(res)
    return res

chunk1+=chunk2
del chunk2

In [10]:
from sklearn.model_selection import train_test_split
rand_seed = 42
X_train, X_test = train_test_split(chunk1, test_size=0.1, random_state = rand_seed)
X_train, X_val = train_test_split(X_train, test_size=0.1, random_state = rand_seed)
print(len(X_train), len(X_val), len(X_val))

14580 1620 1620


In [12]:
device = 'cuda'

In [13]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(X_train, batch_size=32, shuffle=True)
val_loader = DataLoader(X_val, batch_size=32, shuffle=False)
test_loader = DataLoader(X_test, batch_size=32, shuffle=False)

In [14]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.1 torchmetrics-1.3.1


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, SAGEConv, SAGPooling
from torch_geometric.nn.conv import GraphConv
from torch_geometric.utils import to_undirected
from torch_geometric.data import DataLoader
from torchmetrics.classification import BinaryAUROC
auroc = BinaryAUROC()


class Network(nn.Module):
    def __init__(self, c_in, c_hidden, c_out, p=0.3):
        super().__init__()
        torch.manual_seed(123)
        self.conv1 = SAGEConv(c_in, c_hidden, aggr='mean')
        self.conv2 = SAGEConv(c_hidden,3*c_hidden, aggr='mean')
        self.conv3 = SAGEConv(3*c_hidden, 2*c_hidden, aggr='mean')
        self.conv4 = SAGEConv(2*c_hidden, c_hidden, aggr='mean')

        # self.pool = SAGPooling(c_hidden)

        self.lin1 = nn.Linear(c_hidden, 4*c_out)
        self.lin2 = nn.Linear(4*c_out, c_out)
        self.p = p

    def forward(self, x, edge_index, batch, is_train):
        x = self.conv1(x, edge_index)
        x = x.relu()

        x = self.conv2(x, edge_index)
        x = x.relu()

        x = self.conv3(x, edge_index)
        x = x.relu()

        x = self.conv4(x, edge_index)



        x = global_mean_pool(x, batch)

        # classifier

        x = F.dropout(x, p=self.p, training=is_train)
        x = self.lin1(x)

        x = F.dropout(x, p=self.p, training=is_train)
        x = self.lin2(x)

        return x


def evaluate(loader):
    model.eval()
    total_loss = 0.0
    correct = 0
    total_samples = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            batch.to(device)
            # print(batch.edge_index)
            pred = model(batch.x.float(), batch.edge_index, batch.batch, False)
            target = F.one_hot(batch.y, 2).float()
            loss = criterion(pred, target)
            total_loss += loss.item()

            pred_labels = torch.softmax(pred, -1).argmax(dim=-1)
            correct += (pred_labels == batch.y).sum().item()
            total_samples += len(batch.y)
            all_labels.append(batch.y)
            all_preds.append(pred_labels)

    pred = all_preds[0]
    label = all_labels[0]

    for p, l in zip(all_preds[1:], all_labels[1:]):
      pred = torch.cat([pred, p])
      label = torch.cat([label, l])

    return total_loss / len(loader), correct / total_samples, auroc(pred, label)

# Training loop
num_epochs = 50
embedding_dim = 64
model = Network(c_in=5, c_hidden=embedding_dim, c_out=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for idx, batch in enumerate(train_loader):

        batch = batch.to(device)

        pred = model(batch.x.float(), batch.edge_index, batch.batch, True)
        target = F.one_hot(batch.y, 2).float()
        loss = criterion(pred, target)
        epoch_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_train_loss = epoch_loss / len(train_loader)
    avg_val_loss, val_accuracy, val_auroc = evaluate(val_loader)

    print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val AUROC: {val_auroc:.4f}')

Epoch 1/50, Train Loss: 0.6614, Val Loss: 0.6176, Val Accuracy: 0.6796, Val AUROC: 0.6809
Epoch 2/50, Train Loss: 0.6218, Val Loss: 0.6058, Val Accuracy: 0.6716, Val AUROC: 0.6749
Epoch 3/50, Train Loss: 0.6071, Val Loss: 0.5873, Val Accuracy: 0.7136, Val AUROC: 0.7143
Epoch 4/50, Train Loss: 0.6035, Val Loss: 0.5835, Val Accuracy: 0.7099, Val AUROC: 0.7081
Epoch 5/50, Train Loss: 0.5970, Val Loss: 0.5874, Val Accuracy: 0.7154, Val AUROC: 0.7134
Epoch 6/50, Train Loss: 0.5982, Val Loss: 0.5797, Val Accuracy: 0.7136, Val AUROC: 0.7123
Epoch 7/50, Train Loss: 0.5988, Val Loss: 0.5819, Val Accuracy: 0.7179, Val AUROC: 0.7165
Epoch 8/50, Train Loss: 0.5972, Val Loss: 0.5818, Val Accuracy: 0.7154, Val AUROC: 0.7150
Epoch 9/50, Train Loss: 0.5964, Val Loss: 0.5774, Val Accuracy: 0.7198, Val AUROC: 0.7193
Epoch 10/50, Train Loss: 0.5908, Val Loss: 0.5764, Val Accuracy: 0.7185, Val AUROC: 0.7179
Epoch 11/50, Train Loss: 0.5922, Val Loss: 0.5772, Val Accuracy: 0.7142, Val AUROC: 0.7129
Epoch 12

In [20]:
# Testing
test_loss, test_accuracy, test_auroc = evaluate(test_loader)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {100*test_accuracy:.4f}%, Test AUROC: {test_auroc:.4f}')

Test Loss: 0.5703, Test Accuracy: 72.3333%, Test AUROC: 0.7239
