In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, auc, classification_report, adjusted_rand_score, silhouette_score
import numpy as np

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

In [None]:
# 1. Load and prepare data
df = pd.read_csv('../Preprocessing/balanced_ecg_data.csv')  # Your balanced dataset
df.head()

In [None]:
df['target'].value_counts()

In [None]:
X = df.drop('target', axis=1)
y = df['target']

In [None]:
# 2. Split data (stratified for class balance)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.3, 
    random_state=42,
    stratify=y  # Preserve class distribution
)

In [None]:
# 3. Feature scaling (essential for Logistic Regression)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

In [None]:
# 4. Initialize and train KMeans Clustering model
kmeans = KMeans(
    n_clusters=14,         # Number of clusters (change as needed)
    init='k-means++',     # Smart initialization
    max_iter=300,         # Max iterations for convergence
    random_state=42,      # Ensures reproducibility
    n_init=10,            # Number of times k-means runs with different centroid seeds
    algorithm='lloyd',    # Lloyd (default) or Elkan for efficiency
)

In [None]:
kmeans.fit(X_train_scaled, y_train)

In [None]:
y_pred = kmeans.predict(X_test_scaled)

# Adjust cluster labels to match ground truth
if np.sum(y_pred == y_test) < np.sum(y_pred != y_test):
    y_pred = 1 - y_pred  # Invert labels if necessary
    
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
    
print("Final Centroids:", kmeans.cluster_centers_)
print("Cluster Assignments:", y_pred)
print(f"\nAccuracy: {accuracy:.4f}")


In [None]:
# Assuming `y_test` contains actual class labels (if available)
print("Classification Report (if labels are available):")
print(classification_report(y_test, y_pred, zero_division=0))  # Only if labels exist

In [None]:
# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:", cm)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Class 0", "Class 1", "Class 2", "Class 3", "Class 4", "Class 5", "Class 6", "Class 7", "Class 8", "Class 9", "Class 10", "Class 11", "Class 12", "Class 13", "Class 14"], yticklabels=["Class 0", "Class 1", "Class 2", "Class 3", "Class 4", "Class 5", "Class 6", "Class 7", "Class 8", "Class 9", "Class 10", "Class 11", "Class 12", "Class 13", "Class 14"])

plt.title("Confusion Matrix")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

In [None]:
accuracy = accuracy_score(y_test, y_pred)
print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

In [None]:
from itertools import cycle

# Set up the figure with larger dimensions
plt.figure(figsize=(16, 12))  # Wider and taller for better visibility

# Custom color cycle for 15 classes
colors = cycle([
    '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
    '#1a55FF', '#FF33F4', '#228B22', '#B22222', '#8B008B'
])

# Sample data (replace with your actual FPR/TPR values)
n_classes = 15
fpr = dict()
tpr = dict()
roc_auc = {
    0: 1.000, 1: 0.999, 2: 0.999, 3: 0.998, 4: 1.000,
    5: 1.000, 6: 1.000, 7: 1.000, 8: 0.998, 9: 1.000,
    10: 1.000, 11: 1.000, 12: 1.000, 13: 1.000, 14: 1.000
}

# Generate smooth curves (replace with your actual data)
for i in range(n_classes):
    fpr[i] = np.linspace(0, 1, 100)
    tpr[i] = np.sqrt(1 - (1 - fpr[i]) ** (1 + 0.1 * i))  # Example curve
    plt.plot(fpr[i], tpr[i], color=next(colors), lw=3,
             label=f'Class {i} (AUC = {roc_auc[i]:.3f})')

# Formatting
plt.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.5)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)
plt.title('Multi-class ROC Curves for ECG Classification', fontsize=16, pad=20)
plt.legend(loc='lower right', fontsize=10, framealpha=1)
plt.grid(alpha=0.3)