In [1]:
%%capture
!pip install ipympl

%matplotlib widget
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from enum import Enum
import time
import multiprocessing
# from signal import signal, SIGTERM

In [2]:
output = widgets.Output()
    
class Geno(Enum):
    dd = 0
    Dd = 1
    DD = 2

def geno_to_pheno(geno):
    pheno = np.array(geno)
    pheno[geno == Geno.dd] = 'white'
    pheno[geno != Geno.dd] = 'black'
    return pheno
    
with output:
    plt.rcParams["figure.figsize"] = (6, 4)
    plt.rcParams["figure.constrained_layout.use"] = True
    fig = plt.figure()
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.axvspan(0, 0.5, facecolor='grey', alpha=0.7, zorder=1)
    plt.axvspan(0.5, 1, facecolor='black', alpha=0.7, zorder=1)
    plt.axis('off')
    scat = plt.scatter([], [])
    fig.canvas.draw()

settings_layout=widgets.Layout(width='50%')
speed_slider = widgets.IntSlider(value=50, min=0, max=100, readout=False, layout=settings_layout)
num_mice_in = widgets.BoundedIntText(value=1000, min=0, max=10000, layout=settings_layout)
mutation_rate_in = widgets.BoundedFloatText(value=0.001, min=0, max=1, layout=settings_layout)
selection_intensity_in = widgets.BoundedFloatText(value=0.05, min=0, max=1, layout=settings_layout)
mortality_rate_in = widgets.BoundedFloatText(value=0.25, min=0, max=1, layout=settings_layout)
migration_rate_in = widgets.BoundedFloatText(value=0, min=0, max=1, layout=settings_layout)
labels = ['Simulation Speed', '# of Mice', 'Mutation Rate', 'Selection Intensity', 'Mortality Rate', 'Migration Rate']
settings_labels = widgets.VBox([widgets.Label(value=label) for label in labels])
settings = widgets.VBox([speed_slider, num_mice_in, mutation_rate_in, selection_intensity_in, mortality_rate_in, migration_rate_in])
simulate_button = widgets.Button(description='Simulate', layout=widgets.Layout(width='99%'))
reset_button = widgets.Button(description='Reset', layout=widgets.Layout(width='99%'))
clear_button = widgets.Button(description='Clear', layout=widgets.Layout(width='99%'))
control_bar = widgets.VBox([widgets.HBox([settings_labels, settings]), simulate_button, clear_button, reset_button])

def stop_simulation():
    if 'simulation_process' in globals():
        simulation_process.terminate()
    
@simulate_button.on_click
def start_simulation(_):
    def simulation_work(scat, plt, fig):
#         def cleanup(*args):
#             clear()
#             plt.clf()
#             fig.canvas.draw()
#             plt.rcParams["figure.figsize"] = (6, 4)
#             plt.rcParams["figure.constrained_layout.use"] = True
#             plt.xlim(0, 1)
#             plt.ylim(0, 1)
#             plt.axvspan(0, 0.5, facecolor='grey', alpha=0.7, zorder=1)
#             plt.axvspan(0.5, 1, facecolor='black', alpha=0.7, zorder=1)
#             plt.axis('off')
#             scat = plt.scatter([], [])
#             fig.canvas.draw()
#             sys.exit(0)
#         signal(SIGTERM, cleanup)
        settings = {
            "speed": speed_slider.value,
            "num_mice": num_mice_in.value,
            "mutation_rate": mutation_rate_in.value,
            "selection_intensity": selection_intensity_in.value,
            "mortality_rate": mortality_rate_in.value,
            "migration_rate": migration_rate_in.value
        }
        x, y = np.random.random((2, settings['num_mice']))
        geno = np.array([Geno.dd] * settings['num_mice']) # 0 = dd, 1 = Dd, 2 = DD
        pheno = geno_to_pheno(geno)
        scat = plt.scatter(x, y, c=pheno, s=10, zorder=2)
        scat.set_offsets(np.array((x, y)).T)
        for _ in range(100):
            print() # I have no idea why but the simulation only works when you print.
            x += (np.random.random(settings['num_mice']) - 0.5) * 0.1
            y += (np.random.random(settings['num_mice']) - 0.5) * 0.1
            scat.set_offsets(np.array((x, y)).T)
            fig.canvas.draw()
            time.sleep(0.5)
        fig.clear()
        fig.canvas.draw()
        
    global simulation_process # We only run one simulation process concurrently
    global scat, plt, fig
    clear_simulation()
    simulation_process = multiprocessing.Process(target=simulation_work, args=(scat, plt, fig))
    simulation_process.start()

def clear():
    plt.clf()
    fig.canvas.draw()
    plt.rcParams["figure.figsize"] = (6, 4)
    plt.rcParams["figure.constrained_layout.use"] = True
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.axvspan(0, 0.5, facecolor='grey', alpha=0.7, zorder=1)
    plt.axvspan(0.5, 1, facecolor='black', alpha=0.7, zorder=1)
    plt.axis('off')
    scat = plt.scatter([], [])
    fig.canvas.draw()
    
# This is necessary because the onclick decorator sets the variable to None (so other functions can't call it)
def clear_simulation():
    stop_simulation()
    clear()
    
@clear_button.on_click
def clear_simulation_onclick(_):
    clear_simulation()
    
@reset_button.on_click
def reset_simulation(_):
    stop_simulation()

widgets.HBox([output, control_bar])

HBox(children=(Output(), VBox(children=(HBox(children=(VBox(children=(Label(value='Simulation Speed'), Label(v…