In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
from kernelcanvas.kernelcanvas2 import KernelCanvas2
import wisardpkg as wp

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

In [2]:
MNIST = fetch_openml('mnist_784', version=1, as_frame=False)

In [3]:
# X = MNIST.data[:10000]
# y = MNIST.target[:10000]
X = MNIST.data
y = MNIST.target

X_reshaped = X.reshape(-1, 28, 28)
X_reshaped[0].shape
X_train, X_test, y_train, y_test = train_test_split(X_reshaped, y, test_size=0.2, random_state=42)

In [4]:
class KernelCanvas (KernelCanvas2):
    def __init__(self, shape: tuple[int, int], numberOfKernels: int, bitsPerKernel: int = 1, activationDegree: float = 0.07):
        KernelCanvas2.__init__(self, shape, numberOfKernels, bitsPerKernel, activationDegree)
        
    def showCanvas(self):
        canvas = np.ones((*(self.getShape()), 3))  # All white squares, RGB

        # Paint the kernel points red
        for row, col in self.getKernelPoints():
            canvas[row, col] = [1, 0, 0]  # Red color (R=1, G=0, B=0)
        
        fig, ax = plt.subplots()
        ax.imshow(canvas, vmin=0, vmax=1)
        
         # Add numbers to each square
        closest_kernels = self.getClosestKernel()
        for i in range(self.getShape()[0]):
            for j in range(self.getShape()[1]):
                ax.text(j, i, str(closest_kernels[i][j]),
                        va='center', ha='center', color='black', fontsize=6)

        plt.show()
        
    def saveCanvas(self):
        canvas = np.ones((*(self.getShape()), 3))  # All white squares, RGB

        # Paint the kernel points red
        for row, col in self.getKernelPoints():
            canvas[row, col] = [1, 0, 0]  # Red color (R=1, G=0, B=0)
        
        fig, ax = plt.subplots()
        ax.imshow(canvas, vmin=0, vmax=1)
        
         # Add numbers to each square
        closest_kernels = self.getClosestKernel()
        for i in range(self.getShape()[0]):
            for j in range(self.getShape()[1]):
                ax.text(j, i, str(closest_kernels[i][j]),
                        va='center', ha='center', color='black', fontsize=6)
                
        # Find next available filename
        os.makedirs("./img2", exist_ok=True)
        i = 1
        while os.path.exists(f"./img2/canvas_{i}.png"):
            i += 1
        plt.savefig(f"./img2/canvas_{i}.png")
        plt.close(fig)

        
    
    def showTransformed(self, TransformedData):
        if len(TransformedData) != self.getNumberOfKernels():
            raise ValueError(f"Transformed data length {len(TransformedData)} does not match number of kernels {self.getNumberOfKernels()}")

        canvas = np.ones((*(self.getShape()), 3))  # All white squares, RGB

        closest_kernels = self.getClosestKernel()

        # Paint the activated kernels
        for i in range(self.getShape()[0]):
            for j in range(self.getShape()[1]):
                if(TransformedData[closest_kernels[i][j]]):
                    canvas[i, j] = [0, 0, 0]


        # Paint the kernel points red
        for row, col in self.getKernelPoints():
            canvas[row, col] = [1, 0, 0]  # Red color (R=1, G=0, B=0)
            
        fig, ax = plt.subplots()
        ax.imshow(canvas, vmin=0, vmax=1)
        plt.show()
        

In [5]:
kc = KernelCanvas((28, 28), 128)

In [6]:
def drawFig(idx):
    plt.imshow(X_reshaped[idx], cmap='gray')
    plt.title(f"Label: {y[idx]}")
    plt.axis('off')
    plt.show()
    
def accuracy(y_pred, y_target):
  return np.where(y_pred == y_target, 1, 0).sum()/len(y_target)

In [7]:
# print(kc.transform(X_reshaped[0]))
# drawFig(0)
# kc.showTransformed(kc.transform(X_reshaped[0]))
# kc.saveCanvas()

In [8]:
train_points = [kc.transform(x) for x in X_train]
test_points = [kc.transform(x) for x in X_test]

In [9]:
train_ds = wp.DataSet(train_points, y_train)
test_ds = wp.DataSet(test_points, y_test)

In [11]:
model = wp.Wisard(16, verbose=True)
model.train(train_ds)
pred = model.classify(test_ds)
accuracy(pred, y_test)

np.float64(0.817)

In [12]:
canvas = KernelCanvas((28, 28), 128, bitsPerKernel=1, activationDegree=0.07)
# canvas.showCanvas()
canvas.saveCanvas()
model = wp.Wisard(16, verbose=True)
gen_cnt = 0

In [None]:
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor
import copy
from time import time
# import gc


generations = 1000

def transform_train(aux):
    return [aux.transform(x) for x in X_train]

def transform_test(aux):
    return [aux.transform(x) for x in X_test]

try:
    for i in range(generations):
        generation_time_start = time()
        # print(f"Generation {i+1}")
        # genePool = [canvas]
        # gc.collect()
        genePool = [copy.deepcopy(canvas) for _ in range(20)]
        
        for canvas in genePool[1:]:
            canvas.mutateKernel(mutationFactor=np.random.uniform(0.008, 0.5))
        
        # for _ in range(10):
        #     genePool.append(KernelCanvas((28, 28), 128, bitsPerKernel=1, activationDegree=0.07))
        
            
        transform_start_time = time()
        with ProcessPoolExecutor() as executor:
            train_points_list = list(executor.map(transform_train, genePool))
        with ProcessPoolExecutor() as executor:
            test_points_list = list(executor.map(transform_test, genePool))
        transform_stop_time = time()
        print(f"Transformation time: {transform_stop_time - transform_start_time:.2f} seconds")
        
        train_ds_list = [wp.DataSet(train_points, y_train) for train_points in train_points_list]
        test_ds_list = [wp.DataSet(test_points, y_test) for test_points in test_points_list]

        best_accuracy = 0
        best_model_index = 0
        for j, train_ds in enumerate(train_ds_list):
            # genePool[j].showCanvas()
            # print(f"Training model with canvas {j+1}")
            model.train(train_ds)
            pred = model.classify(test_ds_list[j])
            acc = accuracy(pred, y_test)
            # print(f"Accuracy with canvas {j+1}: {acc}")
            if acc > best_accuracy:
                best_accuracy = acc
                best_model_index = j
            model.untrain(train_ds)

        generation_time_stop = time()
        print(f"Generation {gen_cnt} time: {generation_time_stop - generation_time_start:.2f} seconds")
        print(f"Accuracy generation {gen_cnt}: {best_accuracy} (canvas {best_model_index+1})")
        print("------------------------------------------------------\n\n")
        canvas = genePool[best_model_index]  # Keep the best canvas for the next generation
        gen_cnt += 1
        if(best_model_index != 0):
            # canvas.showCanvas()
            canvas.saveCanvas()
        
except KeyboardInterrupt:
    print("Training interrupted by user.")
    print(f"Final accuracy: {best_accuracy}")
    # canvas.showCanvas()
    # canvas.saveCanvas()
    
    
    
canvas.showCanvas()

Transformation time: 8.23 seconds
Generation 0 time: 14.63 seconds
Accuracy generation 0: 0.8227857142857142 (canvas 8)
------------------------------------------------------


Transformation time: 9.82 seconds
Generation 1 time: 16.37 seconds
Accuracy generation 1: 0.8275714285714286 (canvas 15)
------------------------------------------------------


Transformation time: 10.42 seconds
Generation 2 time: 15.02 seconds
Accuracy generation 2: 0.8275714285714286 (canvas 1)
------------------------------------------------------


Transformation time: 11.85 seconds
Generation 3 time: 18.33 seconds
Accuracy generation 3: 0.8301428571428572 (canvas 7)
------------------------------------------------------


Transformation time: 9.85 seconds
Generation 4 time: 14.27 seconds
Accuracy generation 4: 0.8337857142857142 (canvas 18)
------------------------------------------------------


Transformation time: 9.95 seconds
Generation 5 time: 16.45 seconds
Accuracy generation 5: 0.8340714285714286 (c