In [17]:
import sys
if sys.version_info[0] < 3:
	raise Exception("Python 3 not detected.")
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from scipy import io


for data_name in ["mnist", "spam", "cifar10"]:
    data = np.load(f"../data/{data_name}-data.npz")
    print("\nloaded %s data!" % data_name)
    fields = "test_data", "training_data", "training_labels"
    for field in fields:
        print(field, data[field].shape)



loaded mnist data!
test_data (10000, 1, 28, 28)
training_data (60000, 1, 28, 28)
training_labels (60000,)

loaded spam data!
test_data (1000, 32)
training_data (4172, 32)
training_labels (4172,)

loaded cifar10 data!
test_data (10000, 3072)
training_data (50000, 3072)
training_labels (50000,)


In [20]:
# Q2: Data Partitioning

# Returns indices for training_labels and training_data that are set aside for validation and for training.
def shuffle_partition(data_name, count):
    data = np.load(f"../data/{data_name}-data.npz")
    indices = np.arange(0, len(data["training_labels"]), 1)
    np.random.shuffle(indices)
    return indices[0:count], indices[count:]
    
# 10,000 training data set aside for validation.
mnist_validation, mnist_training = shuffle_partition("mnist", 10000)
print(mnist_training)
# 20% of training data set aside for validation.
spam_validation, spam_training = shuffle_partition("spam", 834)
# 5,000 training data set aside for validation.
cifar10_validation, cifar10_training = shuffle_partition("cifar10", 5000)


[18082 39568 44943 ...  4828 11047 53788]


In [23]:
# Q3: Support Vector Coding

# Train Model
def train_svm(data_name, training_indices, num_train):
    clf = svm.LinearSVC()
    data = np.load(f"../data/{data_name}-data.npz")
    Y = []
    for i in range(num_train):
        index = mnist_training[i]
        Y.append(data["training_labels"][index])
    #clf.fit(X, Y)

# Train MNIST using raw pixels as features
train_svm("mnist", mnist_training, 100)




[4, 8, 1, 3, 0, 8, 4, 2, 6, 5, 6, 4, 8, 5, 9, 7, 5, 0, 2, 1, 7, 4, 1, 7, 8, 1, 8, 0, 2, 4, 5, 4, 3, 9, 8, 5, 9, 5, 7, 8, 3, 9, 5, 1, 2, 0, 8, 8, 3, 8, 5, 1, 0, 5, 2, 8, 2, 0, 7, 9, 5, 4, 2, 5, 8, 7, 7, 1, 0, 4, 3, 4, 2, 7, 4, 9, 8, 0, 1, 7, 3, 6, 8, 9, 7, 6, 5, 7, 5, 9, 5, 8, 7, 6, 6, 2, 4, 8, 9, 6]
