# Visual experiment tests

In [1]:
import numpy as np
import ipywidgets as wd
import matplotlib.pyplot as plt
import pickle
from IPython.display import clear_output
from enum import Enum, unique, auto
from datetime import datetime
from random import choice

In [2]:
@unique
class Shape(Enum):
    TRIANGLE = 'triangle'
    RECTANGLE = 'rectangle'
    NOTHING = 'nothing'
    
    def get_shape(self):
        return getattr(self, f"get_{self.value}")
    
    @staticmethod
    def get_line():
        x = np.random.rand(2, 2)
        while(np.linalg.norm(x[:,0]-x[:,1]) < .3):
            x = np.random.rand(2, 2)
        return x
    
    @staticmethod
    def get_orthogonal(x):
        x_m = np.linalg.norm(x)
        y = np.random.rand(2)
        y = y - np.dot(x/x_m, y) * x/x_m
        return y/np.linalg.norm(y) # return orthonormal vector to x
    
    def get_rectangle(self):
        x = self.get_line()
        y = self.get_orthogonal(x[:,0]-x[:,1])*(np.random.rand(1)/2+0.3)
        y1, y2 = (x[:,0]+y).reshape(-1, 1), (x[:,1]+y).reshape(-1, 1)
        x = np.concatenate([y1, x, y2], axis=1)
        return x
    
    def get_triangle(self):
        x = self.get_line()
        y = self.get_orthogonal(x[:,0]-x[:,1])*(np.random.rand(1)/2+0.3)
        y = (x[:,0]+x[:,1])/2 + y
        x = np.append(x, y.reshape(2, 1), axis=1)
        return x
    
    def get_nothing(self):
        return np.asarray([[], []])

In [3]:
@unique
class ShapeType(Enum):
    ANY = 'any'
    SINGLE = 'single'
    MULTIPLE = 'multiple'
    INVISIBLE = 'invisible'
    
    # Refactor the points defining a shape into a specific type
    def adjust_data(self, noise, points):
        return getattr(self, f"{self.value}_shape_type")(noise, points)
    
    def any_shape_type(self, noise, points):
        return getattr(self, f"{choice(list(ShapeType)).value}_shape_type")(noise, points)
    
    def single_shape_type(self, noise, points):
        return noise, None
    
    def multiple_shape_type(self, noise, points):
        size = points.shape[-1]
        for i in range(size):
            interpolated = self.interpolate(points[:,i], points[:,(i+1)%size])
            extras = self.add_noise(self.select_points(interpolated))
            points = np.concatenate([points, extras], axis=1)
        return noise, points
    
    def interpolate(self, x, y):
        return np.linspace(x, y, int(np.linalg.norm(x - y)*100))
    
    def select_points(self, line):
        points = []
        for potential_point in line:
            if np.random.rand() > 0.75: points.append(potential_point)
        return np.transpose(points)
    
    @staticmethod
    def add_noise(points):
        return points + np.random.randn(*points.shape)*0.005
    
    def invisible_shape_type(self, noise, points):
        size = points.shape[-1]
        for i in range(size):
            interpolated = self.interpolate(points[:,i], points[:,(i+1)%size])
            noise = self.remove_points(noise, interpolated)
        return noise, np.asarray([[], []])
    
    def remove_points(self, noise, line):
        for local in line:
            ind = self.find_in_local(noise, local)
            if ind.size > 0: noise = np.delete(noise, ind, axis=1)
        return noise
    
    def find_in_local(self, noise, local):
        ind = []
        for i, x in enumerate(np.transpose(noise)):
            if self.in_local(x, local):
                ind.append(i)
        return np.asarray(ind)
    
    @staticmethod
    def in_local(point, local):
        margin = 0.02
        return point[0] > local[0] - margin and point[0] < local[0] + margin \
            and point[1] > local[1] - margin and point[1] < local[1] + margin

In [4]:
class FigureManager():
    def __init__(self, shape_type=None, shape=(12,10), s=60, a=0.6, color='b', size=300):
        self.type = shape_type if shape_type is not None else ShapeType.SINGLE
        self.figsize = shape
        self.s, self.a, self.color = s, a, color
        self.num_of_points = size
        self.show_lines = False
    
    # Method used to create a scatter plot of random points with hidden shape
    def create_shape(self):
        shape = choice(list(Shape))
        x = self.generate_points(shape.get_shape())
        noise = np.random.rand(2, self.num_of_points)
        noise, extras = self.type.adjust_data(noise, x)
        current = {'noise': noise, 'fig': x, 'extras': extras, 'real': shape.value, 'time': datetime.now()}
        self.plot_figure(noise, x, extras=extras)
        plt.show()
        return current, shape
    
    # Makes sure none of the points constituting the shape are too close to the edge
    def generate_points(self, func):
        x = func()
        while any(x.reshape(-1) > .9) or any(x.reshape(-1) < .1):
            x = func()
        return x
    
    def show_incorrect(self, incorrect):
        _, axs = plt.subplots(len(incorrect), 2, figsize=(26, 10*len(incorrect)))
        axs = axs.reshape(-1, 2)
        self.show_lines = False
        self.show_figures(incorrect, axs[:,0], 'guess', template='Guessed: {}')
        self.show_lines = True
        self.show_figures(incorrect, axs[:,1], 'real', template='Real answer: {}')
        plt.show()
            
    def show_figures(self, incorrect, axs, title_key, template="{}", font_size=24):
        for incorrect, ax in zip(incorrect, axs):
            self.plot_figure(incorrect['noise'], incorrect['fig'], extras=incorrect['extras'], fig=ax)
            ax.set_title(template.format(incorrect[title_key]), size=font_size)
    
    def line_switch(self, b):
        if type(b['new']) == bool: self.show_lines = b['new']
    
    # The lines connect the 'hidden' points constituting the shape
    @staticmethod
    def draw_lines(fig, x):
        size = x.shape[-1]
        for i in range(size):
            i = (np.arange(2) + i)%size
            fig.plot(x[0,i], x[1,i], linewidth=2, color='r')

    def plot_figure(self, noise, x, extras=None, fig=None):
        if fig is None: _, fig = plt.subplots(1, 1, figsize=self.figsize)
        if extras is None: extras = x
        fig.scatter(*noise, color=self.color, alpha=self.a)
        fig.scatter(*extras, color=self.color, s=self.s)
        if self.show_lines: self.draw_lines(fig, x)
        fig.axis('off')

In [5]:
class WidgetManager:
    def __init__(self):
        self.widgets = []
    
    def create_begin_button(self, experiment):
        self.e = experiment
        self.begin_button = wd.Button(description="Begin")
        self.output = wd.Output()
        display(self.output, self.begin_button)
        return self.begin_button
    
    def begin_button_clicked(self, b):
        with self.output:
            self.begin_button.layout.display = 'none'
    
    def create_buttons(self):
        buttons = [wd.Button(description=shape.value.title()) for shape in Shape]
        self.output = wd.Output()
        self.widgets += buttons
        return buttons
    
    def create_checkbox(self):
        lines = wd.Checkbox(description="Lines", indent=False)
        self.widgets.append(lines)
        return lines
    
    def place_layout(self):
        display(self.output, wd.HBox(self.widgets))
    
    def clear_widgets(self):
        for widget in self.widgets:
            widget.layout.display = 'none'

In [6]:
class Experiment:
    def __init__(self, mode='custom', **kwargs):
        self.mode = mode
        self.kwargs = kwargs
        self.wm = WidgetManager()
        self.num_of_completed = 0
        self.line_checkbox = False
        self.current = None
        self.incorrect = []
        self.generated_shapes, self.guessed_shapes, self.response_times = *np.asarray([[],[],[]]),
        self.create_experiment(mode, **kwargs)
    
    def create_experiment(self, mode, **kwargs):
        self.num_of_tests = 30
        if mode in ['default', 'single', 'multiple', 'invisible', 'debug']:
            args = getattr(self, f"{mode}_init")()
        else:
            args = self.custom_init(**kwargs)
        self.fm = FigureManager(**args)
    
    def default_init(self):
        return {'shape_type': ShapeType.ANY}
    
    def single_init(self):
        return {'shape_type': ShapeType.SINGLE, 's': 60, 'a': 0.6}
    
    def multiple_init(self):
        return {'shape_type': ShapeType.MULTIPLE, 's': None, 'a': None, 'size': 500}
    
    def invisible_init(self):
        return {'shape_type': ShapeType.INVISIBLE, 's': None, 'a': None, 'size': 1000}
    
    def debug_init(self):
        self.num_of_tests = 999999
        self.line_checkbox = True
        return {'shape_type': ShapeType.INVISIBLE, 's': None, 'a': None, 'size': 1000}
    
    def custom_init(self, **kwargs):
        if 'num_of_tests' in kwargs.keys(): self.num_of_tests = kwargs['num_of_tests']
        args = {}
        legal_args = ['shape_type', 'shape', 's', 'a', 'color', 'size']
        for key in set(kwargs.keys()).intersection(legal_args):
            args[key] = kwargs[key]
        return args
        
    def start(self):
        button = self.wm.create_begin_button(self)
        button.on_click(self.wm.begin_button_clicked)
        button.on_click(self.initialize_experiment)
        
    def initialize_experiment(self, b):
        with self.wm.output:
            buttons = self.wm.create_buttons()
            if self.line_checkbox: self.wm.create_checkbox().observe(self.fm.line_switch)
            self.wm.place_layout()
        
        for button, shape in zip(buttons, Shape):
            button.on_click(getattr(self, f"{shape.value}_button_clicked"))
        # Create a shape with the new output
        with self.wm.output:
            self.current, shape = self.fm.create_shape()
        self.generated_shapes = np.append(self.generated_shapes, shape)
    
    def triangle_button_clicked(self, b):
        self.button_clicked(Shape.TRIANGLE)

    def rectangle_button_clicked(self, b):
        self.button_clicked(Shape.RECTANGLE)

    def nothing_button_clicked(self, b):
        self.button_clicked(Shape.NOTHING)
    
    def button_clicked(self, shape):
        with self.wm.output:
            clear_output()
            if self.tests_completed():
                self.wm.clear_widgets()
                print('Tests completed')
            else:
                self.add_answer(shape)
                self.num_of_completed += 1 
                self.current, shape = self.fm.create_shape()
                self.generated_shapes = np.append(self.generated_shapes, shape)
    
    def tests_completed(self):
        return self.num_of_completed >= self.num_of_tests - 1
    
    def add_answer(self, answer):
        self.current['time'] = (datetime.now() - self.current['time']).total_seconds()
        # If incorrect, add to incorrect guess list
        if answer.value != self.current['real']:
            self.current['guess'] = answer.value
            self.incorrect.append(self.current)
        self.guessed_shapes = np.append(self.guessed_shapes, answer)
        self.response_times = np.append(self.response_times, self.current['time'])
        
    def results(self):
        print(f"Accuracy: {self.get_score():.3f}")
        if len(self.response_times) > 0:
            print(f"Mean response time: {self.response_times.mean():.3f}")
            print(f"Response time standard deviation: {self.response_times.std():.3f}")
        
        if len(self.incorrect) > 0: self.fm.show_incorrect(self.incorrect)
    
    def get_score(self):
        if len(self.guessed_shapes) > 0:
            return sum(self.generated_shapes[:-1] == self.guessed_shapes)/len(self.guessed_shapes)
        return 0
    
    def clear_score(self):
        self.generated_shapes = np.asarray([self.generated_shapes[-1]])
        self.guessed_shapes, self.response_times = *np.asarray([[],[]]),
        self.incorrect = []
        
    def pickle(self, path):
        with open(path, 'wb') as out:
            pickle.dump(self.mode, out, pickle.DEFAULT_PROTOCOL)
            pickle.dump(self.kwargs, out, pickle.DEFAULT_PROTOCOL)
            pickle.dump(self.generated_shapes, out, pickle.DEFAULT_PROTOCOL)
            pickle.dump(self.guessed_shapes, out, pickle.DEFAULT_PROTOCOL)
            pickle.dump(self.response_times, out, pickle.DEFAULT_PROTOCOL)
            pickle.dump(self.incorrect, out, pickle.DEFAULT_PROTOCOL)
    
    @staticmethod
    def from_pickle(path):
        with open(path, 'rb') as inp:
            mode = pickle.load(inp)
            kwargs = pickle.load(inp)
            e = Experiment(mode=mode, **kwargs)
            e.generated_shapes = pickle.load(inp)
            e.guessed_shapes = pickle.load(inp)
            e.response_times = pickle.load(inp)
            e.incorrect = pickle.load(inp)
        return e