### GAT MILD VS. SEVERE 
### load data (all csv)

In [None]:
import numpy as np
import pandas as pd
from sklearn.svm import SVC
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_curve, auc, accuracy_score
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GATConv
from torch_geometric.data import Data

import seaborn as sns

from torch_geometric.explain import Explainer, GNNExplainer

In [None]:
# Get features to retain
X_train = pd.read_csv('mild_vs_severe_training_data.csv')
y_train = pd.read_csv('mild_vs_severe_training_label.csv')
y_train = np.array(y_train.drop(columns=['Unnamed: 0']))
X_train = np.array(X_train.drop(columns=['Unnamed: 0']))

X_test = pd.read_csv('mild_vs_severe_test_data.csv')
y_test = pd.read_csv('mild_vs_severe_test_label.csv')
y_test = np.array(y_test.drop(columns=['Unnamed: 0']))
X_test = np.array(X_test.drop(columns=['Unnamed: 0']))


In [None]:
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

In [None]:
# Calculate the adjacency matrix (correlation matrix) for training data.
data_matrix = X_train

# correlation threshold: only variable pairs with correlation coefficients greater than 0.8 are considered highly correlated.
threshold = 0.8  

if isinstance(data_matrix, pd.DataFrame):
    data_matrix = data_matrix.to_numpy()

correlation_matrix = np.corrcoef(data_matrix)
adj_matrix = (np.abs(correlation_matrix) >= threshold).astype(int)

labels = [True if i == ['Severe'] else False for i in y_train]

# Convert adjacency matrix and node feature matrix to PyTorch Tensors
adj_matrix = torch.tensor(adj_matrix, dtype=torch.long)
feature_matrix = torch.tensor(data_matrix, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# Build edge index
edge_index = adj_matrix.nonzero(as_tuple=True)
edge_index = torch.stack(edge_index, dim=0)

# Create PyTorch Geometric Data object
data_train = Data(x=feature_matrix, edge_index=edge_index, y=labels)


In [None]:
print(adj_matrix.shape)
print(data_train)

In [None]:
# Calculate the adjacency matrix (correlation matrix) for testing data.
data_matrix = X_test 
threshold = 0.8

if isinstance(data_matrix, pd.DataFrame):
    data_matrix = data_matrix.to_numpy()

correlation_matrix = np.corrcoef(data_matrix)
adj_matrix = (np.abs(correlation_matrix) >= threshold).astype(int)

labels = [True if i == ['Severe'] else False for i in y_test]

# Convert adjacency matrix and node feature matrix to PyTorch Tensors
adj_matrix = torch.tensor(adj_matrix, dtype=torch.long)
feature_matrix = torch.tensor(data_matrix, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# Build edge index
edge_index = adj_matrix.nonzero(as_tuple=True)
edge_index = torch.stack(edge_index, dim=0)

# Create PyTorch Geometric Data object
data_test = Data(x=feature_matrix, edge_index=edge_index, y=labels)


In [None]:
print(adj_matrix.shape)
print(data_train)

In [None]:
class GATClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, num_heads):
        super(GATClassifier, self).__init__()
        self.gat1 = GATConv(input_dim, hidden_dim, heads=num_heads, concat=True, dropout=0.6)
        self.gat2 = GATConv(hidden_dim * num_heads, num_classes, heads=num_heads, concat=True, dropout=0.6)

    def forward(self, x, edge_index):
        # x, edge_index = data.x, data.edge_index
        x = self.gat1(x, edge_index)
        x = torch.relu(x)
        x = self.gat2(x, edge_index)
        return torch.softmax(x, dim=1)

def evaluate(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct = float(pred.eq(data.y).sum().item())
        accuracy = correct / len(data.y)
        return accuracy

input_dim = feature_matrix.shape[1]
hidden_dim = 64
num_classes = len(set(labels.tolist()))
num_heads = 4

model = GATClassifier(input_dim, hidden_dim, num_classes, num_heads)
optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

num_epochs = 200

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    out = model(data_train.x, data_train.edge_index)
    loss = criterion(out, data_train.y)
    loss.backward()
    optimizer.step()
    
    # Evaluate the model on the train and test set
    train_acc = evaluate(model, data_train)
    test_acc = evaluate(model, data_test)

    print(f"Epoch: {epoch + 1}, Loss: {loss.item()}, Test Accuracy: {test_acc}")


In [None]:
# F1 score

model.eval()

with torch.no_grad():
    out = model(data_test.x, data_test.edge_index)
    pred = out.argmax(dim=1)
    pred = pred.tolist()
    y_test = data_test.y.tolist()
    f1 = f1_score(y_test, pred, average='weighted')
    print("f1:", f1)


In [None]:
# Confusion matrix

cm = confusion_matrix(y_test, pred)
cm_df = pd.DataFrame(cm, columns=['Predicted Positive', 'Predicted Negative'], index=['Actual Positive', 'Actual Negative'])

print(cm_df)


In [None]:
# AUC-ROC

fpr, tpr, _ = roc_curve(y_test, pred)
roc_auc = auc(fpr, tpr)


plt.figure()
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

accuracy = accuracy_score(y_test, pred)
print("Accuracy:", accuracy)

In [None]:
def evaluate_f1(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        f1 = f1_score(data.y.numpy(), pred.numpy(), average='macro')
    return f1

learning_rates = [0.003, 0.004, 0.005, 0.006, 0.007]
weight_decays = [3e-4, 4e-4, 5e-4, 6e-4, 7e-4]

results = np.zeros((len(learning_rates), len(weight_decays)))

for i, lr in enumerate(learning_rates):
    for j, wd in enumerate(weight_decays):
        model = GATClassifier(input_dim, hidden_dim, num_classes, num_heads)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

        # Train the model
        for epoch in range(num_epochs):
            model.train()
            optimizer.zero_grad()
            out = model(data_train.x, data_train.edge_index)
            loss = criterion(out, data_train.y)
            loss.backward()
            optimizer.step()

        # Evaluate the model on test data
        test_f1 = evaluate_f1(model, data_test)
        results[i, j] = test_f1


x_tick_labels = [f"{wd:.4f}" for wd in weight_decays]
y_tick_labels = [f"{lr:.4f}" for lr in learning_rates]

fig, ax = plt.subplots()
sns.heatmap(results, annot=True, fmt=".2f", xticklabels=x_tick_labels, yticklabels=y_tick_labels, ax=ax)

ax.set_xlabel("Weight Decay")
ax.set_ylabel("Learning Rate")
ax.set_title("Test F1 Heatmap")


ax.set_xticks(np.arange(len(weight_decays)) + 0.5)
ax.set_yticks(np.arange(len(learning_rates)) + 0.5)

ax.set_xticklabels(x_tick_labels)
ax.set_yticklabels(y_tick_labels)

plt.setp(ax.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')

plt.show()

In [None]:
# GNNExplainer

from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

node_index = 0
explanation = explainer(data_test.x, data_test.edge_index, index=node_index)
print(f'Generated explanations in {explanation.available_explanations}')

path = 'feature_importance_mildVSsevere.png'
explanation.visualize_feature_importance(path, top_k=15)
print(f"Feature importance plot has been saved to '{path}'")

# path = 'subgraph_v1.pdf'
# explanation.visualize_graph(path)
# print(f"Subgraph visualization plot has been saved to '{path}'")

In [None]:
for i in [1012, 793, 728, 1140, 216, 1664, 1073, 212, 1222, 658]:
    i = i + 1
    print(X_train.columns[i])