In [1]:
# Cell 1: Imports
import sys
sys.path.append('../src')

from data_loader import load_mnist
from preprocessing import preprocess_data
from train_models import train_svm, train_rf
from evaluate import evaluate_model

# Cell 2: Load and preprocess
X, y = load_mnist()
X_train, X_test, y_train, y_test = preprocess_data(X, y)

# Cell 3: Train models
svm_model, svm_time = train_svm(X_train, y_train)
rf_model, rf_time = train_rf(X_train, y_train)

# Cell 4: Evaluate models
svm_acc = evaluate_model(svm_model, X_test, y_test, "SVM", "../results/svm_confusion_matrix.png")
rf_acc = evaluate_model(rf_model, X_test, y_test, "Random Forest", "../results/rf_confusion_matrix.png")

# Cell 5: Save comparison report
with open("../results/comparison_report.txt", "w") as f:
    f.write("Model Comparison Report\n")
    f.write("========================\n")
    f.write(f"SVM Accuracy: {svm_acc:.4f}, Time: {svm_time:.2f} sec\n")
    f.write(f"Random Forest Accuracy: {rf_acc:.4f}, Time: {rf_time:.2f} sec\n")

print("✅ Report saved in results folder.")


Loading MNIST dataset...
Dataset loaded: 70000 samples, 784 features.
Preprocessing data...
Train size: (56000, 784), Test size: (14000, 784)
Training SVM model...
SVM trained in 337.77 seconds.
Training Random Forest model...
Random Forest trained in 41.25 seconds.
Evaluating SVM...
SVM Accuracy: 0.9631

              precision    recall  f1-score   support

           0       0.99      0.98      0.98      1343
           1       0.98      0.99      0.98      1600
           2       0.95      0.96      0.95      1380
           3       0.96      0.95      0.96      1433
           4       0.96      0.96      0.96      1295
           5       0.97      0.96      0.96      1273
           6       0.97      0.98      0.98      1396
           7       0.93      0.97      0.95      1503
           8       0.97      0.95      0.96      1357
           9       0.96      0.94      0.95      1420

    accuracy                           0.96     14000
   macro avg       0.96      0.96      0.96