# Imports

In [1]:
import scipy.io
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

RNDS = 1389

# Data

In [2]:
mat_data = scipy.io.loadmat('./data/mnistAll.mat')
mnist_data = mat_data['mnist']

In [3]:
# load images
train_images = mnist_data['train_images'][0, 0].T.squeeze()
test_images = mnist_data['test_images'][0, 0].T.squeeze()
train_labels = mnist_data['train_labels'][0, 0].T.squeeze()
test_labels = mnist_data['test_labels'][0, 0].T.squeeze()
# find idxs that are of 3 and 7
train_filter = np.where((train_labels == 3) | (train_labels == 7))[0].squeeze()
test_filter = np.where((test_labels == 3) | (test_labels == 7))[0]

# filter only images and labels of 3 and 7
train_images = train_images[train_filter]
train_labels = train_labels[train_filter]
test_images = test_images[test_filter]
test_labels = test_labels[test_filter]

# split train into train and val
train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=RNDS)

print(train_images.shape, val_images.shape, test_images.shape, train_labels.shape, val_labels.shape, test_labels.shape)

(9916, 28, 28) (2480, 28, 28) (2038, 28, 28) (9916,) (2480,) (2038,)


In [4]:
# scale inputs to be between 0 and 1
train_images = train_images.astype(float) / 255.0 # 255 because greyscale images
val_images = val_images.astype(float) / 255.0
test_images = test_images.astype(float) / 255.0

print(train_images[0], val_images[0], test_images[0])

# scale classes to be 0 and 1 as well
train_labels = np.where(train_labels == 3, 0, 1)
val_labels = np.where(val_labels == 3, 0, 1)
test_labels = np.where(test_labels == 3, 0, 1)

[[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.25098039 0.25098039 0.         0.
  0.         0.    