In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, auc, accuracy_score, confusion_matrix
import shap
import matplotlib.pyplot as plt
import seaborn as sns

### Prepare the dataset for training

In [None]:
df = pd.read_csv('vascular_dementia_dataset.csv')

df.describe()

In [None]:
X = df.drop('Has Vascular Dementia', axis=1).values  # Features
y = df['Has Vascular Dementia'].values  # Target

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3)

### Define the neural network

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.layer2 = nn.Linear(64, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.layer3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.bn1(self.layer1(x)))
        x = torch.relu(self.bn2(self.layer2(x)))
        x = self.sigmoid(self.layer3(x))
        return x

In [None]:
# Initialize the model
input_size = X_train.shape[1]
model = NeuralNetwork(input_size)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)

# Convert the data to tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)

X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)


In [None]:
# Training loop
epochs = 500
for epoch in range(epochs):
    model.train()

    optimizer.zero_grad()

    output_train = model(X_train_tensor)
    loss = criterion(output_train, y_train_tensor)

    loss.backward()
    optimizer.step()

    # Calculate training accuracy
    predicted_train = (output_train > 0.5).float()
    train_accuracy = accuracy_score(y_train, predicted_train.numpy())

    # Evaluate on test batch
    model.eval()
    with torch.no_grad():
        output_test = model(X_test_tensor)
        predicted_test = (output_test > 0.5).float()
        test_accuracy = accuracy_score(y_test, predicted_test.numpy())

    if (epoch + 1) % (epochs // 10) == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Training Loss: {loss.item():.4f}, '
              f'Training Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}')


### Evaluate the model performance

#### Model ROC-AUC curve

In [None]:
model.eval()
with torch.no_grad():
    y_pred = model(X_test_tensor)

In [None]:
y_pred_prob = y_pred.numpy().flatten()
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)

# Plot the ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()

#### Confusion matrix

In [None]:
y_pred_labels = (y_pred_prob > 0.5).astype(int)

cm = confusion_matrix(y_test, y_pred_labels)

plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Negative", "Positive"], yticklabels=["Negative", "Positive"])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

### Feature importance extraction using SHAPley

In [None]:
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

explainer = shap.GradientExplainer(model, X_test_tensor)

# Compute SHAP values
shap_values = explainer.shap_values(X_test_tensor)

feature_names = df.drop(columns=["Has Vascular Dementia"]).columns.to_list()

# plot summary of feature importance
shap_values = shap_values.squeeze()  # Remove unnecessary dimensions
shap.summary_plot(shap_values, X_test, feature_names=feature_names)