In [48]:
import csv
import ast
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [49]:
from sklearn import svm

In [50]:
from sklearn.neural_network import MLPClassifier

In [2]:
import collections

In [3]:
from sklearn.metrics import confusion_matrix, classification_report, f1_score

In [4]:
features_key = "features"
label_key = "label"
csv_file_path = 'num_recognition_training_data.csv'

def get_csv_data():
        train_labels = []
        train_features = []
        with open(csv_file_path) as csvfile:
            data_reader = csv.reader(csvfile, delimiter=',')
            for row in data_reader:
                train_labels.append(row[0])
                features = ast.literal_eval(row[1])
                train_features.append(features)

        return {
            label_key: train_labels,
            features_key: train_features
        }

In [25]:
labels, features = get_csv_data().values()

In [57]:
def train_data(labels, features, solver='SVC', **kwargs):
    
    if solver == 'SVC':
        classifier = svm.SVC(**kwargs)
    else:
        print('using other than SVC')
        classifier = MLPClassifier()
        
    classifier.fit(labels, features)
    
    return classifier

In [40]:
def cross_validate(classifier, to_predict_features, to_predict_labels):
    # make prediction
    prediction = classifier.predict(to_predict_features)
    
    # check validity
    return confusion_matrix(to_predict_labels, prediction), \
        classification_report(to_predict_labels, prediction), \
        f1_score(to_predict_labels, prediction, average='weighted')

In [60]:
def make_model(labels, features):
    # convert to numpy array
    labels = np.array(labels, dtype='uint8')
    features = np.array(features, dtype='float_')
    
    # create random state to maintain randomness
    rs = np.random.RandomState(seed=1234567890)
    
    # randomize the data
    rs.shuffle(labels)
    rs.shuffle(features)

    total_test = len(labels) * 5 // 6
    
    # check occurences
    test_occ = sorted(collections.Counter(labels[:total_test]).items())
    cv_occ = sorted(collections.Counter(labels[total_test:]).items())
    
    print('Train labels: {}\nCV    labels: {}\n\n'.format(test_occ, cv_occ))
    
    classifier = train_data(features[:total_test], labels[:total_test], solver='NN')
    matrix, report, f1_score = cross_validate(classifier, features[total_test:], labels[total_test:])
    print(matrix)
    print('\n')
    print(report)
    
    costs = np.linspace(10, 100, num=5) 
    gammas = np.linspace(0.05, 0.06, num=5)
    scores = []
    
#     for cost in costs:
#         for gamma in gammas:
#             classifier = train_data(features[:total_test], labels[:total_test], solver='NN', C=cost, gamma=gamma)

#             matrix, report, f1_score = cross_validate(classifier, features[total_test:], labels[total_test:])
#             scores.append((cost, gamma, f1_score))
# #             print('Cost: {}\nGamma: {}\nScore: {}\n\n'.format(cost, gamma, f1_score))
# #             print(matrix)
# #             print('\n')
# #             print(report)
    
#     fig = plt.figure()
#     ax = fig.add_subplot(111, projection='3d')
    
#     x, y, z = zip(*scores)
# #     print(x)
# #     print(y)
#     print(z)
    
#     ax.scatter(x, y, z, c='r', marker='o')

#     ax.set_xlabel('C')
#     ax.set_ylabel('Gamma')
#     ax.set_zlabel('F1_score')

#     plt.show()

In [61]:
make_model(labels, features)

Train labels: [(0, 64), (1, 32), (2, 80), (3, 48), (4, 66), (5, 52), (6, 61), (7, 86), (8, 87), (9, 90)]
CV    labels: [(0, 13), (1, 7), (2, 22), (3, 6), (4, 9), (5, 16), (6, 13), (7, 18), (8, 17), (9, 13)]


using other than SVC
[[1 0 2 1 1 0 1 2 3 2]
 [1 1 1 0 1 0 0 1 0 2]
 [2 0 2 0 1 2 3 3 2 7]
 [1 0 2 1 1 0 1 0 0 0]
 [1 0 0 1 1 1 1 1 1 2]
 [2 0 0 1 1 0 3 5 2 2]
 [1 0 2 2 0 0 1 4 1 2]
 [1 3 4 0 0 5 0 3 1 1]
 [2 0 1 0 4 1 3 2 3 1]
 [2 1 0 0 5 0 0 1 1 3]]


             precision    recall  f1-score   support

          0       0.07      0.08      0.07        13
          1       0.20      0.14      0.17         7
          2       0.14      0.09      0.11        22
          3       0.17      0.17      0.17         6
          4       0.07      0.11      0.08         9
          5       0.00      0.00      0.00        16
          6       0.08      0.08      0.08        13
          7       0.14      0.17      0.15        18
          8       0.21      0.18      0.19        17
      

