# k-NN for Stroke Prediction: Effect of k Under Class Imbalance

This notebook accompanies the report and demonstrates the experiments described there.
It uses the Stroke Prediction Dataset (stroke.csv) and explores how changing **k** in k-NN
affects performance, especially on the minority (stroke) class.


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    average_precision_score,
    RocCurveDisplay,
    PrecisionRecallDisplay
)

plt.rcParams['figure.figsize'] = (7, 5)

In [None]:
# Load the stroke dataset
# Make sure stroke.csv is in the same directory as this notebook
df = pd.read_csv("stroke.csv")
print("Shape:", df.shape)
df.head()

In [None]:
# Basic info and class balance
print("\nData info:")
df.info()

print("\nClass distribution (stroke):")
print(df['stroke'].value_counts())
print("\nClass proportions (stroke):")
print(df['stroke'].value_counts(normalize=True))

# Bar plot of class distribution
df['stroke'].value_counts().plot(kind='bar')
plt.xticks([0, 1], ['No stroke (0)', 'Stroke (1)'], rotation=0)
plt.ylabel('Count')
plt.title('Class distribution: stroke vs non-stroke')
plt.tight_layout()
plt.show()

print("\nMissing values per column before imputation:")
print(df.isna().sum())

# Additional EDA plots: distributions of key numeric features by class
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Age distribution by stroke outcome
axes[0].hist(df[df['stroke'] == 0]['age'], bins=30, alpha=0.5, label='No stroke (0)')
axes[0].hist(df[df['stroke'] == 1]['age'], bins=30, alpha=0.5, label='Stroke (1)')
axes[0].set_xlabel('Age')
axes[0].set_ylabel('Count')
axes[0].set_title('Age distribution by stroke outcome')
axes[0].legend()

# Glucose distribution by stroke outcome
axes[1].hist(df[df['stroke'] == 0]['avg_glucose_level'], bins=30, alpha=0.5, label='No stroke (0)')
axes[1].hist(df[df['stroke'] == 1]['avg_glucose_level'], bins=30, alpha=0.5, label='Stroke (1)')
axes[1].set_xlabel('Average glucose level')
axes[1].set_ylabel('Count')
axes[1].set_title('Glucose level by stroke outcome')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Simple BMI imputation
df['bmi'] = df['bmi'].fillna(df['bmi'].median())

print("\nMissing values per column after BMI imputation:")
print(df.isna().sum())

In [None]:
# Define features and target
X = df.drop('stroke', axis=1)
y = df['stroke']

categorical_cols = ['gender', 'ever_married', 'work_type',
                    'Residence_type', 'smoking_status']
numeric_cols = ['age', 'hypertension', 'heart_disease',
                'avg_glucose_level', 'bmi']

numeric_transformer = StandardScaler()
categorical_transformer = OneHotEncoder(handle_unknown='ignore')

preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_cols),
        ('cat', categorical_transformer, categorical_cols),
    ]
)

# Train-test split (stratified)
X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42,
    stratify=y
)

print("Train size:", X_train.shape[0])
print("Test size:", X_test.shape[0])
print("Stroke prevalence in train:", y_train.mean())
print("Stroke prevalence in test:", y_test.mean())

In [None]:
# Loop over different k values
k_values = [1, 3, 5, 7, 9, 11, 15, 21]
test_acc_list = []
f1_stroke_list = []

for k in k_values:
    model = Pipeline(steps=[
        ('preprocess', preprocessor),
        ('knn', KNeighborsClassifier(n_neighbors=k))
    ])
    model.fit(X_train, y_train)
    y_test_pred = model.predict(X_test)

    test_acc = accuracy_score(y_test, y_test_pred)
    f1_stroke = f1_score(y_test, y_test_pred, pos_label=1)

    test_acc_list.append(test_acc)
    f1_stroke_list.append(f1_stroke)

    print(f"k = {k:2d} | test accuracy = {test_acc:.3f} | F1 (stroke=1) = {f1_stroke:.3f}")

In [None]:
# Plot accuracy vs k
plt.plot(k_values, test_acc_list, marker='o')
plt.xlabel('k (number of neighbours)')
plt.ylabel('Test accuracy')
plt.title('Effect of k on k-NN accuracy (stroke dataset)')
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()

In [None]:
# Plot F1-score for stroke class vs k
plt.plot(k_values, f1_stroke_list, marker='o')
plt.xlabel('k (number of neighbours)')
plt.ylabel('F1-score (stroke class)')
plt.title('Effect of k on F1-score for stroke class')
plt.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
plt.show()

In [None]:
# Detailed analysis for a chosen k (e.g. k = 1)
best_k = 1
best_model = Pipeline(steps=[
    ('preprocess', preprocessor),
    ('knn', KNeighborsClassifier(n_neighbors=best_k))
])
best_model.fit(X_train, y_train)
y_test_pred_best = best_model.predict(X_test)

print(f"\n=== Detailed evaluation for k = {best_k} ===")
cm = confusion_matrix(y_test, y_test_pred_best)
print("Confusion matrix (counts):\n", cm)
print("\nClassification report:")
print(classification_report(y_test, y_test_pred_best, digits=3))

# Plot confusion matrix as a heatmap for better visualisation
from sklearn.metrics import ConfusionMatrixDisplay

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=['No stroke (0)', 'Stroke (1)'])
disp.plot(values_format='d')
plt.title(f'Confusion matrix (k = {best_k})')
plt.tight_layout()
plt.show()

# ROC and Precision-Recall curves using predicted probabilities
y_scores = best_model.predict_proba(X_test)[:, 1]

roc_auc = roc_auc_score(y_test, y_scores)
avg_prec = average_precision_score(y_test, y_scores)
print(f"ROC AUC (k = {best_k}): {roc_auc:.3f}")
print(f"Average precision (PR AUC, k = {best_k}): {avg_prec:.3f}")

RocCurveDisplay.from_predictions(y_test, y_scores)
plt.title(f'ROC curve (k = {best_k})')
plt.tight_layout()
plt.show()

PrecisionRecallDisplay.from_predictions(y_test, y_scores)
plt.title(f'Precision-Recall curve (k = {best_k})')
plt.tight_layout()
plt.show()

## References

- Cover, T. M., & Hart, P. E. (1967). *Nearest neighbor pattern classification*. IEEE Transactions on Information Theory, 13(1), 21–27.
- Scikit-learn developers. *Nearest Neighbors*. In scikit-learn User Guide.
- Fedesoriano. (n.d.). *Stroke Prediction Dataset*. Kaggle.
- He, H., & Garcia, E. A. (2009). *Learning from imbalanced data*. IEEE Transactions on Knowledge and Data Engineering, 21(9), 1263–1284.
