In [None]:
import numpy as np
import multiprocessing as mp
from multiprocessing import Pool, cpu_count
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import fetch_openml
import os.path
import cv2
import matplotlib.pyplot as plt
from ctypes import *
import sys
import threading

In [None]:
def sort_by_target(mnist):
    reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
    reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
    mnist.data[:60000] = mnist.data[reorder_train]
    mnist.target[:60000] = mnist.target[reorder_train]
    mnist.data[60000:] = mnist.data[reorder_test + 60000]
    mnist.target[60000:] = mnist.target[reorder_test + 60000]
    
def saveImages(mnist):
  for i in range(len(mnist)):
    digit = mnist[i].reshape(28, 28)
    img = PIL.Image.fromarray(digit, mode="L")
    img.save("mnist{}.png".format(i + 1), mode = "L")
    from google.colab import files
    files.download("mnist/images/mnist{}.png".format(i + 1)) 
    
def saveLabels(labels):
    f = open("mnist/labels/labels.txt", "r")
    for i in range(len(labels)):
        f.write("{},".format(labels[i]))
    f.close()

        

In [None]:
SO_DIRPATH = "../libs/"

In [None]:
# Loop Counter
SO_FILE_COUNTER = SO_DIRPATH + 'libcounter.so'
counter_c = CDLL(SO_FILE_COUNTER)
counter_c.python_loop_count.argtypes = [np.ctypeslib.ndpointer(dtype=c_ubyte, flags="C_CONTIGUOUS"), c_int, c_int]
counter_c.python_loop_count.restypes = c_int

In [None]:
def loop_count_c(img, nx, ny):
    shape = np.shape(img)
    if not isinstance(img, np.ndarray): 
        raise TypeError("Image must be a ndarray, get {}".format(type(img).__name__))
    assert (len(shape) == 1), "Image must be serialized"
    
    try:
        cnt = counter_c.python_loop_count(img, nx, ny)
    except Exception as e:
        print("Exception occured: {}".format(e))
        print("If error caused by undefined counter. Make sure to load counter_c library first before running this function\n")
        print("Make sure to defined the restype and argtype of python_loop_count")
        exit(-1)
        
    return cnt

In [None]:
def readImages(path):
    global n_finished
    a = cv2.imread(path, 0)
    with n_finished.get_lock():
        n_finished.value += 1;
    return a

def readTrain(X):
    IMG_TRAIN = "../../mnist/images/train_images"
    LABELS_TRAIN = "../../mnist/labels/train_labels/train.txt"
    # read labels
    with open(LABELS_TRAIN, "r") as f:
        labels = f.read()
    labels = list(map(lambda x: int(x), labels.split(",")))
    assert len(labels) == 60000, "Corrupted labels. Expected {} got {}".format(60000, len(labels))
    data = [os.path.join(IMG_TRAIN, "number{}.png".format(i + 1)) for i in range(60000)]
    with Pool(cpu_count()) as pool:
        train_img = pool.map(readImages, data)
    X[0] = train_img
    X[1] = np.array(labels).reshape(-1, 1)

In [None]:
train = readTrain()

In [None]:
n_finished = mp.Value('i', 0)
X = [None, None]
t = threading.Thread(target = readTrain, args=(X, ))
t.start()
while t.is_alive():
    sys.stdout.write('\r' + "n_finished={}".format(n_finished.value))
    sys.stdout.flush()

In [None]:
train_images = X[0]
labels = X[1]

In [None]:
plt.imshow(train_images[5370], cmap='gray')

In [None]:
def loop(some_digit):
    global n_finished
    shape = some_digit.reshape(1, -1)[0]
    cnt = loop_count_c(shape, 28, 28) 
    with n_finished.get_lock():
        n_finished.value += 1;
    print(cnt)

In [None]:
def process(process_fnc, data, return_value):
    with Pool(cpu_count()) as pool:
        pool.map(process_fnc, data)

In [None]:
n_finished = mp.Value('i', 0)
X = [] # count 
t = threading.Thread(target = process, args=(loop, train_images, X, ))
t.start()
while t.is_alive():
    sys.stdout.write('\r' + "n_finished={}".format(n_finished.value))
    sys.stdout.flush()

In [None]:
X_train, y_train = X, train[1]

In [None]:
# check loop count
for i in range(60000):
    print(X_train[i][784], y_train[i][0])