In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".13"

In [2]:
import pandas as pd
from model import Brain
from submodels import factories
import matplotlib.pyplot as plt
import pandas as pd
from itertools import accumulate
import numpy as np
from collections import defaultdict
import re

from jf.db import DB
from lib.score import fate_corr, score_both_size_norm, shrink_and_align_stats
from lib.preprocess import *
from lib.sde.grn3 import GRNMain3
from lib.sde.mutate import mutate_grn2
from lib.ga.utils import weighted_selection
from lib.ga.objective import Objective
from jf.profiler import Profiler
from jf.utils.export import Exporter
from jf.autocompute.jf import O

In [3]:
HISTORY = defaultdict(dict)
HALL_OF_FAME = []

In [4]:
_count = -1
def provide_id():
    global _count
    _count += 1
    return _count

In [5]:
REF = O(
    stats=pd.read_csv("output/results/setup_basic/export/ref_basic2.csv"),  # ref is a mean
    # fmetric=setup_ref_fmetric("output/results/setup_basic/export/ref_fmetric_tristate.csv"),
)

In [6]:
def individual_generator(id_=-1):
    return Solution(GRNMain3(5, 0, 0), id_=id_)

In [7]:
class Solution:
    def __init__(self, grn, id_=0, parent=-1):
        self.id = id_
        self.grn = grn
        self.parent = parent
        self.fit = -1
        self.stats = None
        
    def copy(self, id_=0):
        return Solution(self.grn.copy(), id_=id_, parent=self.id)
        
    def mutate(self):
        for i in range(1):
            mutate_grn2(self.grn)
        # here force values
        self.grn.set_mutable()
        self.grn._params[4, :] = 1
        self.grn._params[5, :] = 0
        self.grn.compile()

In [8]:
def score_bb_size(bb, ref, *args, **kwargs):
    return score_both_size_norm(bb.stats, ref.stats, *args, **kwargs, norm=2.0)

In [9]:
def get_bb(prun, grn):
    ccls = factories["grn3"](grn=grn)
    bb = Brain(time_step=0.5, verbose=False, start_population=4, max_pop_size=1e3,
            cell_cls=ccls, end_time=prun.end_time, start_time=50, silent=True,
              run_tissue=False)
    return bb

In [10]:
def run_grn(prun, grn):
    get_bb(prun, grn)
    bb.run()
    return bb

In [11]:
def fitness_func(prun, grn, score_func):
    bb = run_grn(prun, grn)
    output = score_func(bb.stats, REF, max_step=prun.end_time)
    fitness = 1.0 / output
    return fitness

In [12]:
def fitness_multistep(prun, grn, steps):
    total_fitness = 0
    stop = False
    previous_time = None
    bb = get_bb(prun, grn)
    # first step
    for step in steps:
        if not bb.run_until(step.end_time):
            stop = True
        # score_step = score_both_size(bb.stats, prun.ref, max_step=step.end_time, min_step=previous_time)
        score_step = step.score_func(bb, prun.ref, max_step=step.end_time, min_step=step.start_time)
        fitness_step = 1.0 / score_step
        fitness_step = min(fitness_step, step.max_fitness)
        total_fitness += fitness_step
        if fitness_step < step.min_fitness or stop:
            return total_fitness, bb.stats
        else:
            previous_time = step.end_time
            step.passed()
        
    return total_fitness, bb.stats

def score_multistep(prun, stats, steps):
    total_fitness = 0
    stop = False
    previous_time = None
    # first step
    for i, step in enumerate(steps):
        # score_step = score_both_size(bb.stats, prun.ref, max_step=step.end_time, min_step=previous_time)
        score_step = step.score_func(bb, prun.ref, max_step=step.end_time, min_step=step.start_time)
        print(f"Score for step {i} is {score_step}")
        fitness_step = 1.0 / score_step
        fitness_step = min(fitness_step, step.max_fitness)
        total_fitness += fitness_step
        if fitness_step < step.min_fitness or stop:
            return total_fitness
        previous_time = step.end_time
        
    return total_fitness

In [13]:
def mean_sd_fitness(prun, grn, run=3):
    fitnesses = [fitness_multistep(prun, grn) for i in range(run)]
    return np.mean(fitnesses), np.std(fitnesses)

In [14]:
def multi_fitness(*args):
    fitnesses = [fitness_multistep(*args) for i in range(3)]
    scores = [x[0] for x in fitnesses]
    return fitnesses[scores.index(min(scores))]

In [15]:
def get_sample_index(sample, acc):
    val = min(filter(lambda x: x > sample, acc))
    return acc.index(val)

In [16]:
def weighted_selection_one(population, fitness_values, create_func, new_fitness=1., id_=0):
    total_population_fitness = sum(fitness_values)
    acc = list(accumulate(fitness_values))
    total_fitness = total_population_fitness + new_fitness
    
    sample = np.random.uniform(0, total_fitness)
    if sample >= total_population_fitness:
        return create_func(id_), -1
    else:
        chosen_id = get_sample_index(sample, acc)
        return population[chosen_id].copy(id_), chosen_id

In [17]:
def do_init(prun):
    return individual_generator(provide_id())

def do_fitness(prun, sol):
    fitness, stats = multi_fitness(prun, sol.grn, prun.steps)
    return fitness, stats

def do_selection(prun, pop_fit, pop):
    if len(pop) < prun.min_pop:
        return individual_generator(provide_id())
    
    return weighted_selection_one(pop, pop_fit, individual_generator, new_fitness=0.5, id_=provide_id())[0]

def do_mutation(prun, sol):
    sol.mutate()
    return sol

In [18]:
class ObjectiveStep(O):
    start_time = 0
    end_time = 0
    max_fitness = 4
    min_fitness = 1
    name = ""
    _passed = False
    
    def reset(self):
        self._passed = False
    
    def passed(self):
        if self._passed:
            return
        print(f"Step {self.name} passed !")
        self._passed = True
    
example_steps = [
    ObjectiveStep(name="1", start_time=50, end_time=53, score_func=score_bb_size),
    ObjectiveStep(name="2", start_time=53, end_time=56, score_func=score_bb_size),
    ObjectiveStep(name="3", start_time=56, end_time=59, score_func=score_bb_size),
    ObjectiveStep(name="4", start_time=59, end_time=62, score_func=score_bb_size),
    ObjectiveStep(name="5", start_time=62, end_time=65, score_func=score_bb_size),
    ObjectiveStep(name="6", start_time=65, end_time=68, score_func=score_bb_size),
    ObjectiveStep(name="7", start_time=68, end_time=71, score_func=score_bb_size),
    ObjectiveStep(name="8", start_time=71, end_time=74, score_func=score_bb_size),
    ObjectiveStep(name="9", start_time=74, end_time=77, score_func=score_bb_size),
]

class ParamRun(O):
    pop_size = 100
    n_gen = 100
    current_gen = 0
    end_time = 83
    ref = REF
    min_pop = 20
    max_pop = 50

args = ParamRun()
args.steps = example_steps
for step in args.steps:
    step.reset()

In [19]:
def main(prun):
    exporter = Exporter()
    best = 0
    sol = do_init(prun)
    pop = [sol]
    for generation in range(prun.n_gen * prun.pop_size):
        # args.generation = generation
        # objective.new_trial()
        fit, stats = do_fitness(prun, sol)
        sol.fit = fit
        sol.stats = stats
        
        # history
        # print(f"Fitness = {fit}", end="\t\t")
        if generation % 100 == 0:
            print(f"Step {generation}")
        if fit > best:
            print(f"++ Best {fit} for generation {generation}")
            best = fit
            
        monitor = sol
        HISTORY[generation] = monitor
        exporter(monitor, f"generation_g{generation}")
        
        # TODO get the stats associated with the best scores
        sub_pop = pop[-prun.max_pop:]
        sol = do_selection(prun, [s.fit for s in sub_pop], sub_pop)
            
        sol = do_mutation(prun, sol)
        pop.append(sol)
        
    return best

In [None]:
sol = main(args)

Exporting at output/2022-01-29/14:11:04.030819_41618
Step 0
++ Best 0.09446645943899437 for generation 0
++ Best 0.09505919615260927 for generation 1
++ Best 0.13330481974056496 for generation 3
Step 1 passed !
Step 2 passed !
Step 3 passed !
Step 4 passed !
++ Best 0.6520484124333811 for generation 10
++ Best 1.5747953060483506 for generation 15
++ Best 1.8007324187625688 for generation 23
++ Best 2.0520752396339144 for generation 35
++ Best 2.862270944218904 for generation 38
++ Best 2.969895212631636 for generation 48
++ Best 3.148916361206117 for generation 92
Step 100
++ Best 5.7913864560012644 for generation 126


In [None]:
print(1)

In [None]:
raise

In [None]:
sols = list(HISTORY.values())
fits = [sol.fit for sol in sols]
idx = fits.index(max(fits))
sol = sols[idx]

In [None]:
sol.fit

In [None]:
def show_curve(stats, ref, max_step=None, show=True):
    """
    The evaluation function for progenitor population size.
    The lower is the better
    :param stats: the stats of the bb after running
    """
    stats, ref = shrink_and_align_stats(stats, ref, max_step=max_step)
    
    x, y = preprocess_progenitor_size(stats, ref)
    
    plt.plot(ref.index, x, label="Reference Prog")
    plt.plot(ref.index, y, label="Simulation Prog")
    
    x, y = preprocess_whole_size(stats, ref)
    
    plt.plot(ref.index, x, label="Reference Whole")
    plt.plot(ref.index, y, label="Simulation Whole")
    
    plt.legend()
    
    if show:
        plt.show()

In [None]:
def print_fmetrics(population, ref):
    pairs = get_fmetric_pairs(population, min_time=50, max_time=60)
    fmetric = fate_corr(pairs)
    print(f"Population : {fmetric}, ref : {ref[50][60]}")

In [None]:
show_curve(sol.stats, REF.stats, max_step=args.end_time)

In [None]:
score_multistep(args, stats, args.steps)

In [None]:
REF.head(10)

In [None]:
def fill_df(stats, ls_val):
    for time, prog_pop, whole_pop in ls_val:
        stats.loc[time] = {"progenitor_pop_size": prog_pop, "time": time, "whole_pop_size": whole_pop}

In [None]:
vals = [
    (50.5, 25, 27),
    (51., 26, 30),
    (51.5, 27, 40),
    (52., 30, 50),
    (52.5, 35, 60),
    (53., 40, 27),
    (53.5, 45, 27),
    (54., 50, 27),
    (54.5, 55, 27),
    (55., 60, 27),
]

In [None]:
fill_df(stats, ls_val=vals)
stats

In [None]:
show_curve(stats, REF.stats, max_step=55)
print_fmetrics(population, REF.fmetric)
score_multistep(args, stats, args.steps)

In [None]:
# peut être mettre sqrt plutôt que abs en dénominateur
# ensuite équilibrer les fit value en normalisant avec la médiane ou la moyenne ou autre (3e quartile ?)