In [None]:
from keras.datasets import mnist
from UniversalImageRocketMnist import UniversalImageRocket
from sklearn.metrics import classification_report, confusion_matrix

# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# normalize
X_train = X_train.astype("float32") / 255.0
X_test = X_test.astype("float32") / 255.0

# instantiate model
model = UniversalImageRocket(num_kernels=1000)

# training
model.fit(X_train, y_train)

# test accuracy
acc = model.score(X_test, y_test)
print(f"MNIST Test Accuracy: {acc:.4f}")

# precision, recall, f1-score & confusion matrix
y_pred = model.predict(X_test)
print("\nClassification Report:\n")
print(classification_report(y_test, y_pred, digits=4))

print("\nConfusion Matrix:\n")
print(confusion_matrix(y_test, y_pred))

# Evaluation MNIST with 1000 kernels
MNIST Test Accuracy: 0.9776

Classification Report:

              precision    recall  f1-score   support

           0     0.9759    0.9918    0.9838       980
           1     0.9894    0.9885    0.9890      1135
           2     0.9777    0.9758    0.9767      1032
           3     0.9733    0.9752    0.9743      1010
           4     0.9897    0.9776    0.9836       982
           5     0.9699    0.9753    0.9726       892
           6     0.9863    0.9781    0.9822       958
           7     0.9793    0.9669    0.9731      1028
           8     0.9640    0.9887    0.9762       974
           9     0.9689    0.9574    0.9631      1009

    accuracy                         0.9776     10000
   macro avg     0.9774    0.9775    0.9775     10000
weighted avg     0.9777    0.9776    0.9776     10000


Confusion Matrix:

[[ 972    0    1    0    0    0    2    1    4    0]
 [   0 1122    5    0    2    0    4    1    0    1]
 [   2    0 1007    5    0    1    0    8    9    0]
 [   0    0    3  985    0   12    0    3    4    3]
 [   0    0    1    0  960    0    1    0    3   17]
 [   1    1    0    9    0  870    6    0    5    0]
 [   5    4    0    0    1    8  937    0    3    0]
 [   1    3   12    3    4    3    0  994    0    8]
 [   4    0    1    1    0    1    0    2  963    2]
 [  11    4    0    9    3    2    0    6    8  966]]

In [None]:
from UniversalImageRocketCifar import UniversalImageRocket
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# initalize model
model = UniversalImageRocket(num_kernels=6000, pca_components=5000, alpha=1.0)

# load data
X_train, X_test, y_train, y_test = model.load_data()

# fit
model.fit(X_train, y_train)

# prediction
y_pred = model.predict(X_test)

# evaluation
acc = np.mean(y_pred == y_test)
print(f"\nTest Accuracy: {acc:.4f}\n")

print("Classification Report:\n", classification_report(y_test, y_pred, digits=2))

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(9,7))
sns.heatmap(cm, cmap="Blues", xticklabels=range(10), yticklabels=range(10))
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("CIFAR-10 Confusion Matrix (ROCKET + RidgeClassifier)")
plt.show()


# Evaluation CIFAR-10 with 6000 kernels
Test Accuracy: 0.5003

Classification Report:
               precision    recall  f1-score   support

           0       0.59      0.53      0.56      1000
           1       0.54      0.57      0.56      1000
           2       0.43      0.36      0.39      1000
           3       0.39      0.31      0.34      1000
           4       0.45      0.38      0.41      1000
           5       0.45      0.47      0.46      1000
           6       0.50      0.58      0.53      1000
           7       0.51      0.53      0.52      1000
           8       0.58      0.68      0.63      1000
           9       0.52      0.60      0.55      1000

    accuracy                           0.50     10000
   macro avg       0.50      0.50      0.50     10000
weighted avg       0.50      0.50      0.50     10000


 Confusion Matrix:
 [[525  38  64  25  30  25  46  40 149  58]
 [ 42 572  14  15  18  24  46  37  74 158]
 [ 81  36 357  81  97 101 119  54  49  25]
 [ 44  55  67 310  81 176 106  67  25  69]
 [ 26  44  93  57 383  76 110 123  44  44]
 [ 15  30  77 126  71 474  58  87  25  37]
 [ 25  42  68  73  65  64 580  34  16  33]
 [ 18  39  35  64  77  89  41 528  29  80]
 [ 79  64  28  23  15  16  23  17 676  59]
 [ 28 135  21  26  15  19  40  45  73 598]]