In [4]:
import os
import pickle
import json
import utils
import numpy as np
import pandas as pd

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from data import get_data, preprocess_data
from config import config

with open("model/gfr_dataset.json", "r") as f:
    json_dataset = json.load(f)

dataset = utils.df_from_json(json_dataset)

cell_ids = set()
for k in dataset.keys():
    cell_ids = cell_ids.union(set(dataset[k]["cell_id"].to_list()))
cell_ids = list(cell_ids)

In [None]:
with open("cell_ids.csv", "w") as f:
    f.write(",".join(map(str, cell_ids)))

In [5]:
# -----------------------------------------------------------------------------
# Original Polynomial Activation
# -----------------------------------------------------------------------------
class PolynomialActivation(torch.nn.Module):
    def __init__(self, degree, max_current, max_firing_rate, bin_size):
        super().__init__()
        self.degree = degree
        self.max_current = torch.nn.Parameter(torch.tensor(max_current).reshape(1), requires_grad=False)
        self.max_firing_rate = torch.nn.Parameter(torch.tensor(max_firing_rate).reshape(1), requires_grad=False)
        self.bin_size = bin_size
        self.p = torch.nn.Parameter(torch.tensor([d for d in range(degree+1)]), requires_grad=False)
        self.poly_coeff = torch.nn.Parameter(torch.randn(1, self.degree + 1))
        self.b = torch.nn.Parameter(torch.tensor([0.0]))

    def forward(self, z):
        x = (z - self.b) / self.max_current
        poly = torch.einsum("ijk,jk->ij",
                            x.unsqueeze(dim=2).pow(self.p.reshape(1, 1, -1)),
                            self.poly_coeff ** 2)
        tan = self.max_firing_rate * torch.tanh(poly)
        return F.relu(tan).to(torch.float32)

    def get_params(self):
        return {
            "type": "polynomial",
            "degree": self.degree,
            "max_current": self.max_current.item(),
            "max_firing_rate": self.max_firing_rate.item(),
            "poly_coeff": self.poly_coeff.detach().cpu().tolist(),
            "b": self.b.detach().cpu().tolist(),
            "bin_size": self.bin_size
        }

    @classmethod
    def from_params(cls, params):
        g = cls(params["degree"], params["max_current"], params["max_firing_rate"], params["bin_size"])
        g.poly_coeff = torch.nn.Parameter(torch.tensor(params["poly_coeff"]))
        g.b = torch.nn.Parameter(torch.tensor(params["b"]))
        return g

# -----------------------------------------------------------------------------
# Benchmark Activation Modules
# -----------------------------------------------------------------------------
class ReLUActivation(torch.nn.Module):
    def __init__(self, bin_size, max_current):
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(1))
        self.b = torch.nn.Parameter(torch.zeros(1))
        self.bin_size = bin_size
        self.max_current = max_current

    def forward(self, z):
        x = (self.a * z - self.b) / self.max_current
        return F.relu(x)

    def get_params(self):
        return {"type": "relu", "a": self.a.detach().item(), "b": self.b.detach().item(), "bin_size": self.bin_size}

    @classmethod
    def from_params(cls, params):
        g = cls(params["bin_size"])
        g.a = torch.nn.Parameter(torch.tensor(params["a"]))
        g.b = torch.nn.Parameter(torch.tensor(params["b"]))
        return g

class SigmoidActivation(torch.nn.Module):
    def __init__(self, bin_size, max_firing_rate, max_current):
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(1))
        self.b = torch.nn.Parameter(torch.zeros(1))
        self.c = torch.nn.Parameter(torch.tensor(max_firing_rate), requires_grad=False)  # scale
        self.bin_size = bin_size
        self.max_current = max_current

    def forward(self, z):
        x = (self.a * z - self.b) / self.max_current
        return self.c * torch.sigmoid(x)

    def get_params(self):
        return {"type": "sigmoid", "a": self.a.detach().item(), "b": self.b.detach().item(), "c": self.c.detach().item(), "bin_size": self.bin_size}

    @classmethod
    def from_params(cls, params):
        g = cls(params["bin_size"], params["c"])
        g.a = torch.nn.Parameter(torch.tensor(params["a"]))
        g.b = torch.nn.Parameter(torch.tensor(params["b"]))
        return g

class SoftplusActivation(torch.nn.Module):
    def __init__(self, bin_size, max_current):
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(1))
        self.b = torch.nn.Parameter(torch.zeros(1))
        self.c = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)  # scale
        self.bin_size = bin_size
        self.max_current = max_current

    def forward(self, z):
        x = (self.a * z - self.b) / self.max_current
        return self.c * F.softplus(x)

    def get_params(self):
        return {"type": "softplus", "a": self.a.detach().item(), "b": self.b.detach().item(), "c": self.c.detach().item(), "bin_size": self.bin_size}

    @classmethod
    def from_params(cls, params):
        g = cls(params["bin_size"])
        g.a = torch.nn.Parameter(torch.tensor(params["a"]))
        g.b = torch.nn.Parameter(torch.tensor(params["b"]))
        g.c = torch.nn.Parameter(torch.tensor(params["c"]))
        return g

# -----------------------------------------------------------------------------
# Utility: Choose activation class
# -----------------------------------------------------------------------------
def make_activation(act_type: str, bin_size: float, degree=None, max_current=None, max_firing_rate=None):
    if act_type == 'polynomial':
        if degree is None or max_current is None or max_firing_rate is None:
            raise ValueError("PolynomialActivation requires degree, max_current, and max_firing_rate")
        return PolynomialActivation(degree, max_current, max_firing_rate, bin_size)
    elif act_type == 'relu':
        return ReLUActivation(bin_size, max_current)
    elif act_type == 'sigmoid':
        if max_firing_rate is None:
            raise ValueError("SigmoidActivation requires max_firing_rate")
        return SigmoidActivation(bin_size, max_firing_rate, max_current)
    elif act_type == 'softplus':
        return SoftplusActivation(bin_size, max_current)
    else:
        raise ValueError(f"Unknown activation type: {act_type}")

# -----------------------------------------------------------------------------
# Fitting routines (with train/test split)
# -----------------------------------------------------------------------------
def fit_activation(
    actv,
    criterion,
    optimizer,
    Is_train,
    fs_train,
    Is_test=None,
    fs_test=None,
    epochs=1000
):
    train_losses = []
    test_losses = []
    for _ in range(epochs):
        # training step
        total_train = 0
        for current, fr in zip(Is_train, fs_train):
            current = current.reshape(1,1)
            pred = actv(current)
            loss = criterion(pred * actv.bin_size, fr.reshape(1,1) * actv.bin_size)
            total_train += loss
        optimizer.zero_grad()
        total_train.backward()
        optimizer.step()
        train_losses.append(total_train.item() / len(Is_train))
        # evaluation on test set
        if Is_test is not None and fs_test is not None:
            with torch.no_grad():
                total_test = 0
                for current, fr in zip(Is_test, fs_test):
                    current = current.reshape(1,1)
                    pred = actv(current)
                    loss = criterion(pred * actv.bin_size, fr.reshape(1,1) * actv.bin_size)
                    total_test += loss
                test_losses.append(total_test.item() / len(Is_test))
    return train_losses, test_losses

# -----------------------------------------------------------------------------
# Activation fitting with optional test split
# -----------------------------------------------------------------------------
def get_activations(
    act_type,
    Is_train,
    fs_train,
    bin_size,
    epochs=1000,
    device=None,
    g=None,
    degree=None,
    max_firing_rate=None,
    Is_test=None,
    fs_test=None
):
    # init if needed
    if g is None:
        max_current = float(torch.max(torch.abs(torch.cat((Is_train, Is_test) if Is_test is not None else Is_train))))
        g = make_activation(act_type, bin_size, degree, max_current, max_firing_rate)
    g = g.to(device)
    criterion = torch.nn.PoissonNLLLoss(log_input=False)
    optimizer = torch.optim.Adam(g.parameters(), lr=0.05)
    train_losses, test_losses = fit_activation(
        g, criterion, optimizer,
        Is_train, fs_train,
        Is_test, fs_test,
        epochs
    )
    return g, train_losses, test_losses

# -----------------------------------------------------------------------------
# Full model fitting including activation train/test split
# -----------------------------------------------------------------------------
def fit_model(
    cell_id,
    activation_bin_size,
    activation_type,
    max_firing_rate=None,
    degree=None,
    device=None,
    g=None
):
    data = get_data(cell_id)
    Is_np, fs_np = preprocess_data(data, activation_bin_size)
    Is = torch.tensor(Is_np).to(device)
    fs = torch.tensor(fs_np).to(device)

    # split into train/test
    n = len(Is)
    gen = torch.Generator().manual_seed(42)
    idx = torch.randperm(n, generator=gen)
    train_size = int(0.8 * n)
    train_idx, test_idx = idx[:train_size], idx[train_size:]
    Is_train, fs_train = Is[train_idx], fs[train_idx]
    Is_test, fs_test = Is[test_idx], fs[test_idx]

    # fit activation
    g, train_losses, test_losses = get_activations(
        activation_type,
        Is_train,
        fs_train,
        activation_bin_size,
        epochs=1000,
        device=device,
        g=g,
        degree=degree,
        max_firing_rate=max_firing_rate,
        Is_test=Is_test,
        fs_test=fs_test
    )
    activation_params = g.get_params()

    # return activations and their losses
    return activation_params, train_losses, test_losses

def summarize_losses(df):
    """
    Given a DataFrame with columns
      ['cell_id', 'activation_bin_size', 'activation_type', 'train_loss', 'test_loss'],
    returns a new DataFrame with the mean train_loss and test_loss grouped by
    activation_type and activation_bin_size.
    """
    # group and average
    summary = (
        df
        .groupby(['activation_type', 'activation_bin_size'])[['train_loss','test_loss']]
        .mean()
        .reset_index()
    )
    return summary


In [10]:
# USER‑SET: list of cell IDs and your bin size
# cell_ids = [313860745]
activation_bin_sizes = [20, 100]
actv_types = ['polynomial', 'relu', 'sigmoid', 'softplus']

rows = []
for cell_id in tqdm(cell_ids):
    path = config["data_path"] + f"processed_I_and_firing_rate_{cell_id}.pickle"
    path2 = "data/activation_function/" + f"{cell_id}.pickle"
    # check if the data file does not exist and activation function has not been fitted
    if not os.path.exists(path):
        print(f"Skipping cell {cell_id}: data file does not exist.")
    elif os.path.exists(path2):
        print(f"Skipping cell {cell_id}: activation function has already been fitted.")
        # load cell data and add to dataframe
        with open(path2, "rb") as f:
            data = pickle.load(f)
        rows.append({
            'cell_id': cell_id,
            'activation_bin_size': data['activation_bin_size'],
            'activation_type': data['activation_type'],
            'train_loss': data['train_loss'],
            'test_loss': data['test_loss'],
            "activation_params": data['activation_params']
        })
    else:
        # load that cell’s max firing rate
        with open(f"{config['mfr_path']}{cell_id}.pickle", "rb") as f:
            max_firing_rate = pickle.load(f)

        for act_type in actv_types:
            for activation_bin_size in activation_bin_sizes:
                # fit_model returns (activation_params, train_losses, test_losses)
                activation_params, train_losses, test_losses = fit_model(
                    cell_id,
                    activation_bin_size,
                    act_type,
                    max_firing_rate,
                    degree=2
                )

                rows.append({
                    'cell_id': cell_id,
                    'activation_bin_size': activation_bin_size,
                    'activation_type': act_type,
                    'train_loss': train_losses[-1],
                    'test_loss':  test_losses[-1],
                    "activation_params": activation_params
                })

                # save data as data/activation_function/[cell_id].pickle
                with open(f"data/activation_function/{cell_id}.pickle", "wb") as f:
                    pickle.dump({
                        'activation_bin_size': activation_bin_size,
                        'activation_type': act_type,
                        'train_loss': train_losses[-1],
                        'test_loss':  test_losses[-1],
                        "activation_params": activation_params
                    }, f)

# build & inspect
df = pd.DataFrame(rows)

# save dataframe as pickle
with open("model/activation_fits.pickle", "wb") as f:
    pickle.dump(df, f)

  0%|          | 0/1796 [00:00<?, ?it/s]

Skipping cell 566517779: activation function has already been fitted.
Skipping cell 486875162: activation function has already been fitted.
Skipping cell 562651165: activation function has already been fitted.
Skipping cell 526573598: activation function has already been fitted.
Skipping cell 608108585: activation function has already been fitted.
Skipping cell 584146987: activation function has already been fitted.
Skipping cell 479993900: activation function has already been fitted.
Skipping cell 569270348: activation function has already been fitted.
Skipping cell 508911693: activation function has already been fitted.
Skipping cell 517693519: activation function has already been fitted.
Skipping cell 609435731: activation function has already been fitted.
Skipping cell 479010903: activation function has already been fitted.
Skipping cell 341442651: activation function has already been fitted.
Skipping cell 530022494: activation function has already been fitted.
Skipping cell 485245

 23%|██▎       | 408/1796 [15:59<3:40:21,  9.53s/it]

Skipping cell 490718897: data file does not exist.


100%|██████████| 1796/1796 [4:10:05<00:00,  8.35s/it]  


In [11]:
summary_df = summarize_losses(df)
print(summary_df)

  activation_type  activation_bin_size  train_loss  test_loss
0      polynomial                   20    0.207010   0.314541
1      polynomial                  100   -0.480744  -0.278284
2            relu                   20    0.348153   0.480723
3            relu                  100    1.018951   1.374021
4         sigmoid                   20    1.015814   1.087993
5         sigmoid                  100    3.427711   3.648883
6        softplus                   20    7.577573   7.980943
7        softplus                  100   35.628174  37.570494


In [12]:
# drop nan rows in df
df_no_nan = df.dropna(subset=['train_loss', 'test_loss'])
summary_no_nan = summarize_losses(df_no_nan)
print(summary_no_nan)

  activation_type  activation_bin_size  train_loss  test_loss
0      polynomial                   20    0.207010   0.314541
1      polynomial                  100   -0.480744  -0.278284
2            relu                   20    0.348153   0.480723
3            relu                  100    1.018951   1.374021
4         sigmoid                   20    1.015814   1.087993
5         sigmoid                  100    3.427711   3.648883
6        softplus                   20    7.577573   7.980943
7        softplus                  100   35.628174  37.570494


In [15]:
# number of rows in df_no_nan
print(f"Number of rows in df_no_nan: {len(df_no_nan)}")

# number of rows in df
print(f"Number of rows in df: {len(df)}")

Number of rows in df_no_nan: 12161
Number of rows in df: 12161
