In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".13"

In [2]:
import pandas as pd
from brain import BrainModel
from submodels import factories
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from random import shuffle
import random

from lib.score import (
    fate_corr, score_both_size_norm, shrink_and_align_stats, score_stats_norm
)
from lib.preprocess import *
from lib.callback import (
    cell_number_callback, progenitor_number_callback, neuron_number_callback,
    TargetPopulation, TagNumberCallback,
)
from lib.sde.grn.grn4 import GRNMain4 as GRNMain
from lib.sde.mutate import mutate_param3, mutate_tree

from lib.ga.utils import weighted_selection_one, normalize_fitness_values
from lib.utils import pick_best, pick_last
from jf.utils.export import Exporter
from jf.autocompute.jf import O, L
from itertools import product
import jf.models.stringmodel as sm
from lib.analyser import show_curve, show_curve_progenitor
from jf.models.stringmodel import read_model
from lib.utils import normalize_time, align_time

In [3]:
def _mutate_grn4(grn):
    one_gene = random.choice(grn.genes)
    mutate_gene4(one_gene)

def mutate_grn(grn):
    nb_mut = grn.var.get("nb_mut", 1)
    grn.set_mutable()
    for i in range(nb_mut):
        _mutate_grn4(grn)
    grn.compile()
    
def mutate_gene4(gene):
    r = random.random()
    param_prob = 0.05
    smooth_param_prob = 0.65

    if r < param_prob:
        mutate_param3(gene)

    elif r < (param_prob + smooth_param_prob):
        mutate_smooth_param3(gene)

    else:
        gene.tree = mutate_tree(gene.tree, gene.get_labels_not_in_tree())  # TODO add the labels here !!!
        
def mutate_smooth_param3(gene):
    r = random.uniform(0.8, 1.2)
    nb_param = 8
    if r < 1 / nb_param:
        gene.b = gene.b * r
    elif r < 2 / nb_param:
        gene.m = gene.m * r
    elif r < 3 / nb_param:
        gene.expr = gene.expr * (1 + (r - 1) * 0.1)
    elif r < 4 / nb_param:
        gene.deg = gene.deg * (1 + (r - 1) * 0.1)
    elif r < 5 / nb_param:
        gene.init = gene.init * r
    elif r < 6 / nb_param:
        gene.noise = gene.noise * r
    elif r < 7 / nb_param:
        gene.asym = gene.asym * r
    else:
        gene.theta = gene.theta * r

In [4]:
_count = -1
def provide_id():
    global _count
    _count += 1
    return _count

def init_provide_id(idx):
    global _count
    _count = idx + 1

In [5]:
REF = O(
    stats=pd.read_csv("reference/ref_tristate2.csv"),  # ref is a mean
)
SM_GEN = read_model("generation")

In [6]:
def individual_generator(id_=-1, cb_init=None):
    return Solution(GRNMain(5, 0, 1, generate_funcs=cb_init), id_=id_)

In [7]:
class Solution:
    def __init__(self, grn, id_=0, parent=-1):
        self.id = id_
        self.grn = grn
        self.parent = parent
        self.fit = -1
        self.stats = None
        
    def copy(self, id_=0):
        return Solution(self.grn.copy(), id_=id_, parent=self.id)
        
    def mutate(self):
        mutate_grn(self.grn)

In [8]:
def score_bb_size(bb, ref, *args, **kwargs):
    stats = bb.stats.copy()
    stats, ref_stats = align_time(stats, ref.stats)
    stats_p, ref_p = normalize_time(stats, ref_stats, "progenitor_pop_size", "progenitor_pop_size")
    stats_n, ref_n = normalize_time(stats, ref_stats, "neuron_pop_size", "neuron_pop_size",
                                    "progenitor_pop_size", "progenitor_pop_size")
    last_time_stats, last_time_ref = max(stats.time), max(ref_stats.time)
    
    ref_p = ref_p.set_index("time")
    ref_n = ref_n.set_index("time")
    stats_p = stats_p.set_index("time")
    stats_n = stats_n.set_index("time")
    
    prog = stats_p.loc[last_time_stats]["progenitor_pop_size"]
    neuron = stats_n.loc[last_time_stats]["neuron_pop_size"]
    
    ref_prog = ref_p.loc[last_time_ref]["progenitor_pop_size"]
    ref_neuron = ref_n.loc[last_time_ref]["neuron_pop_size"]

    return (prog - ref_prog)**2 + (neuron - ref_neuron)**2

In [9]:
def setup_tag(cp):
    indexes = list(cp.base_population.keys())
    shuffle(indexes)
    splits = np.array_split(indexes, 3)
    for i, ls in enumerate(splits):
        for idx in ls:
            cp.base_population[idx].tag["subbrain"] = i

In [10]:
def get_bb(prun, grn):
    ccls = factories["grn5"](grn=grn)
    callbacks = dict(
        progenitor_pop_size=progenitor_number_callback,
        whole_pop_size=cell_number_callback,
        neuron_pop_size=neuron_number_callback,
    )
    bb = BrainModel(time_step=0.5, verbose=False, start_population=prun.size, max_pop_size=5e2,
            cell_cls=ccls, end_time=prun.end_time, start_time=56, silent=True, opti=True,
              run_tissue=True, monitor_callbacks=callbacks, tag_func=setup_tag)
    return bb

In [11]:
def run_grn(prun, grn):
    get_bb(prun, grn)
    bb.run()
    return bb

In [12]:
def fitness_multistep(prun, grn, steps):
    total_fitness = 0
    stop = False
    previous_time = None
    bb = get_bb(prun, grn)
    # first step
    for step in steps:
        if not bb.run_until(step.end_time):
            stop = True
        
        score_step = step.score_func(bb, prun.ref, max_step=step.end_time, min_step=step.start_time)
        fitness_step = 1.0 / score_step
        fitness_step = min(fitness_step, step.max_fitness)
        total_fitness += fitness_step
        if fitness_step < step.min_fitness or stop:
            return total_fitness, bb.stats
        else:
            previous_time = step.end_time
            step.passed()
        
    return total_fitness, bb.stats

In [13]:
def do_init(prun):
    return individual_generator(provide_id(), prun.cb_init)

def do_fitness(prun, sol):
    fitness, stats = fitness_multistep(prun, sol.grn, prun.steps)
    return fitness, stats

def do_selection(prun, pop_fit, pop):
    if len(pop) < prun.min_pop:
        return individual_generator(provide_id(), prun.cb_init)
    
    pop_fit = normalize_fitness_values(pop_fit)
    
    return weighted_selection_one(pop, pop_fit, lambda x: individual_generator(x, prun.cb_init), new_fitness=0.05, id_=provide_id())[0]

def do_mutation(prun, sol):
    sol.mutate()
    return sol

In [14]:
class ObjectiveStep(O):
    start_time = 0
    end_time = 0
    max_fitness = 1e9
    min_fitness = 1
    name = ""
    _passed = False
    
    def reset(self):
        self._passed = False
    
    def passed(self):
        if self._passed:
            return
        print(f"Step {self.name} passed !")
        self._passed = True
    
example_steps = [
    # ObjectiveStep(name="1", start_time=50, end_time=53, score_func=score_bb_size, min_fitness=0.2),
    # ObjectiveStep(name="2", start_time=53, end_time=56, score_func=score_bb_size, min_fitness=0.2),
    ObjectiveStep(name="3", start_time=56, end_time=86, score_func=score_bb_size, min_fitness=0.2),
]

class ParamRun(O):
    pop_size = 50
    batch_size = 50
    n_gen = 50
    current_gen = 0
    end_time = 86
    ref = REF
    min_pop = 50
    max_pop = 50

def get_prun(size=7, exponent=1):
    prun = ParamRun()
    prun.cb_init = dict()
    prun.size = size
    prun.exponent = exponent
    prun.steps = example_steps
    return prun

In [15]:
def pick_last_exported(exporter):
    generations = list(filter(SM_GEN.match, exporter.list()))
    if len(generations) == 0:
        return None, 0
    
    last = max(generations, key=lambda x: int(SM_GEN.extract(x).get("generation")))
    b_gen = int(SM_GEN.extract(last).get("generation")) + 1
    exporter.print(f"Found generation {b_gen - 1}", "reload")
    pop = exporter.load(last)
    return pop, b_gen

In [16]:
def main(prun):
    prun.history = dict()
    exporter = Exporter(name=prun.name, copy_stdout=True)
    definition = """
    
    """
    exporter.print(definition, slot="definition")
    best = 0
    
    # setup
    pop, batch_gen = pick_last_exported(exporter)
    
    if pop is None:
        sol = do_init(prun)
        pop = [sol]
        batch_gen = 0
    else:
        sol = pop[-1]
        init_provide_id(sol.id)
        
    for i in range(batch_gen * prun.batch_size,
                   prun.n_gen * prun.batch_size):
        fit, stats = do_fitness(prun, sol)
        sol.fit, sol.stats = fit, stats
        
        if i % 100 == 0:
            exporter.print(f"Step {i}")
        if fit > best:
            exporter.print(f"++ Best {fit} for generation {i}")
            best = fit
            
        monitor = sol
        prun.history[i] = monitor
        # exporter(monitor, f"generation_g{generation}")
        
        sub_pop = pop[-prun.max_pop:]
        sol = do_selection(prun, [s.fit for s in sub_pop], sub_pop)
            
        sol = do_mutation(prun, sol)
        pop.append(sol)
        
        if (i + 1) % prun.batch_size == 0:
            print("Saving ...")
            batch_gen = (i + 1) // 100
            exporter(pop[-prun.max_pop:], SM_GEN.fill(generation=batch_gen))
        
    return best

In [17]:
callback_init = dict(
    init=lambda: np.random.beta(1.5, 3) * 3,
    b=lambda: np.random.beta(1.5, 3) * 5,
    expr=lambda: 1,
    deg=lambda: 0.1,
    noise=lambda: np.random.beta(1.5, 3) * 1,
    asym=lambda: 5,
)

In [None]:
prun = get_prun()
prun.cb_init = callback_init
prun.name = "check_simple_obj_r9"
res = main(prun)

Exporting at output/check_simple_obj_r9
[definition] 
    
    
[out] Step 0
[out] ++ Best 0.0014571900538777544 for generation 0
[out] ++ Best 0.0017698049289881278 for generation 1
[out] ++ Best 0.002404714676657859 for generation 4
[out] ++ Best 0.002571669222232167 for generation 42
[out] ++ Best 0.006092948833525033 for generation 49
Saving ...
[out] ++ Best 0.009495417467988142 for generation 53
[out] ++ Best 0.010782570350163022 for generation 66
[out] ++ Best 0.011405410581612178 for generation 92
Saving ...
[out] Step 100
[out] ++ Best 0.013058321639430982 for generation 117
[out] ++ Best 0.0145374429912541 for generation 127
[out] ++ Best 0.03509728710509213 for generation 129
Saving ...
Step 3 passed !
[out] ++ Best 1.3981797775024547 for generation 180
[out] ++ Best 1.8790360073484338 for generation 181
Saving ...
[out] Step 200
[out] ++ Best 1.9459370529537965 for generation 247
Saving ...
[out] ++ Best 2.084627899471055 for generation 279
[out] ++ Best 2.293576765522983 f

In [None]:
exp = Exporter(name=prun.name)

In [None]:
sm = read_model("generation")

In [None]:
res = pick_best(exp.load(pick_last(exp)))

In [None]:
for i in range(5):
    print(fitness_multistep(prun, res.grn, prun.steps))

In [None]:
show_curve(res.stats, REF.stats)

In [None]:
for g in sorted(filter(sm.match, exp.list()), key=lambda x: int(sm.extract(x)["generation"])):
    gen = exp.load(g)
    res = pick_best(gen)
    print()
    print(f"======= GENERATION {g} =======")
    print(res.grn)