# Visual experiment tests

In [1]:
import numpy as np
import ipywidgets as wd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from enum import Enum, unique, auto

In [2]:
@unique
class Shape(Enum):
    TRIANGLE = 'triangle'
    RECTANGLE = 'rectangle'
    NOTHING = 'nothing'
    
    def get_shape(self):
        methods = {
            'triangle':self.get_triangle,
            'rectangle':self.get_rectangle,
            'nothing':self.get_nothing
        }
        return methods[self.value]
    
    @staticmethod
    def get_line():
        x = np.random.rand(2, 2)
        while(np.linalg.norm(x[:,0]-x[:,1]) < .3):
            x = np.random.rand(2, 2)
        return x
    
    @staticmethod
    def get_orthogonal(x):
        x_m = np.linalg.norm(x)
        y = np.random.rand(2)
        y = y - np.dot(x/x_m, y) * x/x_m
        return y/np.linalg.norm(y)
    
    def get_rectangle(self):
        x = self.get_line()
        y = self.get_orthogonal(x[:,0]-x[:,1])*(np.random.rand(1)/2+0.3)
        y1, y2 = (x[:,0]+y).reshape(-1, 1), (x[:,1]+y).reshape(-1, 1)
        x = np.concatenate([y1, x, y2], axis=1)
        return x
    
    def get_triangle(self):
        x = self.get_line()
        y = self.get_orthogonal(x[:,0]-x[:,1])*(np.random.rand(1)/2+0.3)
        y = (x[:,0]+x[:,1])/2 + y
        x = np.append(x, y.reshape(2, 1), axis=1)
        return x
    
    def get_nothing(self):
        return np.asarray([[], []])



class ShapeManager():
    def __init__(self, shape=(12,9), s=60, a=0.6, color='b', size=300):
        self.shapes = list(Shape)
        self.figsize = shape
        self.s, self.a, self.color = s, a, color
        self.num_of_points = size
        self.generated_shapes = np.asarray([])
        self.guessed_shapes = np.asarray([])
        self.incorrect = []
        self.show_lines = False
    
    # Method used to create a scatter plot of random points with hidden shape
    def create_shape(self):
        shape = self.choose_shape()
        for s in self.shapes:
            if s == shape:
                x = self.generate_points(shape.get_shape())
        noise = np.random.rand(2, self.num_of_points)
        self.plot_figure(noise, x)
        self.incorrect.append({'noise':noise, 'fig':x, 'real':shape.value})
        self.generated_shapes = np.append(self.generated_shapes, shape)
    
    def choose_shape(self): 
        return self.shapes[np.random.randint(0, len(self.shapes))]
    
    # Makes sure none of the points constituting the shape are too close to the edge
    def generate_points(self, func):
        x = func()
        while any(x.reshape(-1) > .9) or any(x.reshape(-1) < .1):
            x = func()
        return x
    
    def add_answer(self, answer):
        # If correct, remove from incorrect guess list
        if answer == self.generated_shapes[-1]: self.incorrect.pop()
        else: self.incorrect[-1]['guess'] = answer.value
        self.guessed_shapes = np.append(self.guessed_shapes, answer)
    
    def get_score(self):
        if len(self.guessed_shapes) > 0:
            return sum(self.generated_shapes[:-1] == self.guessed_shapes)/len(self.guessed_shapes)
        return 0
    
    def clear_score(self):
        self.generated_shapes = np.asarray([self.generated_shapes[-1]])
        self.guessed_shapes = np.asarray([])
        self.incorrect = [self.incorrect[-1]]
    
    # The lines connect the 'hidden' points constituting the shape
    @staticmethod
    def draw_lines(fig, x):
        size = x.shape[-1]
        for i in range(size):
            i = (np.arange(2) + i)%size
            fig.plot(x[0,i], x[1,i], linewidth=2, color='r')

    def plot_figure(self, noise, x):
        _, fig = plt.subplots(1, 1, figsize=self.figsize)
        fig.scatter(*noise, color=self.color, alpha=self.a)
        fig.scatter(*x, color=self.color, s=self.s)
        if self.show_lines: self.draw_lines(fig, x)
        fig.axis('off')
        plt.show()

In [3]:
buttons = [wd.Button(description=label) for label in ["Triangle", "Rectangle", "Nothing"]]
lines = wd.Checkbox(description="Lines", value=False, indent=False)
output = wd.Output(layout={})
display(lines, output, wd.HBox(buttons))
sm = ShapeManager()

with output:
    sm.create_shape()

def triangle_button_clicked(b):
    with output:
        clear_output()
        sm.add_answer(Shape.TRIANGLE)
        sm.create_shape()

def rectangle_button_clicked(b):
    with output:
        clear_output()
        sm.add_answer(Shape.RECTANGLE)
        sm.create_shape()

def nothing_button_clicked(b):
    with output:
        clear_output()
        sm.add_answer(Shape.NOTHING)
        sm.create_shape()
        
def line_switch(b):
    sm.show_lines = lines.value
        
buttons[0].on_click(triangle_button_clicked)
buttons[1].on_click(rectangle_button_clicked)
buttons[2].on_click(nothing_button_clicked)
lines.observe(line_switch)

Checkbox(value=False, description='Lines', indent=False)

Output()

HBox(children=(Button(description='Triangle', style=ButtonStyle()), Button(description='Rectangle', style=Butt…

In [4]:
print(f"Accuracy: {sm.get_score()}")

num_of_incorrect = len(sm.incorrect)-1
if num_of_incorrect:
    _, axs = plt.subplots(num_of_incorrect, 2, figsize=(26, 10*num_of_incorrect))
    axs = axs.reshape(-1, 2)
    
    for incorrect, ax in zip(sm.incorrect[:-1], axs[:,0]):
        ax.scatter(*incorrect['noise'], color=sm.color, alpha=sm.a)
        ax.scatter(*incorrect['fig'], color=sm.color, s=sm.s)
        ax.axis('off')
        ax.set_title('Guessed: {}'.format(incorrect['guess']), size=24)

    for incorrect, ax in zip(sm.incorrect[:-1], axs[:,1]):
        ax.scatter(*incorrect['noise'], color=sm.color, alpha=sm.a)
        ax.scatter(*incorrect['fig'], color=sm.color, s=sm.s)
        sm.draw_lines(ax, incorrect['fig'])
        ax.axis('off')
        ax.set_title('Real answer: {}'.format(incorrect['real']), size=24)

    plt.show()

Accuracy: 0
