# AI Capstone: Project 1
## *with a public non-image dataset*
Author: 0816066 官澔恩

Dataset Source: [Kaggle](https://www.kaggle.com/fedesoriano/stellar-classification-dataset-sdss17)

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, Markdown

from imblearn.under_sampling import RandomUnderSampler

from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split, cross_validate, ParameterGrid
from sklearn.metrics import confusion_matrix, classification_report

from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier

# Data Preprocessing

In [2]:
data_path = '../input/stellar-classification-dataset-sdss17/star_classification.csv'
data = pd.read_csv(data_path)
data

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
0,1.237661e+18,135.689107,32.494632,23.87882,22.27530,20.39501,19.16573,18.79371,3606,301,2,79,6.543777e+18,GALAXY,0.634794,5812,56354,171
1,1.237665e+18,144.826101,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.176014e+19,GALAXY,0.779136,10445,58158,427
2,1.237661e+18,142.188790,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.152200e+18,GALAXY,0.644195,4576,55592,299
3,1.237663e+18,338.741038,-0.402828,22.13682,23.77656,21.61162,20.50454,19.25010,4192,301,3,214,1.030107e+19,GALAXY,0.932346,9149,58039,775
4,1.237680e+18,345.282593,21.183866,19.43718,17.58028,16.49747,15.97711,15.54461,8102,301,3,137,6.891865e+18,GALAXY,0.116123,6121,56187,842
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,1.237679e+18,39.620709,-2.594074,22.16759,22.97586,21.90404,21.30548,20.73569,7778,301,2,581,1.055431e+19,GALAXY,0.000000,9374,57749,438
99996,1.237679e+18,29.493819,19.798874,22.69118,22.38628,20.45003,19.75759,19.41526,7917,301,1,289,8.586351e+18,GALAXY,0.404895,7626,56934,866
99997,1.237668e+18,224.587407,15.700707,21.16916,19.26997,18.20428,17.69034,17.35221,5314,301,4,308,3.112008e+18,GALAXY,0.143366,2764,54535,74
99998,1.237661e+18,212.268621,46.660365,25.35039,21.63757,19.91386,19.07254,18.62482,3650,301,4,131,7.601080e+18,GALAXY,0.455040,6751,56368,470


In [3]:
print(data['class'].value_counts())
# labeling encode the feature 'class'
encoder = LabelEncoder()
label_encoded_classes = encoder.fit_transform(data['class'].values)
data['class'] = label_encoded_classes

GALAXY    59445
STAR      21594
QSO       18961
Name: class, dtype: int64


In [4]:
corr = data.corr()
corr_class = corr['class'].sort_values()
threshold = (corr_class > 0.02) | (corr_class < -0.02)
print(pd.DataFrame({'corr_coef': corr_class, 'meet_thres': threshold}))

to_drops = threshold.index[~threshold].tolist()
to_drops.append('class')

             corr_coef  meet_thres
r            -0.076766        True
redshift     -0.054239        True
fiber_ID     -0.041586        True
run_ID       -0.036014        True
obj_ID       -0.036012        True
field_ID     -0.034833        True
u            -0.024645        True
g            -0.020066        True
alpha        -0.011756       False
spec_obj_ID  -0.010060       False
plate        -0.010060       False
z            -0.001614       False
MJD          -0.000405       False
delta         0.014452       False
i             0.015028       False
cam_col       0.023138        True
class         1.000000        True
rerun_ID           NaN       False


`alpha`, `spec_obj_ID`, `plate`, `z`, `MJD`, `delta`, `i`, and `rerun_ID` has little relation with the target feature `class`.

Therefore, I'll drop them.

In [5]:
# drop unrelated features and the target feature
data_X = data.drop(columns=to_drops)
data_y = data['class']

In [6]:
data_X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100000 entries, 0 to 99999
Data columns (total 9 columns):
 #   Column    Non-Null Count   Dtype  
---  ------    --------------   -----  
 0   obj_ID    100000 non-null  float64
 1   u         100000 non-null  float64
 2   g         100000 non-null  float64
 3   r         100000 non-null  float64
 4   run_ID    100000 non-null  int64  
 5   cam_col   100000 non-null  int64  
 6   field_ID  100000 non-null  int64  
 7   redshift  100000 non-null  float64
 8   fiber_ID  100000 non-null  int64  
dtypes: float64(5), int64(4)
memory usage: 6.9 MB


`run_ID`, `cam_col`, `field_ID`, and `fiber_ID` is categorical.

They have to be one-hot encoded.

In [7]:
# transform each categorical attributes into one-hot encoded ones
cat_attrs = ['run_ID', 'cam_col', 'field_ID', 'fiber_ID']
for cat_attr in cat_attrs:
    one_hot = pd.get_dummies(data_X[cat_attr]) \
                .add_prefix(f'{ cat_attr }_')
    data_X = data_X.drop(columns=cat_attr)
    data_X = data_X.join(one_hot)

data_X = data_X.values
print(f'Number of features becomes { data_X.shape[1] }.')

Number of features becomes 2297.


In [8]:
test_sizes = [0.2, 0.3]
datasets = [ train_test_split(data_X, data_y.values, test_size=test_size)
             for test_size in test_sizes ]

In [9]:
rus = RandomUnderSampler()
pca = PCA(n_components=100)
scaler = MinMaxScaler()

for idx, dataset in enumerate(datasets):
    X_train, X_test, y_train, y_test = dataset
    
    X_train, y_train = rus.fit_resample(X_train, y_train)
    print(f'Number of each class: { np.unique(y_train, return_counts=True)[0] }')
    
    X_train = pca.fit_transform(X_train)
    X_test = pca.transform(X_test)
    
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    datasets[idx] = (X_train, X_test, y_train, y_test)

Number of each class: [0 1 2]
Number of each class: [0 1 2]


# Models

In [10]:
# make a specified model with desired parameters
def get_model(model_type, params):
    if model_type == 'knn':
        return KNeighborsClassifier(**params)
    elif model_type == 'rf':
        return RandomForestClassifier(criterion='gini', **params)
    elif model_type == 'svm':
        return SVC(kernel='rbf', **params)
    elif model_type == 'mlp':
        return MLPClassifier(**params)
    else:
        return None

# Validation & Results

In [11]:
# display a confusion matrix and the classification report
def show_performance(y_true, y_pred):
    c_matrix = confusion_matrix(y_true, y_pred)
    c_table = pd.DataFrame(c_matrix)
    c_table.columns.name = 'truth\pred'
    display(c_table)
    
    report = classification_report(y_true, y_pred)
    print(report)

In [12]:
def show_cross_validate_report(res):
    report = pd.DataFrame({
        'fit_time': res['fit_time'],
        'score_time': res['score_time'],
        'test_score': res['test_score'],
    })
    display(report)

In [13]:
# train a model with 5-fold cross validation and validate the best model with the testing set
def train_model(model_type, param_grid, datasets):
    for test_size, dataset in zip(test_sizes, datasets):
        display(Markdown(f'### Test size: { test_size }'))
        X_train, X_test, y_train, y_test = dataset

        for params in param_grid:
            display(Markdown(f'#### { params }'))
            model = get_model(model_type, params)
            res = cross_validate(model, X_train, y_train, return_estimator=True)

            display(Markdown('#### Training Performance:'))
            show_cross_validate_report(res)

            best_model = res['estimator'][res['test_score'].argmax()]
            y_pred = best_model.predict(X_test)

            display(Markdown('#### Testing Performance:'))
            show_performance(y_test, y_pred)

## KNN

In [14]:
param_grid = ParameterGrid({
    'n_neighbors': [5, 10, 15]
})
train_model('knn', param_grid, datasets)

### Test size: 0.2

#### {'n_neighbors': 5}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,0.591357,27.756326,0.712545
1,0.589011,24.303007,0.724838
2,0.591559,24.698927,0.713862
3,0.600424,18.960065,0.726704
4,0.585859,25.054924,0.726375


#### Testing Performance:

truth\pred,0,1,2
0,9113,1129,1674
1,617,2877,282
2,1053,381,2874


              precision    recall  f1-score   support

           0       0.85      0.76      0.80     11916
           1       0.66      0.76      0.70      3776
           2       0.60      0.67      0.63      4308

    accuracy                           0.74     20000
   macro avg       0.70      0.73      0.71     20000
weighted avg       0.76      0.74      0.75     20000



#### {'n_neighbors': 10}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,0.596431,33.938075,0.712216
1,0.595089,29.946334,0.71935
2,0.593379,30.550804,0.715509
3,0.596539,22.89802,0.717594
4,0.593348,30.364032,0.720119


#### Testing Performance:

truth\pred,0,1,2
0,9149,985,1782
1,602,2869,305
2,1109,370,2829


              precision    recall  f1-score   support

           0       0.84      0.77      0.80     11916
           1       0.68      0.76      0.72      3776
           2       0.58      0.66      0.61      4308

    accuracy                           0.74     20000
   macro avg       0.70      0.73      0.71     20000
weighted avg       0.75      0.74      0.75     20000



#### {'n_neighbors': 15}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,0.593985,38.627063,0.712216
1,0.590709,32.515585,0.720997
2,0.596555,33.590678,0.716936
3,0.635099,26.2542,0.718253
4,0.589623,34.751071,0.717594


#### Testing Performance:

truth\pred,0,1,2
0,8892,992,2032
1,583,2868,325
2,1068,359,2881


              precision    recall  f1-score   support

           0       0.84      0.75      0.79     11916
           1       0.68      0.76      0.72      3776
           2       0.55      0.67      0.60      4308

    accuracy                           0.73     20000
   macro avg       0.69      0.72      0.70     20000
weighted avg       0.75      0.73      0.74     20000



### Test size: 0.3

#### {'n_neighbors': 5}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,0.52314,19.036655,0.722818
1,0.523246,19.218825,0.729676
2,0.529201,14.709198,0.718329
3,0.543404,17.714947,0.71156
4,0.511894,19.28845,0.720663


#### Testing Performance:

truth\pred,0,1,2
0,13262,1738,2820
1,982,4195,418
2,1616,559,4410


              precision    recall  f1-score   support

           0       0.84      0.74      0.79     17820
           1       0.65      0.75      0.69      5595
           2       0.58      0.67      0.62      6585

    accuracy                           0.73     30000
   macro avg       0.69      0.72      0.70     30000
weighted avg       0.74      0.73      0.73     30000



#### {'n_neighbors': 10}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,0.534756,23.034114,0.728304
1,0.516666,22.719599,0.723691
2,0.516264,17.794184,0.721945
3,0.520218,20.817512,0.719541
4,0.518215,23.319162,0.718793


#### Testing Performance:

truth\pred,0,1,2
0,13155,1664,3001
1,945,4184,466
2,1604,555,4426


              precision    recall  f1-score   support

           0       0.84      0.74      0.78     17820
           1       0.65      0.75      0.70      5595
           2       0.56      0.67      0.61      6585

    accuracy                           0.73     30000
   macro avg       0.68      0.72      0.70     30000
weighted avg       0.74      0.73      0.73     30000



#### {'n_neighbors': 15}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,0.524993,24.68735,0.725561
1,0.524052,25.362343,0.723441
2,0.51113,20.247196,0.724813
3,0.522579,23.656578,0.7158
4,0.514059,26.158248,0.715052


#### Testing Performance:

truth\pred,0,1,2
0,13024,1580,3216
1,919,4214,462
2,1634,532,4419


              precision    recall  f1-score   support

           0       0.84      0.73      0.78     17820
           1       0.67      0.75      0.71      5595
           2       0.55      0.67      0.60      6585

    accuracy                           0.72     30000
   macro avg       0.68      0.72      0.70     30000
weighted avg       0.74      0.72      0.73     30000



## Random Forest

In [15]:
param_grid = ParameterGrid({
    'min_samples_leaf': [1, 5, 10]
})
train_model('rf', param_grid, datasets)

### Test size: 0.2

#### {'min_samples_leaf': 1}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,45.675269,0.227156,0.88673
1,46.613976,0.223292,0.894852
2,45.54797,0.224702,0.885303
3,46.398122,0.226215,0.888816
4,45.338431,0.22725,0.886401


#### Testing Performance:

truth\pred,0,1,2
0,10431,305,1180
1,278,3434,64
2,393,6,3909


              precision    recall  f1-score   support

           0       0.94      0.88      0.91     11916
           1       0.92      0.91      0.91      3776
           2       0.76      0.91      0.83      4308

    accuracy                           0.89     20000
   macro avg       0.87      0.90      0.88     20000
weighted avg       0.90      0.89      0.89     20000



#### {'min_samples_leaf': 5}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,41.612931,0.196904,0.880474
1,41.560394,0.196464,0.890462
2,41.109924,0.197646,0.879596
3,41.548987,0.199101,0.885633
4,42.266616,0.195365,0.89145


#### Testing Performance:

truth\pred,0,1,2
0,10371,326,1219
1,276,3427,73
2,419,5,3884


              precision    recall  f1-score   support

           0       0.94      0.87      0.90     11916
           1       0.91      0.91      0.91      3776
           2       0.75      0.90      0.82      4308

    accuracy                           0.88     20000
   macro avg       0.87      0.89      0.88     20000
weighted avg       0.89      0.88      0.89     20000



#### {'min_samples_leaf': 10}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,39.903651,0.189353,0.876413
1,39.905484,0.187791,0.882779
2,40.343466,0.19675,0.876413
3,40.329696,0.185294,0.881791
4,39.510999,0.188216,0.880803


#### Testing Performance:

truth\pred,0,1,2
0,10276,349,1291
1,291,3400,85
2,463,8,3837


              precision    recall  f1-score   support

           0       0.93      0.86      0.90     11916
           1       0.90      0.90      0.90      3776
           2       0.74      0.89      0.81      4308

    accuracy                           0.88     20000
   macro avg       0.86      0.88      0.87     20000
weighted avg       0.88      0.88      0.88     20000



### Test size: 0.3

#### {'min_samples_leaf': 1}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,39.973126,0.1988,0.888404
1,39.49941,0.200034,0.890773
2,39.751418,0.205147,0.885162
3,39.989056,0.208263,0.889637
4,38.928485,0.203368,0.884275


#### Testing Performance:

truth\pred,0,1,2
0,15375,473,1972
1,495,4999,101
2,637,10,5938


              precision    recall  f1-score   support

           0       0.93      0.86      0.90     17820
           1       0.91      0.89      0.90      5595
           2       0.74      0.90      0.81      6585

    accuracy                           0.88     30000
   macro avg       0.86      0.89      0.87     30000
weighted avg       0.89      0.88      0.88     30000



#### {'min_samples_leaf': 5}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,35.865542,0.199885,0.882294
1,35.907469,0.179981,0.885037
2,36.045514,0.183883,0.878678
3,36.469109,0.18281,0.885397
4,36.163516,0.179604,0.878913


#### Testing Performance:

truth\pred,0,1,2
0,15291,470,2059
1,488,5001,106
2,633,6,5946


              precision    recall  f1-score   support

           0       0.93      0.86      0.89     17820
           1       0.91      0.89      0.90      5595
           2       0.73      0.90      0.81      6585

    accuracy                           0.87     30000
   macro avg       0.86      0.88      0.87     30000
weighted avg       0.88      0.87      0.88     30000



#### {'min_samples_leaf': 10}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,34.234166,0.160088,0.87606
1,33.971726,0.159327,0.875935
2,34.160286,0.158448,0.873691
3,34.014092,0.170868,0.88203
4,33.202309,0.159791,0.876543


#### Testing Performance:

truth\pred,0,1,2
0,15214,499,2107
1,507,4965,123
2,713,8,5864


              precision    recall  f1-score   support

           0       0.93      0.85      0.89     17820
           1       0.91      0.89      0.90      5595
           2       0.72      0.89      0.80      6585

    accuracy                           0.87     30000
   macro avg       0.85      0.88      0.86     30000
weighted avg       0.88      0.87      0.87     30000



## SVM

In [16]:
param_grid = ParameterGrid({
    'C': [1, 5, 10]
})
train_model('svm', param_grid, datasets)

### Test size: 0.2

#### {'C': 1}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,164.038457,29.163267,0.876194
1,164.236164,28.99381,0.884535
2,167.390821,28.735543,0.877072
3,163.191633,28.485504,0.875206
4,166.370654,28.363239,0.883108


#### Testing Performance:

truth\pred,0,1,2
0,9761,265,1890
1,461,3268,47
2,131,0,4177


              precision    recall  f1-score   support

           0       0.94      0.82      0.88     11916
           1       0.92      0.87      0.89      3776
           2       0.68      0.97      0.80      4308

    accuracy                           0.86     20000
   macro avg       0.85      0.88      0.86     20000
weighted avg       0.88      0.86      0.86     20000



#### {'C': 5}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,128.137603,20.711776,0.902316
1,126.526571,20.808159,0.909121
2,134.333259,21.147509,0.907365
3,125.702542,20.71186,0.903084
4,137.324692,20.597887,0.911426


#### Testing Performance:

truth\pred,0,1,2
0,10389,314,1213
1,415,3340,21
2,88,1,4219


              precision    recall  f1-score   support

           0       0.95      0.87      0.91     11916
           1       0.91      0.88      0.90      3776
           2       0.77      0.98      0.86      4308

    accuracy                           0.90     20000
   macro avg       0.88      0.91      0.89     20000
weighted avg       0.91      0.90      0.90     20000



#### {'C': 10}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,119.793507,18.491478,0.912853
1,135.968007,18.62257,0.919219
2,133.606181,18.624306,0.919438
3,136.59777,18.379378,0.913292
4,119.959949,18.147001,0.918779


#### Testing Performance:

truth\pred,0,1,2
0,10577,382,957
1,407,3355,14
2,76,1,4231


              precision    recall  f1-score   support

           0       0.96      0.89      0.92     11916
           1       0.90      0.89      0.89      3776
           2       0.81      0.98      0.89      4308

    accuracy                           0.91     20000
   macro avg       0.89      0.92      0.90     20000
weighted avg       0.91      0.91      0.91     20000



### Test size: 0.3

#### {'C': 1}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,124.184621,22.820486,0.876559
1,123.836879,22.865508,0.875561
2,124.530362,22.652485,0.875436
3,127.304084,23.439789,0.876294
4,128.337514,23.087257,0.868313


#### Testing Performance:

truth\pred,0,1,2
0,14219,577,3024
1,706,4796,93
2,226,2,6357


              precision    recall  f1-score   support

           0       0.94      0.80      0.86     17820
           1       0.89      0.86      0.87      5595
           2       0.67      0.97      0.79      6585

    accuracy                           0.85     30000
   macro avg       0.83      0.87      0.84     30000
weighted avg       0.87      0.85      0.85     30000



#### {'C': 5}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,90.092665,16.335645,0.904738
1,90.795599,16.46302,0.901372
2,91.252213,16.3325,0.904863
3,91.754706,16.40981,0.904851
4,91.259218,16.161797,0.899239


#### Testing Performance:

truth\pred,0,1,2
0,15107,534,2179
1,649,4903,43
2,156,1,6428


              precision    recall  f1-score   support

           0       0.95      0.85      0.90     17820
           1       0.90      0.88      0.89      5595
           2       0.74      0.98      0.84      6585

    accuracy                           0.88     30000
   macro avg       0.86      0.90      0.88     30000
weighted avg       0.90      0.88      0.88     30000



#### {'C': 10}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,84.538599,14.932498,0.911721
1,84.879532,14.617152,0.912344
2,84.498279,14.545146,0.91197
3,83.542183,14.406843,0.913705
4,85.667024,14.404963,0.911086


#### Testing Performance:

truth\pred,0,1,2
0,15523,574,1723
1,634,4929,32
2,125,1,6459


              precision    recall  f1-score   support

           0       0.95      0.87      0.91     17820
           1       0.90      0.88      0.89      5595
           2       0.79      0.98      0.87      6585

    accuracy                           0.90     30000
   macro avg       0.88      0.91      0.89     30000
weighted avg       0.91      0.90      0.90     30000



## MLP

In [17]:
param_grid = ParameterGrid({
    'hidden_layer_sizes': [256, 512, 1024]
})
train_model('mlp', param_grid, datasets)

### Test size: 0.2

#### {'hidden_layer_sizes': 256}



#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,291.43626,0.34376,0.908682
1,128.575114,0.302343,0.913621
2,331.683448,0.487561,0.927889
3,344.400716,0.502531,0.915706
4,351.603396,0.411484,0.925804


#### Testing Performance:

truth\pred,0,1,2
0,11296,391,229
1,411,3361,4
2,225,0,4083


              precision    recall  f1-score   support

           0       0.95      0.95      0.95     11916
           1       0.90      0.89      0.89      3776
           2       0.95      0.95      0.95      4308

    accuracy                           0.94     20000
   macro avg       0.93      0.93      0.93     20000
weighted avg       0.94      0.94      0.94     20000



#### {'hidden_layer_sizes': 512}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,187.88732,0.427686,0.882779
1,272.480736,0.464948,0.916255
2,183.663935,0.467801,0.909011
3,329.341082,0.531129,0.82801
4,519.627748,0.840124,0.890572


#### Testing Performance:

truth\pred,0,1,2
0,10213,532,1171
1,392,3367,17
2,9,0,4299


              precision    recall  f1-score   support

           0       0.96      0.86      0.91     11916
           1       0.86      0.89      0.88      3776
           2       0.78      1.00      0.88      4308

    accuracy                           0.89     20000
   macro avg       0.87      0.92      0.89     20000
weighted avg       0.91      0.89      0.89     20000



#### {'hidden_layer_sizes': 1024}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,126.591838,0.105391,0.920645
1,468.679843,0.946874,0.932609
2,496.547998,0.956378,0.933597
3,1212.859523,2.146674,0.881242
4,285.830395,0.734838,0.917792


#### Testing Performance:

truth\pred,0,1,2
0,11098,453,365
1,378,3392,6
2,100,1,4207


              precision    recall  f1-score   support

           0       0.96      0.93      0.94     11916
           1       0.88      0.90      0.89      3776
           2       0.92      0.98      0.95      4308

    accuracy                           0.93     20000
   macro avg       0.92      0.94      0.93     20000
weighted avg       0.94      0.93      0.93     20000



### Test size: 0.3

#### {'hidden_layer_sizes': 256}



#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,211.818545,0.249178,0.916833
1,173.406471,0.221564,0.928678
2,252.109869,0.396253,0.931047
3,243.266058,0.318046,0.929044
4,167.054414,0.24352,0.914952


#### Testing Performance:

truth\pred,0,1,2
0,15913,849,1058
1,534,5038,23
2,58,2,6525


              precision    recall  f1-score   support

           0       0.96      0.89      0.93     17820
           1       0.86      0.90      0.88      5595
           2       0.86      0.99      0.92      6585

    accuracy                           0.92     30000
   macro avg       0.89      0.93      0.91     30000
weighted avg       0.92      0.92      0.92     30000



#### {'hidden_layer_sizes': 512}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,251.327323,0.418798,0.865586
1,123.29928,0.156102,0.925062
2,300.08476,0.548615,0.901372
3,297.641335,0.483227,0.931662
4,237.269882,0.360077,0.914079


#### Testing Performance:

truth\pred,0,1,2
0,15983,617,1220
1,589,4981,25
2,47,1,6537


              precision    recall  f1-score   support

           0       0.96      0.90      0.93     17820
           1       0.89      0.89      0.89      5595
           2       0.84      0.99      0.91      6585

    accuracy                           0.92     30000
   macro avg       0.90      0.93      0.91     30000
weighted avg       0.92      0.92      0.92     30000



#### {'hidden_layer_sizes': 1024}

#### Training Performance:

Unnamed: 0,fit_time,score_time,test_score
0,623.438117,1.044393,0.860599
1,410.221161,0.745225,0.926933
2,535.622451,0.960546,0.901746
3,301.060634,0.531323,0.931288
4,269.034996,0.580155,0.92206


#### Testing Performance:

truth\pred,0,1,2
0,16527,557,736
1,634,4945,16
2,164,1,6420


              precision    recall  f1-score   support

           0       0.95      0.93      0.94     17820
           1       0.90      0.88      0.89      5595
           2       0.90      0.97      0.93      6585

    accuracy                           0.93     30000
   macro avg       0.92      0.93      0.92     30000
weighted avg       0.93      0.93      0.93     30000

