In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".13"

In [2]:
import pandas as pd
from brain import BrainModel
from submodels import factories
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
from random import shuffle

from lib.score import (
    fate_corr, score_both_size_norm, shrink_and_align_stats, score_stats_norm
)
from lib.preprocess import *
from lib.callback import (
    cell_number_callback, progenitor_number_callback, neuron_number_callback,
    TargetPopulation, TagNumberCallback,
)
from lib.sde.grn.grn5 import GRNMain5 as GRNMain
from lib.sde.mutate import mutate_grn5 as mutate_grn

from lib.ga.utils import weighted_selection_one, normalize_fitness_values
from jf.utils.export import Exporter
from jf.autocompute.jf import O, L
from itertools import product
import jf.models.stringmodel as sm
from lib.analyser import show_curve, show_curve_progenitor
import random
from lib.sde.mutate import mutate_gene4, mutate_tree_2

In [3]:
HISTORY = defaultdict(dict)
HALL_OF_FAME = []

In [4]:
_count = -1
def provide_id():
    global _count
    _count += 1
    return _count

In [5]:
REF = O(
    stats=pd.read_csv("reference/ref_tristate2.csv"),  # ref is a mean
)

In [6]:
def individual_generator(id_=-1, cb_init=None, nb_genes=4):
    return Solution(GRNMain(nb_genes, 0, 1, generate_funcs=cb_init), id_=id_)

In [7]:
_MUTATE_FUNC = mutate_grn

In [8]:
class Solution:
    def __init__(self, grn, id_=0, parent=-1):
        self.id = id_
        self.grn = grn
        self.parent = parent
        self.fit = -1
        self.stats = None
        
    def copy(self, id_=0):
        return Solution(self.grn.copy(), id_=id_, parent=self.id)
        
    def mutate(self):
        global _MUTATE_FUNC
        _MUTATE_FUNC(self.grn)

In [9]:
def score_bb_size(bb, ref, *args, **kwargs):
    stats = bb.stats.copy()
    last_time = max(stats.time)
    stats = stats.set_index("time")
    prog = stats.loc[last_time]["progenitor_pop_size"]
    neuron = stats.loc[last_time]["neuron_pop_size"]
    # print(prog, neuron, neuron - prog**1.8)

    return 1 / max(1, 1000 + (neuron - prog**1.5))

In [10]:
def setup_tag(cp):
    indexes = list(cp.base_population.keys())
    shuffle(indexes)
    splits = np.array_split(indexes, 3)
    for i, ls in enumerate(splits):
        for idx in ls:
            cp.base_population[idx].tag["subbrain"] = i

In [11]:
def get_bb(prun, grn):
    ccls = factories["grn5"](grn=grn)
    callbacks = dict(
        progenitor_pop_size=progenitor_number_callback,
        whole_pop_size=cell_number_callback,
        neuron_pop_size=neuron_number_callback,
    )
    bb = BrainModel(time_step=0.5, verbose=False, start_population=prun.size, max_pop_size=3e2,
            cell_cls=ccls, end_time=prun.end_time, start_time=56, silent=True, opti=True,
              run_tissue=True, monitor_callbacks=callbacks, tag_func=setup_tag)
    return bb

In [12]:
def run_grn(prun, grn):
    get_bb(prun, grn)
    bb.run()
    return bb

In [13]:
def fitness_multistep(prun, grn, steps):
    total_fitness = 0
    stop = False
    previous_time = None
    bb = get_bb(prun, grn)
    # first step
    for step in steps:
        if not bb.run_until(step.end_time):
            stop = True
        # score_step = score_both_size(bb.stats, prun.ref, max_step=step.end_time, min_step=previous_time)
        score_step = step.score_func(bb, prun.ref, max_step=step.end_time, min_step=step.start_time)
        fitness_step = 1.0 / score_step
        fitness_step = min(fitness_step, step.max_fitness)
        total_fitness += fitness_step
        if fitness_step < step.min_fitness or stop:
            return total_fitness, bb.stats
        else:
            previous_time = step.end_time
            step.passed()
        
    return total_fitness, bb.stats

In [14]:
def mean_sd_fitness(prun, grn, run=3):
    fitnesses = [fitness_multistep(prun, grn) for i in range(run)]
    return np.mean(fitnesses), np.std(fitnesses)

In [15]:
def multi_fitness(*args):
    fitnesses = [fitness_multistep(*args) for i in range(3)]
    scores = [x[0] for x in fitnesses]
    return fitnesses[scores.index(min(scores))]

In [16]:
def do_init(prun):
    return individual_generator(provide_id(), prun.cb_init)

def do_fitness(prun, sol):
    fitness, stats = fitness_multistep(prun, sol.grn, prun.steps)
    return fitness, stats

def do_selection(prun, pop_fit, pop):
    if len(pop) < prun.min_pop:
        return individual_generator(provide_id(), prun.cb_init)
    
    pop_fit = normalize_fitness_values(pop_fit)
    
    return weighted_selection_one(pop, pop_fit, lambda x: individual_generator(x, prun.cb_init),
                                  new_fitness=0.5, id_=provide_id())[0]

def do_mutation(prun, sol):
    sol.mutate()
    return sol

In [17]:
class ObjectiveStep(O):
    start_time = 0
    end_time = 0
    max_fitness = 1e6
    min_fitness = 1
    name = ""
    _passed = False
    
    def reset(self):
        self._passed = False
    
    def passed(self):
        if self._passed:
            return
        print(f"Step {self.name} passed !")
        self._passed = True
    
example_steps = [
    ObjectiveStep(name="The ONE", start_time=56, end_time=86, score_func=score_bb_size, min_fitness=0.2),
]

class ParamRun(O):
    pop_size = 10
    n_gen = 20
    current_gen = 0
    end_time = 86
    ref = REF
    min_pop = 25
    max_pop = 50

def get_prun(size=5, exponent=1):
    prun = ParamRun()
    prun.cb_init = dict()
    prun.size = size
    prun.exponent = exponent
    prun.steps = example_steps
    return prun

In [18]:
def main(prun):
    prun.history = defaultdict(dict)
    # exporter = Exporter()
    best = 0
    sol = do_init(prun)
    pop = [sol]
    for generation in range(prun.n_gen * prun.pop_size):
        # args.generation = generation
        # objective.new_trial()
        fit, stats = do_fitness(prun, sol)
        sol.fit = fit
        sol.stats = stats
        
        # history
        # print(f"Fitness = {fit}", end="\t\t")
        if generation % 100 == 0:
            print(f"Step {generation}")
        if fit > best:
            print(f"++ Best {fit} for generation {generation}")
            best = fit
            
        monitor = sol
        prun.history[generation] = monitor
        # exporter(monitor, f"generation_g{generation}")
        
        # TODO get the stats associated with the best scores
        sub_pop = pop[-prun.max_pop:]
        sol = do_selection(prun, [s.fit for s in sub_pop], sub_pop)
            
        sol = do_mutation(prun, sol)
        pop.append(sol)
        
    print("OVER")
    return best

In [19]:
prun = get_prun()
# res = main(prun)

In [20]:
def get_score_parametrized(cb_init, name=None):
    args = get_prun()
    args.cb_init = cb_init
    args.name = name
    main(args)
    p1 = L(args.history.values()).fit
    idx = p1.index(max(p1))
    sol = args.history[idx]
    scores = [fitness_multistep(args, sol.grn, args.steps)[0] for i in range(10)]
    print("Final score", np.mean(scores), "+-", np.std(scores))
    return np.mean(scores), np.std(scores), max(p1)

In [21]:
def mutate_grn_verysparse(grn, temperature=0.1):
    grn.set_mutable()
    shape = grn._params.shape
    r = random.random()
    param_prob = 0.8
    if r < param_prob:
        mask = (np.random.uniform(0, 1, shape) > 0.9)
        coeff = np.random.normal(0, temperature, shape)
        true_coeff = mask * coeff + 1
        grn._params *= true_coeff
    else:
        one_gene = random.choice(grn.genes)
        one_gene.tree = mutate_tree_2(one_gene.tree, one_gene.get_labels_not_in_tree())
    grn.compile()

In [22]:
model = sm.StringModel("expfit_func{mutfunc}_i{i}")
mut_func = dict(
    fullctrl=lambda x: mutate_grn_verysparse(x),
    full005=lambda x: mutate_grn_verysparse(x, temperature=0.05),
    full02=lambda x: mutate_grn_verysparse(x, temperature=0.2),
)

In [23]:
callback_init = dict(
    init=lambda: np.random.beta(1.5, 3) * 3,
    b=lambda: np.random.beta(1.5, 3) * 5,
    expr=lambda: 1,
    deg=lambda: 0.1,
)

In [24]:
# main loop
exporter = Exporter(name="exp_mutfunc8_280222")
for (funcname, func), i in product(mut_func.items(), range(10)):
    name = model.fill(mutfunc=funcname, i=i)
    print(name)
    if name in exporter.list():
        continue
    _MUTATE_FUNC = func
    res = get_score_parametrized(callback_init)
    exporter(res, name)

Exporting at output/exp_mutfunc8_280222
expfit_funcfullctrl_i0
Step 0
++ Best 1.0 for generation 0
++ Best 1029.0 for generation 1
Step The ONE passed !
++ Best 1041.0 for generation 33
++ Best 1226.180920072727 for generation 34
++ Best 1299.430780618347 for generation 60
++ Best 1679.9377863789657 for generation 79
Step 100
++ Best 1688.0177871865296 for generation 121
++ Best 1716.8302205930995 for generation 126
++ Best 1756.68323274845 for generation 133
++ Best 1769.9072043644996 for generation 180
OVER
Final score 1610.632097525448 +- 71.66948936602877
expfit_funcfullctrl_i1
Step 0
++ Best 844.8302205930993 for generation 0
++ Best 971.4307806183468 for generation 4
++ Best 1024.1715728752538 for generation 5
++ Best 1038.0 for generation 22
Step 100
OVER
Final score 1011.1073632772119 +- 12.764143149619951
expfit_funcfullctrl_i2
Step 0
++ Best 1.0 for generation 0
++ Best 795.752267887179 for generation 2
++ Best 995.3772233983163 for generation 3
++ Best 1028.3030615433008 for

In [25]:
dicts = [dict(**model.extract(x), **dict(zip(("mean", "sd", "max"), exporter.load(x)))) for x in exporter.list() if model.match(x)]

In [26]:
df = pd.DataFrame(dicts)

In [27]:
df.head()

Unnamed: 0,mutfunc,i,mean,sd,max
0,full02,0,1367.528493,39.211868,1493.430781
1,full005,7,1264.568294,43.307899,1335.430781
2,full02,9,1031.607566,13.936723,1044.0
3,full005,3,1199.115282,108.355397,1319.424492
4,fullctrl,7,1020.113881,5.143318,1048.479741


In [28]:
new_df = df.groupby(["mutfunc"]).mean()
new_df

Unnamed: 0_level_0,mean,sd,max
mutfunc,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
full005,1186.883243,98.268458,1323.538007
full02,1110.681596,209.391858,1434.075282
fullctrl,1182.617017,44.817931,1274.237468


In [29]:
for i, row in new_df.iterrows():
    print(row.name, row["mean"], row["mean"] - row.sd / np.sqrt(10), row["mean"] + row.sd / np.sqrt(10))

full005 1186.8832432004149 1155.8080284115326 1217.9584579892971
full02 1110.681595732502 1044.4660762446963 1176.8971152203078
fullctrl 1182.617016673857 1168.444342628164 1196.7896907195502


In [30]:
row.name

'fullctrl'