In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".13"

In [2]:
import pandas as pd
from model import Brain
from submodels import factories
import matplotlib.pyplot as plt
import pandas as pd
from itertools import accumulate
import numpy as np
from collections import defaultdict
import re

from jf.db import DB
from lib.score import fate_corr, score_both_size_norm
from lib.preprocess import get_fmetric_pairs
from lib.sde.grn3 import GRNMain3
from lib.sde.mutate import mutate_grn2
from lib.ga.utils import weighted_selection
from lib.ga.objective import Objective
from jf.profiler import Profiler
from jf.utils.export import Exporter
from jf.autocompute.jf import O

In [3]:
HISTORY = defaultdict(dict)
HALL_OF_FAME = []

In [4]:
REF = O(
    stats=pd.read_csv("output/results/setup_basic/export/ref_basic2.csv"),  # ref is a mean
    # fmetric=setup_ref_fmetric("output/results/setup_basic/export/ref_fmetric_tristate.csv"),
)

In [5]:
def individual_generator():
    return Solution(GRNMain3(5, 0, 0))

In [6]:
class Solution:
    def __init__(self, grn):
        self.grn = grn
        
    def copy(self):
        return Solution(self.grn.copy())
        
    def mutate(self):
        for i in range(1):
            mutate_grn2(self.grn)
        # here force values
        self.grn.set_mutable()
        self.grn._params[4, :] = 1
        self.grn._params[5, :] = 0
        self.grn.compile()

In [7]:
def score_bb_size(bb, ref, *args, **kwargs):
    return score_both_size_norm(bb.stats, ref.stats, *args, **kwargs, norm=2.0)

In [8]:
def get_bb(prun, grn):
    ccls = factories["grn3"](grn=grn)
    bb = Brain(time_step=0.5, verbose=False, start_population=3, max_pop_size=1e3,
            cell_cls=ccls, end_time=prun.end_time, start_time=50, silent=True,
              run_tissue=False)
    return bb

In [9]:
def run_grn(prun, grn):
    get_bb(prun, grn)
    bb.run()
    return bb

In [10]:
def fitness_func(prun, grn, score_func):
    bb = run_grn(prun, grn)
    output = score_func(bb.stats, REF, max_step=prun.end_time)
    fitness = 1.0 / output
    return fitness

In [11]:
def fitness_multistep(prun, grn, steps):
    total_fitness = 0
    stop = False
    previous_time = None
    bb = get_bb(prun, grn)
    # first step
    for step in steps:
        if not bb.run_until(step.end_time):
            stop = True
        # score_step = score_both_size(bb.stats, prun.ref, max_step=step.end_time, min_step=previous_time)
        score_step = step.score_func(bb, prun.ref, max_step=step.end_time, min_step=step.start_time)
        fitness_step = 1.0 / score_step
        fitness_step = min(fitness_step, step.max_fitness)
        total_fitness += fitness_step
        if fitness_step < step.min_fitness or stop:
            return total_fitness, bb.stats
        else:
            previous_time = step.end_time
            step.passed()
        
    return total_fitness, bb.stats

def score_multistep(prun, stats, steps):
    total_fitness = 0
    stop = False
    previous_time = None
    # first step
    for i, step in enumerate(steps):
        # score_step = score_both_size(bb.stats, prun.ref, max_step=step.end_time, min_step=previous_time)
        score_step = step.score_func(bb, prun.ref, max_step=step.end_time, min_step=step.start_time)
        print(f"Score for step {i} is {score_step}")
        fitness_step = 1.0 / score_step
        fitness_step = min(fitness_step, step.max_fitness)
        total_fitness += fitness_step
        if fitness_step < step.min_fitness or stop:
            return total_fitness
        previous_time = step.end_time
        
    return total_fitness

In [12]:
def mean_sd_fitness(prun, grn, run=3):
    fitnesses = [fitness_multistep(prun, grn) for i in range(run)]
    return np.mean(fitnesses), np.std(fitnesses)

In [13]:
def multi_fitness(*args):
    fitnesses = [fitness_multistep(*args) for i in range(3)]
    scores = [x[0] for x in fitnesses]
    return fitnesses[scores.index(min(scores))]

In [14]:
def do_init_pop(prun):
    return [individual_generator() for i in range(prun.pop_size)]

def do_fitness(prun, pop):
    # fitness = [fitness_func(prun, sol.grn, score_func) for sol in pop]
    # fitness = [fitness_strategy(prun, sol.grn) for sol in pop]
    fitness, stats = zip(*[multi_fitness(prun, sol.grn, prun.steps) for sol in pop])
    return fitness, stats

def do_selection(prun, pop_fit, pop):
    # print("Fit score : ", pop_fit)
    acc = list(accumulate(pop_fit))
    best = max(pop_fit)
    best_id = pop_fit.index(best)
    
    print("Total fitness :", acc[-1])
    
    pop_sel, history_sel = weighted_selection(pop, pop_fit, individual_generator, new_fitness=0.3)
        
    return pop_sel, history_sel, best_id

def do_mutation(prun, pop_sel):
    [p.mutate() for p in pop_sel]
    return pop_sel

In [15]:
class ObjectiveStep(O):
    start_time = 0
    end_time = 0
    max_fitness = 4
    min_fitness = 1
    name = ""
    _passed = False
    
    def reset(self):
        self._passed = False
    
    def passed(self):
        if self._passed:
            return
        print(f"Step {self.name} passed !")
        self._passed = True
    
example_steps = [
    ObjectiveStep(name="1", start_time=50, end_time=53, score_func=score_bb_size),
    ObjectiveStep(name="2", start_time=53, end_time=56, score_func=score_bb_size),
    ObjectiveStep(name="3", start_time=56, end_time=59, score_func=score_bb_size),
    ObjectiveStep(name="4", start_time=59, end_time=62, score_func=score_bb_size),
    ObjectiveStep(name="5", start_time=62, end_time=65, score_func=score_bb_size),
    ObjectiveStep(name="6", start_time=65, end_time=68, score_func=score_bb_size),
    ObjectiveStep(name="7", start_time=68, end_time=71, score_func=score_bb_size),
    ObjectiveStep(name="8", start_time=71, end_time=74, score_func=score_bb_size),
    ObjectiveStep(name="9", start_time=74, end_time=77, score_func=score_bb_size),
]

class ParamRun(O):
    pop_size = 10
    n_gen = 30
    current_gen = 0
    end_time = 83
    ref = REF

args = ParamRun()
args.steps = example_steps
for step in args.steps:
    step.reset()

In [16]:
def main(prun):
    exporter = Exporter()
    best = 0
    pop = do_init_pop(prun)
    for generation in range(prun.n_gen):
        # args.generation = generation
        # objective.new_trial()
        fit, stats = do_fitness(prun, pop)
        # objective.best_current(max(fit))
        
        # TODO get the stats associated with the best scores
        sel, history_sel, best_id = do_selection(prun, fit, pop)
        if fit[best_id] > best:
            print(f"++ Best {fit[best_id]}")
            best = fit[best_id]
            HALL_OF_FAME.append(pop[best_id].copy())
        else:
            print(f"-- Best {best}")
        pop = do_mutation(prun, sel)
        
        # history
        monitor = dict(
            transition=history_sel,
            solution=pop,
            fitness=fit,
            stats=stats,
        )
        HISTORY[generation] = monitor
        exporter(monitor, f"generation_g{generation}")
        
    return best

In [17]:
print(1)

1


In [18]:
sol = main(args)

Exporting at output/2022-01-27/16:15:32.018554_545533
Total fitness : 1.3431384247449774
++ Best 0.6568660741681976
Step 1 passed !
Total fitness : 0.9491868989710649
-- Best 0.6568660741681976
Total fitness : 2.1914319268646008
++ Best 0.9695173503085515
Total fitness : 2.911649535113378
-- Best 0.9695173503085515
Total fitness : 6.300735561959916
++ Best 1.404767964601398
Total fitness : 6.462010626455761
++ Best 1.5333391343172889
Total fitness : 8.448035358770564
++ Best 1.5513892570588579
Total fitness : 7.061514036320228
-- Best 1.5513892570588579
Step 2 passed !
Total fitness : 10.246567219761197
++ Best 1.5853769891328842
Total fitness : 9.637158114366361
++ Best 2.12337845871288
Total fitness : 8.55875810534278
-- Best 2.12337845871288
Total fitness : 9.61341983336285
-- Best 2.12337845871288
Total fitness : 9.176921442206933
-- Best 2.12337845871288
Total fitness : 10.516460160778795
-- Best 2.12337845871288
Total fitness : 8.324066385195136
-- Best 2.12337845871288
Total fit

In [19]:
gen = HISTORY[27]
idx = gen["fitness"].index(max(gen["fitness"]))
best_stats = gen["stats"][idx]

In [20]:
score_multistep(args, best_stats, args.steps)

NameError: name 'bb' is not defined

In [None]:
raise

In [None]:
# GOOD_POP = HISTORY[2]["solution"]

In [None]:
# GOOD_POP = HISTORY[2]

In [None]:
# best = HISTORY[25]

In [None]:
HISTORY.keys()

In [None]:
last = HISTORY[0]
last.keys()

In [None]:
fit = last['fitness']
index = fit.index(max(fit))
sol = last["solution"][index]
index, max(fit)

In [None]:
stats = last["stats"][index]

In [None]:
def show_curve(stats, ref, max_step=None, show=True):
    """
    The evaluation function for progenitor population size.
    The lower is the better
    :param stats: the stats of the bb after running
    """
    stats, ref = shrink_and_align_stats(stats, ref, max_step=max_step)
    
    x, y = preprocess_progenitor_size(stats, ref)
    
    plt.plot(ref.index, x, label="Reference Prog")
    plt.plot(ref.index, y, label="Simulation Prog")
    
    x, y = preprocess_whole_size(stats, ref)
    
    plt.plot(ref.index, x, label="Reference Whole")
    plt.plot(ref.index, y, label="Simulation Whole")
    
    plt.legend()
    
    if show:
        plt.show()

In [None]:
def print_fmetrics(population, ref):
    pairs = get_fmetric_pairs(population, min_time=50, max_time=60)
    fmetric = fate_corr(pairs)
    print(f"Population : {fmetric}, ref : {ref[50][60]}")

In [None]:
show_curve(stats, REF, max_step=args.end_time)

In [None]:
score_multistep(args, stats, args.steps)

In [None]:
REF.head(10)

In [None]:
def fill_df(stats, ls_val):
    for time, prog_pop, whole_pop in ls_val:
        stats.loc[time] = {"progenitor_pop_size": prog_pop, "time": time, "whole_pop_size": whole_pop}

In [None]:
vals = [
    (50.5, 25, 27),
    (51., 26, 30),
    (51.5, 27, 40),
    (52., 30, 50),
    (52.5, 35, 60),
    (53., 40, 27),
    (53.5, 45, 27),
    (54., 50, 27),
    (54.5, 55, 27),
    (55., 60, 27),
]

In [None]:
fill_df(stats, ls_val=vals)
stats

In [None]:
show_curve(stats, REF.stats, max_step=55)
print_fmetrics(population, REF.fmetric)
score_multistep(args, stats, args.steps)

In [None]:
# peut être mettre sqrt plutôt que abs en dénominateur
# ensuite équilibrer les fit value en normalisant avec la médiane ou la moyenne ou autre (3e quartile ?)