In [22]:
import math

from IPython.display import display
from ipywidgets import VBox, HBox, Layout
import ipywidgets as widgets
from traitlets import Float

from one_locus_two_alleles_simulator import GenotypicFreqs, INF

DEFAULT_POP_SIZE = 1000
MIN_POP_SIZE = 10


class GenoFreqsWidget(widgets.Box):
    freq_AA = Float().tag(sync=True)
    freq_Aa = Float().tag(sync=True)
    freq_aa = Float().tag(sync=True)

    def __init__(
        self,
        freq_AA: float,
        freq_Aa: float,
        freq_aa: float | None = None,
        pop_size: int | float = INF,
        *args,
        **kwargs
    ):
        if freq_aa is None:
            freq_aa = 1 - freq_AA - freq_Aa
        GenotypicFreqs(freq_AA, freq_Aa, freq_aa)

        self.pop_size = pop_size
        self.freq_AA = freq_AA
        self.freq_Aa = freq_Aa
        self.freq_aa = freq_aa

        self.freqs_slider = widgets.FloatRangeSlider(
            value=(freq_AA, freq_AA + freq_Aa),
            min=0.0,
            max=1.0,
            step=0.01,
            readout=False,
        )
        self.freqs_slider.layout = Layout(width="auto", flex="1 1 auto")
        self.text_AA = widgets.FloatText(
            value=round(freq_AA, 2), description="Freq AA:", disabled=True
        )
        self.text_Aa = widgets.FloatText(
            value=round(freq_Aa, 2), description="Freq Aa:", disabled=True
        )
        self.text_aa = widgets.FloatText(
            value=round(freq_aa, 2), description="Freq aa:", disabled=True
        )
        self.text_A = widgets.FloatText(
            value=round(freq_aa + 0.5*freq_Aa, 2), description="Freq A:", disabled=False, min=0.0, max=1.0, step=0.01
        )
        self.text_obs_het = widgets.FloatText(value=round(freq_Aa, 2), description="Obs het:", disabled=False, min=0.0, max=1.0, step=0.01)
        pop_size, pop_is_inf = self._infer_pop_widget_values_from_pop_size()
        self.pop_size_text = widgets.IntText(value=pop_size, description="Pop size")
        self.pop_is_inf_checkbox = widgets.Checkbox(
            value=pop_is_inf, description="inf."
        )

        self.pop_is_inf_checkbox.observe(self._update_inf_size_checkbox, names="value")
        self.pop_size_text.observe(self._update_pop_size_text, names="value")
        self.freqs_slider.observe(self._update_slider, names="value")
        self.text_A.observe(self._update_freq_A, names='value')
        self.text_obs_het.observe(self._update_obs_het, names='value')

        freqs_box = VBox(
            [
                HBox([self.pop_size_text, self.pop_is_inf_checkbox]),
                HBox([self.text_A, self.text_obs_het]),
                HBox([self.text_AA, self.text_Aa, self.text_aa]),
                HBox([self.freqs_slider]),
            ]
        )
        super().__init__(children=[freqs_box], *args, **kwargs)

    def _update_freq_A(self, change):
        old_A = self.freq_AA + self.freq_Aa * 0.5
        freq_A = change["new"]

        if math.isclose(old_A, freq_A):
            return

        if freq_A < 0:
            freq_A = 0
        if freq_A > 1:
            freq_A = 1
        freq_AA = freq_A - 0
        
        old_Aa = self.freq_Aa

        min_Aa = 0
        max_Aa = freq_A * 2
        if max_Aa  > 1:
            max_Aa = 1

        max_Aa = freq_A 
        if self.freq_Aa > max_Aa:
            self.freq_Aa = max_Aa
        elif self.freq_Aa < min_Aa:
            self.freq_Aa = min_Aa
        else:
            self.freq_Aa = old_Aa

        print(f'{freq_A=}')
        self.freq_AA = freq_A - (self.freq_Aa * 0.5)
        self.freq_aa = 1 - self.freq_AA - self.freq_Aa
        print(f'{self.freq_AA=} {self.freq_Aa=} {self.freq_aa=}')

        self.text_obs_het.value = round(self.freq_Aa, 2)
        self.text_A.value = round(freq_A)
        self._update_geno_freq_texts()

    def _update_obs_het(self, change):
        old_Aa = 1 - self.freq_AA - self.freq_aa
        freq_Aa = change["new"]

        if math.isclose(old_Aa, freq_Aa):
            return

        if freq_Aa < 0:
            freq_Aa = 0
        if freq_Aa > 1:
            freq_Aa = 1
        self.freq_Aa = freq_Aa

        min_A = self.freq_Aa * 0.5
        max_A = 1 - self.freq_Aa * 0.5
        if old_A < min_A:
            freq_A = min_A
        elif old_A > max_A:
            freq_A = max_A
        else:
            freq_A = old_A
        self.freq_AA = freq_A - self.freq_Aa * 0.5
        self.freq_aa = 1 - self.freq_Aa - self.freq_AA

        self.text_A.value = round(freq_A, 2)
        self.text_obs_het.value = round(self.freq_Aa, 2)
        self._update_geno_freq_texts()


    def _get_pop_size_values(self):
        if self.pop_is_inf_checkbox.value:
            return INF
        else:
            size = self.pop_size
            if size < 1:
                size = DEFAULT_POP_SIZE
            return size

    def _infer_pop_widget_values_from_pop_size(self) -> tuple[int, bool]:
        if math.isinf(self.pop_size):
            return 0, True
        if self.pop_size > 0:
            return self.pop_size, False
        raise RuntimeError("Fixme")

    def _update_pop_size_widgets(self):
        pop_size, pop_is_inf = self._infer_pop_widget_values_from_pop_size()
        self.pop_size_text.value = pop_size
        self.pop_is_inf_checkbox.value = pop_is_inf

    def _update_inf_size_checkbox(self, change):
        pop_is_inf = change["new"]
        if pop_is_inf:
            self.pop_size = INF
        else:
            self.pop_size = DEFAULT_POP_SIZE
        self._update_pop_size_widgets()

    def _update_pop_size_text(self, change):
        size = change["new"]
        if size > MIN_POP_SIZE:
            self.pop_size = size
        self._update_pop_size_widgets()

    def _update_slider(self, change):
        new_value = change["new"]
        self.freq_AA = new_value[0]
        self.freq_Aa = new_value[1] - self.freq_AA
        self.freq_aa = 1 - self.freq_AA - self.freq_Aa
        self._update_geno_freq_texts()

    def _update_geno_freq_texts(self):
        self.text_AA.value = round(self.freq_AA, 2)
        self.text_Aa.value = round(self.freq_Aa, 2)
        self.text_A.value = round(self.freq_aa + self.freq_Aa * 0.5, 2)

    def _get_value(self):
        return self.range_slider.value

    def _set_value(self, value: tuple[float, float]):
        self.range_slider.value = value

    value = property(_get_value, _set_value)


class OneLocusSimApp:
    def __init__(self):
        self.app_title = "One locus simulator"

    def _generate_layout(self):
        geno_freqs_widget = GenoFreqsWidget(0.3, 0.3)
        return VBox([geno_freqs_widget])

    def _setup_event_handlers(self):
        # self.datagrid.observe(self.plot_stock, names='selections')
        pass

    def run_application(self):
        self._setup_event_handlers()
        display(self._generate_layout())


app = OneLocusSimApp()
app.run_application()

VBox(children=(GenoFreqsWidget(children=(VBox(children=(HBox(children=(IntText(value=0, description='Pop size'…

freq_A=1.0
self.freq_AA=0.85 self.freq_Aa=0.3 self.freq_aa=-0.14999999999999997
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.freq_aa=1.0
freq_A=1.0
self.freq_AA=1.0 self.freq_Aa=0.0 self.freq_aa=0.0
freq_A=0.0
self.freq_AA=0.0 self.freq_Aa=0.0 self.fr

RecursionError: maximum recursion depth exceeded while calling a Python object