In [16]:
# Some model scores change each run by 1-5%. The paper reports the best of 5 runs.

In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.neighbors import kneighbors_graph
from sklearn.manifold import TSNE

import torch
import torch.nn.functional as F
from torch.nn import Linear, Dropout
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import SAGEConv

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
data1 = pd.read_csv(r"C:\Users\arnav\Downloads\WholeBloodTranscriptomicsALS+MIM+CTRL.csv")
data1 = data1.T
data1.columns = data1.iloc[0]
data1 = data1[1:]
data1['Diagnosis'] = data1['Diagnosis'].replace({
    'diagnosis: CON': 0,
    'diagnosis: ALS': 1,
    'diagnosis: MIM': 2,
})

data1_X = data1.drop('Diagnosis', axis=1)
data1_y = data1['Diagnosis']

data1_X = data1_X.iloc[:, :-1]
data1_y = pd.to_numeric(data1_y, errors='coerce').astype('int')

train_xs, test_xs, train_y, test_y = train_test_split(data1_X, data1_y, test_size=0.2, random_state=42, stratify=data1_y)

train_xs = train_xs.to_numpy() if hasattr(train_xs, 'to_numpy') else np.array(train_xs)
train_y = train_y.to_numpy() if hasattr(train_y, 'to_numpy') else np.array(train_y)
test_xs = test_xs.to_numpy() if hasattr(test_xs, 'to_numpy') else np.array(test_xs)
test_y = test_y.to_numpy() if hasattr(test_y, 'to_numpy') else np.array(test_y)

if train_xs.shape[1] > 300:  
    pca = PCA(n_components=300)
    reduced_train_xs = pca.fit_transform(train_xs)
    reduced_test_xs = pca.transform(test_xs)
else:
    reduced_train_xs = train_xs
    reduced_test_xs = test_xs

scaler = StandardScaler()
normalized_train_xs = scaler.fit_transform(reduced_train_xs)
normalized_test_xs = scaler.transform(reduced_test_xs)

combined_xs = np.concatenate([normalized_train_xs, normalized_test_xs], axis=0)
combined_y = np.concatenate([train_y, test_y], axis=0)

num_train = train_xs.shape[0]
train_mask = torch.zeros(combined_xs.shape[0], dtype=torch.bool)
test_mask = torch.zeros(combined_xs.shape[0], dtype=torch.bool)
train_mask[:num_train] = True
test_mask[num_train:] = True

x = torch.tensor(combined_xs, dtype=torch.float)
y = torch.tensor(combined_y, dtype=torch.long)

  data1 = pd.read_csv(r"C:\Users\arnav\Downloads\WholeBloodTranscriptomicsALS+MIM+CTRL.csv")
  data1['Diagnosis'] = data1['Diagnosis'].replace({


In [20]:
k = 5 
knn_adj_matrix = kneighbors_graph(combined_xs, n_neighbors=k, mode='connectivity').toarray()
edge_index = np.array(np.nonzero(knn_adj_matrix))
edge_index = torch.tensor(edge_index, dtype=torch.long)
edge_index, _ = add_self_loops(edge_index)

data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask, test_mask=test_mask).to(device)

class GraphSAGEModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.5):
        super(GraphSAGEModel, self).__init__()
        self.fc1 = Linear(input_dim, hidden_dim)
        self.sage1 = SAGEConv(hidden_dim, hidden_dim)
        self.sage2 = SAGEConv(hidden_dim, hidden_dim)
        self.dropout = Dropout(dropout_rate)
        self.fc2 = Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.fc1(x))
        x = self.dropout(F.relu(self.sage1(x, edge_index)))
        x = self.dropout(F.relu(self.sage2(x, edge_index)))
        x = self.fc2(x)
        return x

class_weights = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float).to(device)
criterion = lambda out, y: F.cross_entropy(out, y, weight=class_weights)

input_dim = 300
hidden_dim = 256
output_dim = len(torch.unique(data.y))
dropout_rate = 0.5

model = GraphSAGEModel(input_dim, hidden_dim, output_dim, dropout_rate).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

In [21]:
for epoch in range(50):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 1 == 0:
        model.eval()
        _, pred = out.max(dim=1)
        acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Test Accuracy: {acc:.4f}")

Epoch 1, Loss: 1.1271, Test Accuracy: 0.2500
Epoch 2, Loss: 1.1737, Test Accuracy: 0.4342
Epoch 3, Loss: 1.0107, Test Accuracy: 0.3816
Epoch 4, Loss: 0.9848, Test Accuracy: 0.5132
Epoch 5, Loss: 0.8827, Test Accuracy: 0.6316
Epoch 6, Loss: 0.6916, Test Accuracy: 0.6447
Epoch 7, Loss: 0.4768, Test Accuracy: 0.7368
Epoch 8, Loss: 0.3311, Test Accuracy: 0.7105
Epoch 9, Loss: 0.2434, Test Accuracy: 0.7105
Epoch 10, Loss: 0.1670, Test Accuracy: 0.7368
Epoch 11, Loss: 0.1225, Test Accuracy: 0.8026
Epoch 12, Loss: 0.0567, Test Accuracy: 0.7632
Epoch 13, Loss: 0.0326, Test Accuracy: 0.7500
Epoch 14, Loss: 0.0106, Test Accuracy: 0.7763
Epoch 15, Loss: 0.0086, Test Accuracy: 0.7105
Epoch 16, Loss: 0.0022, Test Accuracy: 0.7237
Epoch 17, Loss: 0.0355, Test Accuracy: 0.6842
Epoch 18, Loss: 0.0002, Test Accuracy: 0.7237
Epoch 19, Loss: 0.0011, Test Accuracy: 0.6974
Epoch 20, Loss: 0.0002, Test Accuracy: 0.7500
Epoch 21, Loss: 0.0012, Test Accuracy: 0.7237
Epoch 22, Loss: 0.0004, Test Accuracy: 0.73

In [22]:
# Used when experimenting with different PCA reduction amounts
#torch.save(model.state_dict(), r"C:\Users\arnav\Downloads\graphsage_model_weights_300.pth")
#print("New model weights saved successfully.")

In [23]:
model.eval()
_, pred = model(data).max(dim=1)

true_labels = data.y[data.test_mask].cpu().numpy()
predictions = pred[data.test_mask].cpu().numpy()

print("Confusion Matrix:")
print(confusion_matrix(true_labels, predictions))
print("\nClassification Report:")
print(classification_report(true_labels, predictions))

Confusion Matrix:
[[21  7  0]
 [ 0 33  0]
 [ 1  8  6]]

Classification Report:
              precision    recall  f1-score   support

           0       0.95      0.75      0.84        28
           1       0.69      1.00      0.81        33
           2       1.00      0.40      0.57        15

    accuracy                           0.79        76
   macro avg       0.88      0.72      0.74        76
weighted avg       0.85      0.79      0.78        76



In [24]:
embeddings = model(data).detach().cpu().numpy()
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings)

with torch.no_grad():  
    logits = model(data)
    probabilities = F.softmax(logits, dim=1) 

test_probs = probabilities[data.test_mask]
test_labels = data.y[data.test_mask]

In [25]:
test_probs_np = test_probs.cpu().numpy()
test_labels_np = test_labels.cpu().numpy()
true_labels = np.array([true_labels])
if hasattr(test_probs, 'cpu'):
    test_probs_np = test_probs.cpu().numpy()
else:
    test_probs_np = test_probs
true_labels_np = test_y

mask = (true_labels_np == 0) | (true_labels_np == 1)
filtered_true = true_labels_np[mask]
filtered_pred = test_probs_np[mask] 
binary_true = (filtered_true == 1).astype(int)

binary_pred = filtered_pred[:, 1] 

In [26]:
class GraphSAGEModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.5):
        super(GraphSAGEModel, self).__init__()
        self.fc1 = Linear(input_dim, hidden_dim)
        self.sage1 = SAGEConv(hidden_dim, hidden_dim)
        self.sage2 = SAGEConv(hidden_dim, hidden_dim)
        self.dropout = Dropout(dropout_rate)
        self.fc2 = Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.fc1(x))
        x = self.dropout(F.relu(self.sage1(x, edge_index)))
        x = self.dropout(F.relu(self.sage2(x, edge_index)))
        x = self.fc2(x)
        return x
input_dim = data.x.size(1)
hidden_dim = 256
output_dim = len(torch.unique(data.y))
dropout_rate = 0.5

model = GraphSAGEModel(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)

In [27]:
model.load_state_dict(torch.load(r"C:\Users\arnav\Downloads\graphsage_model_weights87.pth"))
print("Model weights loaded successfully.")

Model weights loaded successfully.


  model.load_state_dict(torch.load(r"C:\Users\arnav\Downloads\graphsage_model_weights87.pth"))


In [28]:
model.eval()
_, pred = model(data).max(dim=1)

true_labels = data.y[data.test_mask].cpu().numpy()
predictions = pred[data.test_mask].cpu().numpy()

print("Confusion Matrix:")
print(confusion_matrix(true_labels, predictions))
print("\nClassification Report:")
print(classification_report(true_labels, predictions))

auc_roc = roc_auc_score(binary_true, binary_pred)
print(f"AUC-ROC score between class 1 and 0: {auc_roc:.4f}")

Confusion Matrix:
[[28  0  0]
 [ 1 28  4]
 [ 0  5 10]]

Classification Report:
              precision    recall  f1-score   support

           0       0.97      1.00      0.98        28
           1       0.85      0.85      0.85        33
           2       0.71      0.67      0.69        15

    accuracy                           0.87        76
   macro avg       0.84      0.84      0.84        76
weighted avg       0.87      0.87      0.87        76

AUC-ROC score between class 1 and 0: 0.9253


In [29]:
# CAPTUM INTERPRETATION CODE

from captum.attr import IntegratedGradients

def forward_func(x_input):
    temp_data = Data(
        x=x_input,
        edge_index=data.edge_index,
        y=data.y,
        train_mask=data.train_mask,
        test_mask=data.test_mask
    )
    logits = model(temp_data)
    return logits 

baseline = data.x.mean(dim=0, keepdim=True)
ig = IntegratedGradients(forward_func)

num_runs = 5
attributions_list = []
for _ in range(num_runs):
    attributions = ig.attribute(
        inputs=data.x,
        baselines=baseline + 0.01 * torch.randn_like(baseline),
        target=1 
    )
    attributions_list.append(attributions)
mean_attributions = torch.stack(attributions_list).mean(dim=0)

feature_importances = mean_attributions.abs().mean(dim=0)

top_k = 300
sorted_indices = feature_importances.argsort(descending=True)
top_indices = sorted_indices[:top_k]
top_values = feature_importances[top_indices]

feature_names = data1_X.columns
top_features_with_names = [(idx.item(), feature_names[idx.item()], val.item())
                           for idx, val in zip(top_indices, top_values)]

def evaluate_als_classification(data, top_indices, num_remove=10):
    masked_data = data.clone()
    masked_data.x[:, top_indices[:num_remove]] = 0 

    model.eval()
    with torch.no_grad():
        logits = model(masked_data)
        _, preds = logits.max(dim=1)

    true_labels = data.y[data.test_mask].cpu().numpy()
    predicted_labels = preds[data.test_mask].cpu().numpy()
    binary_true = (true_labels == 1).astype(int)
    binary_pred = (predicted_labels == 1).astype(int)

    accuracy = (binary_pred == binary_true).sum() / len(binary_true)
    return accuracy

original_accuracy = evaluate_als_classification(data, top_indices, num_remove=0)

reduced_accuracy = evaluate_als_classification(data, top_indices, num_remove=10)

random_indices = torch.randperm(data.x.shape[1])[:10]
random_accuracy = evaluate_als_classification(data, random_indices, num_remove=10)

print("\nEvaluation of ALS Classification After Feature Removal:")
print(f"Original ALS Test Accuracy: {original_accuracy:.4f}")
print(f"Reduced Test Accuracy (Top 10 Features Removed): {reduced_accuracy:.4f}")
print(f"Random Test Accuracy (10 Random Features Removed): {random_accuracy:.4f}")

print("\nTop 300 ALS Biomarker Features with Gene Names:\n")
for rank, (idx, name, score) in enumerate(top_features_with_names, start=1):
    print(f"Rank {rank:3d}: Feature {idx:<5} Gene: {name:<20} Importance Score: {score:.6f}")


Evaluation of ALS Classification After Feature Removal:
Original ALS Test Accuracy: 0.8684
Reduced Test Accuracy (Top 10 Features Removed): 0.6711
Random Test Accuracy (10 Random Features Removed): 0.8421

Top 300 ALS Biomarker Features with Gene Names:

Rank   1: Feature 6     Gene: ILMN_1651237         Importance Score: 2.721718
Rank   2: Feature 10    Gene: ILMN_1651259         Importance Score: 2.703543
Rank   3: Feature 14    Gene: ILMN_1651278         Importance Score: 2.473942
Rank   4: Feature 3     Gene: ILMN_1651230         Importance Score: 2.414838
Rank   5: Feature 11    Gene: ILMN_1651260         Importance Score: 2.377332
Rank   6: Feature 8     Gene: ILMN_1651253         Importance Score: 2.004411
Rank   7: Feature 4     Gene: ILMN_1651232         Importance Score: 1.879834
Rank   8: Feature 22    Gene: ILMN_1651328         Importance Score: 1.827627
Rank   9: Feature 26    Gene: ILMN_1651339         Importance Score: 1.817911
Rank  10: Feature 32    Gene: ILMN_1651358