In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve, auc, accuracy_score
import shap
import matplotlib.pyplot as plt


### Prepare the dataset for training

In [None]:
df = pd.read_csv('vascular_dementia_dataset.csv')

df.describe()

In [None]:

X = df.drop('Has Vascular Dementia', axis=1).values  # Features
y = df['Has Vascular Dementia'].values  # Target

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3)

### Define the neural network

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self, input_size):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, 64)
        self.bn1 = nn.BatchNorm1d(64)
        self.layer2 = nn.Linear(64, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.layer3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.bn1(self.layer1(x)))
        x = torch.relu(self.bn2(self.layer2(x)))
        x = self.sigmoid(self.layer3(x))
        return x

In [None]:
# Initialize the model
input_size = X_train.shape[1]
model = NeuralNetwork(input_size)

# Step 4: Train the model
criterion = nn.BCELoss()  # Binary Cross-Entropy loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)

# Convert the data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)

# Prepare test data tensors
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)


In [None]:
# Training loop
epochs = 500
for epoch in range(epochs):
    model.train()

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass (Training data)
    output_train = model(X_train_tensor)
    loss = criterion(output_train, y_train_tensor)

    # Backward pass
    loss.backward()
    optimizer.step()

    # Calculate training accuracy
    predicted_train = (output_train > 0.5).float()
    train_accuracy = accuracy_score(y_train, predicted_train.numpy())

    # Evaluate on test set
    model.eval()
    with torch.no_grad():
        output_test = model(X_test_tensor)
        predicted_test = (output_test > 0.5).float()
        test_accuracy = accuracy_score(y_test, predicted_test.numpy())

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Training Loss: {loss.item():.4f}, '
              f'Training Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}')


In [None]:
# Step 5: Test the model (after training)
model.eval()
with torch.no_grad():
    y_pred = model(X_test_tensor)

In [None]:
y_pred_prob = y_pred.numpy().flatten()
fpr, tpr, _ = roc_curve(y_test, y_pred_prob)
roc_auc = auc(fpr, tpr)

In [None]:
# Plot the ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()

### Feature importance extraction using SHAPley

In [None]:
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

# Create SHAP explainer (GradientExplainer for deep learning)
explainer = shap.GradientExplainer(model, X_test_tensor)

# Compute SHAP values
shap_values = explainer.shap_values(X_test_tensor)

In [None]:
print(type(shap_values))  # Should be a list for GradientExplainer
print(len(shap_values))   # Should be 2 for binary classification (logits for class 0 and class 1)
print(shap_values[0].shape, shap_values[1].shape)  # Check each class's SHAP value shape
print(X_test.shape)  # Check feature matrix shape

In [None]:

feature_names = df.drop(columns=["Has Vascular Dementia"]).columns.to_list()

shap_values = shap_values.squeeze()  # Remove unnecessary dimensions
shap.summary_plot(shap_values, X_test, feature_names=feature_names)