In [2]:
import os
import pickle
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from model import GFR, PolynomialActivation
from evaluate import explained_variance_ratio
from data import get_data, get_train_test_data, preprocess_data

pd.set_option('display.max_rows', None)
pd.options.display.float_format = '{:,.3f}'.format

In [3]:
def get_dataset(params, threshold=0.6):
    with open("model/labels.pickle", "rb") as f:
        labels = pickle.load(f)
    
    chosen_ids = filter(lambda x: params[x]["evr2"] >= threshold, params.keys())
    
    dataset = {}
    for cell_id in chosen_ids:
        y = labels[cell_id]
        p = params[cell_id]["params"]
        model = GFR.from_params(p)
        
        a = p["a"].reshape(-1)
        b = p["b"].reshape(-1)
        pc = p["g"]["poly_coeff"].reshape(-1)
        gb = p["g"]["b"].reshape(-1)
        mc = p["g"]["max_current"].reshape(-1)
        mfr = p["g"]["max_firing_rate"].reshape(-1)
        x = torch.cat([a, b, pc, gb, mc, mfr])
        
        dataset[cell_id] = (x, y, params[cell_id]["evr2"])
        
    return dataset
    
def get_params(bin_size, activation_bin_size, C, patch_seq=False):
    params = {}
    save_path = f"model/params/{bin_size}_{activation_bin_size}_{C}/"
    if patch_seq:
        save_path = f"model/params/patch_seq_{bin_size}_{activation_bin_size}_{C}/"
    for fname in os.listdir(save_path):
        if fname.endswith(".pickle"):
            cell_id = int(fname.split(".")[0])
            with open(f"{save_path}{fname}", "rb") as f:
                params[cell_id] = pickle.load(f)
    return params

def get_all_params(patch_seq=False):
    bin_sizes = [10, 20, 50, 100]
    activation_bin_sizes = [20, 100]
    C = [1, 0.5, 0.1, 0.05, 0.01, 0.005, 0.001, 0]
    
    params = {}
    
    for bin_size in bin_sizes:
        for activation_bin_size in activation_bin_sizes:
            if activation_bin_size >= bin_size:
                for c in C:
                    params[(bin_size, activation_bin_size, c)] = get_params(bin_size, activation_bin_size, c, patch_seq=patch_seq)
                                
    return params

# summarize params of one configuration
def summarize(params):
    data = {"cell_id": [], "evr1": [], "evr2": [], "loss": [], "epochs": []}

    for cell_id in params:
        data["cell_id"].append(cell_id)
        data["evr1"].append(params[cell_id]["evr1"])
        data["evr2"].append(params[cell_id]["evr2"])
        data["loss"].append(params[cell_id]["train_losses"][-1])
        data["epochs"].append(len(params[cell_id]["train_losses"]))

    df = pd.DataFrame(data)
    df = df.set_index("cell_id")
    df = df.sort_values("evr2")
    df_corrected = df[df["evr1"] > 0.01].dropna()
    
    if len(df) == 0:
        return {}
    
    return {
        "n_cells": len(df),
        "p_zero_evr": len(df[df['evr2'] < 0.01]) / len(df),
        "p_early_stop": len(df[df['epochs'] < 50]) / len(df),
        "median_evr": np.median(df_corrected['evr2'].values)
    }

def get_best_params_for_actv_bin_size(params, bin_size, actv_bin_size):
    best_params = {}
    
    cell_ids = set()
    for config in params:
        cell_ids = cell_ids.union(set(params[config].keys()))
    
    for cell_id in cell_ids:
        best_config = None
        best_evr = -1e10
        
        for config in params:
            if config[0] == bin_size and config[1] == actv_bin_size and cell_id in params[config] and params[config][cell_id]["evr1"] > best_evr:
                best_evr = params[config][cell_id]["evr1"]
                best_config = config
        
        # deals with NaN values
        if best_config is not None:
            best_params[cell_id] = params[best_config][cell_id]
        
    return best_params
    
def get_best_params(params, bin_size):
    best_params = {}
    
    cell_ids = set()
    for config in params:
        cell_ids = cell_ids.union(set(params[config].keys()))
    
    for cell_id in cell_ids:
        best_config = None
        best_evr = -1e10
        
        for config in params:
            if config[0] == bin_size and cell_id in params[config] and params[config][cell_id]["evr1"] > best_evr:
                best_evr = params[config][cell_id]["evr1"]
                best_config = config
        
        # doesn't make sense
        if best_config is not None:
            best_params[cell_id] = params[best_config][cell_id]
        
    return best_params

def visualize_data(params):
    data = {"cell_id": [], "evr1": [], "evr2": [], "loss": [], "epochs": []}

    for cell_id in params:
        data["cell_id"].append(cell_id)
        data["evr1"].append(params[cell_id]["evr1"])
        data["evr2"].append(params[cell_id]["evr2"])
        data["loss"].append(params[cell_id]["train_losses"][-1])
        data["epochs"].append(len(params[cell_id]["train_losses"]))

    df = pd.DataFrame(data)
    df = df.set_index("cell_id")
    df = df.sort_values("evr2")

    print(f"Total number of cells: {len(df)}")
    print(f"Number/proportion of cells with evr<=0: {len(df[df['evr2'] <= 0])}/{len(df[df['evr2'] <= 0]) / len(df)}")
    print(f"Number/proportion of cells with epochs<50: {len(df[df['epochs'] < 50])}/{len(df[df['epochs'] < 50]) / len(df)}")

    df_corrected = df[df["evr2"].notna()]
    print(f"Median evr: {np.median(df_corrected.dropna()['evr2'].values)}")

    evrs1 = df_corrected.iloc[:, 0]
    evrs2 = df_corrected.iloc[:, 1]
    losses = df_corrected.iloc[:, 2]

    plt.figure()
    plt.hist(evrs2, bins="auto")
    plt.xlabel("evr2")
    plt.ylabel("counts")
    plt.title("evr2 histogram (failed optimizations removed)")

    
    plt.figure()
    plt.hist(losses, bins="auto")
    plt.xlabel("loss")
    plt.ylabel("counts")
    plt.title("loss histogram (failed optimizations removed)")

    plt.figure()
    plt.scatter(evrs2, losses, alpha=0.5)
    plt.xlabel("evr2")
    plt.ylabel("loss")
    plt.title("evr2 vs loss scatter plot (failed optimizations removed)")

    plt.figure()
    plt.scatter(evrs1, evrs2, alpha=0.5)
    plt.xlabel("evr1")
    plt.ylabel("evr2")
    plt.title("evr1 vs evr2 scatter plot (failed optimizations removed)")
    
    return df

def summarize_all_models(all_params):
    summ = {}
    for config in all_params:
        summ[f"bin_size={config[0]}, activation_bin_size={config[1]}, C={config[2]}"] = summarize(all_params[config])
    for bin_size in [10, 20, 50, 100]:
        for actv_bin_size in [20, 100]:
            if bin_size <= actv_bin_size:
                best_params = get_best_params_for_actv_bin_size(all_params, bin_size, actv_bin_size)
                summ[f"Best for {bin_size=}, {actv_bin_size=}:"] = summarize(best_params)
    for bin_size in [10, 20, 50, 100]:
        best_params = get_best_params(all_params, bin_size)
        summ[f"Best for {bin_size=}:"] = summarize(best_params)
    print(nice_dict(summ))
    
def nice_dict(d, indent=0):
    s = []
    for i in d:
        if type(d[i]) == dict:
            s.append(f"{i}\n{nice_dict(d[i], indent=indent+1)}")
        elif type(d[i]) == float:
            s.append(f"{i}: {d[i]:.4f}")
        else:
            s.append(f"{i}: {d[i]}")
    return "\n".join(["| "*indent + x for x in s])

def save_best_params(all_params):
    best_params = {}
    for bin_size in [10, 20, 50, 100]:
        for actv_bin_size in [20, 100]:
            if bin_size <= actv_bin_size:
                best_params[(bin_size, actv_bin_size)] = get_best_params_for_actv_bin_size(all_params, bin_size, actv_bin_size)
    with open("model/best_params.pickle", "wb") as f:
        pickle.dump(best_params, f, protocol=pickle.HIGHEST_PROTOCOL)

In [22]:
df = pd.read_csv("data/metadata.csv")

def get_line_name(df, cell_id):
    return df[df["specimen__id"] == cell_id]["line_name"].to_numpy()[0]

In [4]:
with open("model/best_params_99_full.pickle", "rb") as f:
    all_params = pickle.load(f)

In [5]:
bin_size = 20
actv_bin_size = 100
params = all_params[(bin_size, actv_bin_size)]

In [None]:
for cell_id in params:
    p = params[cell_id]["params"]
    cell_type = get_line_name(df, cell_id)
    train_evr = params[cell_id]["evr1"]
    val_evr = params[cell_id]["evr2"]
    train_mse = params[cell_id]["train_losses"][-1]
    test_mse = params[cell_id]["test_losses"][-1]

In [8]:
cell_id = list(params.keys())[0]

In [9]:
p = params[cell_id]["params"]
model = GFR.from_params(p)

In [21]:
params[cell_id]

{'params': {'a': tensor([[ 8.0004e+00,  1.3346e-03,  1.0409e-03,  7.7979e-04, -2.9032e-05,
           -2.8466e-03, -6.2044e-03, -9.9217e-03]]),
  'b': tensor([[-0.0322, -0.0330, -0.0339, -0.0345, -0.0348, -0.0343, -0.0335, -0.0328]]),
  'g': {'max_current': Parameter containing:
   tensor([169.9915]),
   'max_firing_rate': Parameter containing:
   tensor([0.0815], dtype=torch.float64),
   'poly_coeff': tensor([[-2.8457e-01,  9.0654e-01,  4.1583e-23,  1.0626e-15]]),
   'b': tensor([48.4476]),
   'bin_size': 100},
  'ds': tensor([1.0000, 0.6321, 0.3297, 0.1813, 0.0952, 0.0392, 0.0198, 0.0100]),
  'bin_size': 20},
 'evr1': 0.2126748905506103,
 'evr2': 0.30039277317421137,
 'train_losses': [0.12678318073164233,
  0.1267899627169833,
  0.12680980089374244,
  0.12683317489626902,
  0.12685919281129168,
  0.12688775898759214,
  0.1269187919423955,
  0.1269525843526269,
  0.1269891026464727,
  0.12702822716609485,
  0.1270700218981745,
  0.12711512791453647,
  0.12716348770805402,
  0.12721473