In [2]:
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
from collections import Counter

# Define the expected feature length
expected_length = 42

# Load the data
data_dict = pickle.load(open('./data.pickle', 'rb'))

# Convert data and labels to consistent arrays by ensuring fixed feature length
data = [np.array(sample) for sample in data_dict['data'] if len(sample) == expected_length]
labels = [label for i, label in enumerate(data_dict['labels']) if len(data_dict['data'][i]) == expected_length]

# Convert to NumPy arrays
data = np.array(data)
labels = np.array(labels)

# Check the distribution of labels
label_counts = Counter(labels)
print(label_counts)

# Identify and remove classes with only one sample
classes_to_remove = [label for label, count in label_counts.items() if count == 1]
data_filtered = [sample for i, sample in enumerate(data) if labels[i] not in classes_to_remove]
labels_filtered = [label for label in labels if label not in classes_to_remove]

# Convert to NumPy arrays again after filtering
data_filtered = np.array(data_filtered)
labels_filtered = np.array(labels_filtered)

# Split the filtered data
x_train, x_test, y_train, y_test = train_test_split(data_filtered, labels_filtered, test_size=0.2, shuffle=True, stratify=labels_filtered)

# Train the classifier
model = RandomForestClassifier()
model.fit(x_train, y_train)

# Predict and calculate accuracy
y_predict = model.predict(x_test)
score = accuracy_score(y_predict, y_test)

print('{}% of samples were classified correctly!'.format(score * 100))

# Save the model
with open('model.p', 'wb') as f:
    pickle.dump({'model': model}, f)


Counter({'4': 199, '29': 195, '23': 194, '30': 194, '17': 192, '3': 191, '31': 189, '7': 186, '8': 181, '20': 179, '2': 178, '5': 166, '6': 166, '0': 115, '1': 115, '26': 111, '9': 84, '11': 81, '14': 69, '10': 67, '16': 34, '25': 34, '22': 27, '28': 26, '18': 23, '12': 22, '27': 22, '13': 21, '32': 20, '24': 15, '15': 11, '34': 10, '21': 9, '19': 3, '33': 1})
93.24324324324324% of samples were classified correctly!
