In [49]:
import os
import re
import pandas as pd
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from pathlib import Path as P
from biodata import ratio_eomes, sorted_number_cells

In [103]:
def ModelStats(f):
    df = pd.read_csv(f)
    df = df.fillna(0)
    return df

In [92]:
class RefMetric:
    def ip_ratio(self):
        # in percentage, no x
        return list(ratio_eomes.index), ratio_eomes["val"] / 100
    
    def number_progenitor(self):
        # returns both x and y (normalized)
        return sorted_number_cells
    
    def number_neuron(self):
        # nothing to show
        return [], []
    
    def get(self, cat):
        return {"neuron": self.number_neuron,
                "progenitor": self.number_progenitor,
                "ip_ratio": self.ip_ratio
               }[cat]()
    
    def corr_progeny(self):
        # still nothing yet
        pass
    

In [21]:
class ModelMetric:
    def __init__(self, x):
        if isinstance(x, (str, pathlib.Path)):
            self.stats = ModelStats(x)
        else:
            self.stats = x
            
    def ip_ratio(self):
        non_IP = (self.stats.size_type_RG + self.stats.size_type_GP) \
                    if "size_type_GP" in self.stats.columns else \
                    self.stats.size_type_RG

        return self.stats.size_type_IP / (non_IP + self.stats.size_type_IP)
    
    def number_progenitor(self):
        ref = self.stats.whole_pop_size.iloc[0]
        return self.stats.progenitor_pop_size / ref
    
    def number_neuron(self):
        ref = self.stats.whole_pop_size.iloc[0]
        return self.stats.whole_pop_size / ref
    
    def get(self, cat):
        return {"neuron": self.number_neuron,
                "progenitor": self.number_progenitor,
                "ip_ratio": self.ip_ratio
               }[cat]()
    
    def get_time(self):
        return self.stats.time

In [55]:
class MultiModel:
    def __init__(self, ls):
        self.model_metrics = [ModelMetric(x) for x in ls]
        
    def get(self, cat):
        return self.gather_samples([x.get(cat) for x in self.model_metrics])
    
    def get_time(self):
        return self.model_metrics[0].get_time()

    def gather_samples(self, samples):
        # returns mean, sd
        return np.mean(samples, axis=0), np.std(samples, axis=0)

In [67]:
def pick_files(var, ctrl=None):
    root = P("output/results")
    mod = []
    for f in filter(lambda x: ".csv" in x,
            os.listdir(root)):
        if var in f or (ctrl is not None and ctrl in f):
            mod.append(str(root / f))

    return mod

def make_viz(curve, legend=None):
    time, y1 = curve
    if isinstance(y1, tuple):
        mean, sd = y1
        plt.plot(time, mean, label=legend)
        plt.fill_between(time, mean - sd, mean + sd, alpha=0.3)
    else:
        plt.plot(time, y1, label=legend)

In [85]:
# draw image with multiple params

def extract_param(filels, param_name):
    er = param_name + r"_([^_]+)_"
    try:
        return sorted(list(set([re.findall(er, x)[0] for x in filels])))
    except:
        print("Exception occured")
        return ["default"]
        

def filter_by_param(filels, param):
    return [f for f in filels if param in f]

In [117]:
def line_viz(filels, cat, legend=None):
    if len(filels) > 1:
        mod = MultiModel(filels)
    else:
        mod = ModelMetric(filels[0])
    curve = mod.get(cat)
    time = mod.get_time()
    make_viz((time, curve), legend=legend)
        
def full_viz(var, cat, fname=None):
    if fname is None:
        fname = var
    root = P("output/")
    plt.figure(figsize=(12, 8))
    plt.title(f"{cat}_{fname}")
    # load files for modality
    filels = pick_files(var)

    # split params
    params = extract_param(filels, var)

    # plot for * params
    if len(params) <= 1:
        line_viz(filels, cat, legend=cat)
    else:
        for p in params:
            ls = filter_by_param(filels, p)
            line_viz(ls, cat, legend=f"{cat}_{p}")
            
    ctrl_ls = pick_files("ctrl")
    line_viz(ctrl_ls, cat, legend=f"ctrl")

    # plot ref
    plt.plot(*RefMetric().get(cat))

    # export
    plt.legend()
    plt.savefig(str(root / f"{cat}_{fname}.png"))
    plt.close()

In [118]:
def build_viz():
    variates = [
        "gpasip",
        "smooth",
        "startval",
        "b1",
        "b2",
        "b3",
        "b4",
        "b5",
    ]
    
    cats = ["neuron", "progenitor", "ip_ratio"]

    for var in variates:
        for cat in cats:
            full_viz(var, cat)

In [119]:
build_viz()

Exception occured
Exception occured
Exception occured
