### README


## `Model.py`
**What does the script do?**

* **`build_rslds_model`** – create an input-driven rSLDS with configurable discrete states (`K`), latent dimension, stickiness (`kappa`), and AR emissions.  
* **`fit_rslds_model`** – fit the model with variational inference; returns posterior means of latents/states and the ELBO curve.  
* **`save_model_outputs`** – write all expected artifacts for downstream analysis:  
  - `*model*.pkl` (fitted model)  
  - `x_hat.npy` (latents)  
  - `z_hat.npy` (discrete states)  
  - `elbos.npy` (training curve)  
  - `footshock.npy` (+ optional `speed.npy`, `pupil.npy`)  
* **`load_model_outputs`** – reload artifacts in a consistent format for plotting/evaluation.  
* Note: python 3.1 works the best with this model.

---

**Inputs**  

* Wide-format CSV (time × neurons) and HDF5 with firing rates & footshock times.  
* Configurable parameters: number of discrete states, latent dimension, learning settings.  

---

**Outputs (saved to run directory)**  

* Latents (`x_hat.npy`), states (`z_hat.npy`), ELBO (`elbos.npy`), shock/speed/pupil vectors.  
* Pickled model (`*model*.pkl`) for reproducibility and vector field plotting.  



1.   This notebook runs the different models for all rats and saves the outputs on drive.

2.  This notebooks runs cross-validations for all rats and saves the outputs on drive.



# Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')


# Clone SSM from git

In [None]:
!git clone https://github.com/lindermanlab/ssm.git
!pip install git+https://github.com/lindermanlab/ssm.git --no-build-isolation

In [None]:
!pip uninstall -y numpy scipy numba scikit-learn pandas matplotlib seaborn autograd ssm
!pip install "numpy==1.26.4" "scipy==1.10." "numba==0.58." "scikit-learn==1.3.*" \
            "pandas==1.5." "matplotlib==3.7." "seaborn==0.12.*"

# Clone rSLDS from git

In [None]:
!git clone https://github.com/ElifAyten/rSLDS.git

# Import the Libraries

In [None]:
from pathlib import Path
import traceback
import pandas as pd


DRIVE = Path("/content/drive/My Drive/rSLDS")
from rSLDS.cross_validation import crossval_rslds
from rSLDS.modelling import fit_single_rslds

# Model Runs

In [None]:
# model runner for all rats

RAT_IDS = [3, 4, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21]

# Which areas to run the model with

# subsets (can be discarded)
AREA_SUBSETS = [
    ("ventral",  "allActive"),
    ("ventral",  "responsive"),
    ("dorsal",   "allActive"),
    ("dorsal",   "responsive"),
    ("thalamus", "allActive"),
    ("thalamus", "responsive"),
]

# only responsive for all areas combined
RUN_RESPONSIVE_ALL = True

# hyperparameter grid
K_GRID        = [4]          # number of discrete states to try (can change)
SEEDS         = [0, 1]          # repeat runs with different random seeds
KAPPA_GRID    = [0.0]           # stickiness values to try
""" f κ = 0 → the model might switch states too rapidly, even within 1–2 time bins, just to explain small fluctuations in the firing rates.
If κ is larger → the model assumes states are more stable in time (like “once you’re in a shock-response state, you’ll stay there for a while”)."""
NUM_ITERS     = 300
OVERWRITE     = True            # if False, existing run folders will be skipped
VERBOSE       = True

# latent dimensionality: None -> auto using variance_goal
LATENT_DIM      = None
VARIANCE_GOAL   = 0.90 # we choose 90 % variance but can also be 95 %

# helper for the paths
def h5_path_for_rat(rid: int) -> Path:
    return (
        DRIVE / "Rat-Data-hdf5" / f"Rat{rid}" /
        "NpxFiringRate_Behavior_SBL_10msBINS_0smoothing.hdf5"
    )

def csv_path_area(rid: int, area: str, subset: str) -> Path:
    return (
        DRIVE / "Sub-Data" / "Seperate-by-Area" / f"Rat{rid}" /
        f"area_splits_rat{rid}_{subset}" / f"{area}_wide.csv"
    )

def csv_path_responsive_all(rid: int) -> Path:
    return (
        DRIVE / "Sub-Data" / "Only-Responsive" / f"Rat{rid}" /
        f"area_splits_rat{rid}_responsive" / "responsive_rates_raw.csv"
    )

def base_model_dir(rid: int) -> Path:
    return DRIVE / "rSLDS-Model-Outputs" / f"Rat{rid}-Model-Outputs"

def run_dir(rid: int, suffix: str, K: int, seed: int, kappa: float) -> Path:
    """
    Make a unique folder per configuration:
    models_Rat{rid}_{suffix}_K{K}_seed{seed}_kappa{kappa}
    Example: models_Rat15_dorsal_responsive_K2_seed0_kappa0.0
    """
    name = f"models_Rat{rid}_{suffix}_K{K}_seed{seed}_kappa{kappa:g}"
    return base_model_dir(rid) / name

def suffix_from(area: str | None, subset: str | None) -> str:
    if area is None and subset == "responsive":
        return "responsive_all"
    if subset == "responsive":
        return f"{area}_responsive"
    elif subset == "allActive":
        return f"{area}"
    else:
        raise ValueError(f"Unknown (area={area}, subset={subset})")

# running
def run_fit(h5_path: Path, csv_path: Path, save_dir: Path,
            K_states: int, seed: int, kappa: float):
    """
    Wraps fit_single_rslds with standard args; returns a dict summary.
    """
    try:
        save_dir.parent.mkdir(parents=True, exist_ok=True)
        if not h5_path.exists():
            return dict(status="skip_h5_missing", path=str(save_dir), msg=str(h5_path))
        if not csv_path.exists():
            return dict(status="skip_csv_missing", path=str(save_dir), msg=str(csv_path))

        print(f"→ fit: K={K_states}, seed={seed}, kappa={kappa:g}")
        res = fit_single_rslds(
            h5_path     = h5_path,
            csv_path    = csv_path,
            save_dir    = save_dir,
            K_states    = K_states,
            num_iters   = NUM_ITERS,
            kappa       = kappa,
            overwrite   = OVERWRITE,
            verbose     = VERBOSE,
            latent_dim  = LATENT_DIM,
            variance_goal = VARIANCE_GOAL,
        )
        return dict(status="ok", path=str(save_dir), msg="trained")
    except FileExistsError as e:
        return dict(status="skip_exists", path=str(save_dir), msg=str(e))
    except Exception as e:
        traceback.print_exc(limit=1)
        return dict(status="error", path=str(save_dir), msg=f"{type(e).__name__}: {e}")

# loop it
summ_rows = []

for rid in RAT_IDS:
    h5p = h5_path_for_rat(rid)

    # per-area/subset grid
    for area, subset in AREA_SUBSETS:
        csvp   = csv_path_area(rid, area, subset)
        suffix = suffix_from(area, subset)

        for K in K_GRID:
            for seed in SEEDS:
                for kappa in KAPPA_GRID:
                    savedir = run_dir(rid, suffix, K, seed, kappa)
                    base_model_dir(rid).mkdir(parents=True, exist_ok=True)

                    print(f"[Rat {rid:2d}] {suffix:20s}  K={K} seed={seed} kappa={kappa:g}")
                    info = run_fit(h5p, csvp, savedir, K, seed, kappa)
                    summ_rows.append({
                        "rat": rid, "suffix": suffix,
                        "K": K, "seed": seed, "kappa": kappa,
                        "status": info["status"], "path": info["path"], "msg": info["msg"]
                    })

    # combined responsive_all
    if RUN_RESPONSIVE_ALL:
        area, subset = None, "responsive"
        csvp   = csv_path_responsive_all(rid)
        suffix = suffix_from(area, subset)

        for K in K_GRID:
            for seed in SEEDS:
                for kappa in KAPPA_GRID:
                    savedir = run_dir(rid, suffix, K, seed, kappa)
                    base_model_dir(rid).mkdir(parents=True, exist_ok=True)

                    print(f"[Rat {rid:2d}] {suffix:20s}  K={K} seed={seed} kappa={kappa:g}")
                    info = run_fit(h5p, csvp, savedir, K, seed, kappa)
                    summ_rows.append({
                        "rat": rid, "suffix": suffix,
                        "K": K, "seed": seed, "kappa": kappa,
                        "status": info["status"], "path": info["path"], "msg": info["msg"]
                    })

# summary csv
summary_dir = DRIVE / "rSLDS-Model-Outputs" / "_run_summaries"
summary_dir.mkdir(parents=True, exist_ok=True)
summary_csv = summary_dir / "rslds_fits_summary.csv"
pd.DataFrame(summ_rows).to_csv(summary_csv, index=False)

# Cross Validation

In [None]:
DRIVE = Path("/content/drive/My Drive/rSLDS")

RAT_IDS = [3, 4, 8, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21]

AREA_SUBSETS = [
    ("ventral",  "allActive"),
    ("ventral",  "responsive"),
    ("dorsal",   "allActive"),
    ("dorsal",   "responsive"),
    ("thalamus", "allActive"),
    ("thalamus", "responsive"),
]

RUN_RESPONSIVE_ALL = True

K_GRID       = [4]     # use 4-state models (change if needed)
N_FOLDS      = 5
SEEDS        = [0, 1]  # optional repeats with different seeds
NUM_ITERS    = 300
VERBOSE      = True


# helper for path
def h5_path_for_rat(rid: int) -> Path:
    return DRIVE / "Rat-Data-hdf5" / f"Rat{rid}" / "NpxFiringRate_Behavior_SBL_10msBINS_0smoothing.hdf5"

def csv_path_area(rid: int, area: str, subset: str) -> Path:
    return DRIVE / "Sub-Data" / "Seperate-by-Area" / f"Rat{rid}" / f"area_splits_rat{rid}_{subset}" / f"{area}_wide.csv"

def csv_path_responsive_all(rid: int) -> Path:
    return DRIVE / "Sub-Data" / "Only-Responsive" / f"Rat{rid}" / f"area_splits_rat{rid}_responsive" / "responsive_rates_raw.csv"

def suffix_from(area: str | None, subset: str | None) -> str:
    if area is None and subset == "responsive":
        return "responsive_all"
    if subset == "responsive":
        return f"{area}_responsive"
    elif subset == "allActive":
        return f"{area}"
    else:
        raise ValueError(f"Unknown (area={area}, subset={subset})")

def cv_save_dir(rid: int, suffix: str, K: int, seed: int) -> Path:
    base = DRIVE / "rSLDS-Model-Outputs" / f"Rat{rid}-Model-Outputs" / "cross_validation"
    return base / f"models_Rat{rid}_{suffix}_K{K}_seed{seed}"

# run
def run_cv(h5_path: Path, csv_path: Path, save_dir: Path, K_states: int, seed: int):
    try:
        save_dir.mkdir(parents=True, exist_ok=True)
        if not h5_path.exists():
            return dict(status="skip_h5_missing", path=str(save_dir), msg=str(h5_path))
        if not csv_path.exists():
            return dict(status="skip_csv_missing", path=str(save_dir), msg=str(csv_path))

        print(f"→ CV: K={K_states}, seed={seed}, folds={N_FOLDS}")
        res = crossval_rslds(
            h5_path   = h5_path,
            csv_path  = csv_path,
            save_dir  = save_dir,
            K_states  = K_states,
            num_iters = NUM_ITERS,
            n_folds   = N_FOLDS,
            verbose   = VERBOSE,
        )
        try:
            if isinstance(res, pd.DataFrame):
                res.to_csv(save_dir / "cv_results.csv", index=False)
            elif isinstance(res, dict):
                pd.DataFrame([res]).to_csv(save_dir / "cv_results.csv", index=False)
        except Exception as _:
            pass

        return dict(status="ok", path=str(save_dir), msg="cv_done")
    except Exception as e:
        traceback.print_exc(limit=1)
        return dict(status="error", path=str(save_dir), msg=f"{type(e).__name__}: {e}")

# loop it
summ_rows = []

for rid in RAT_IDS:
    h5p = h5_path_for_rat(rid)
    # per-area/subset grid
    for area, subset in AREA_SUBSETS:
        csvp   = csv_path_area(rid, area, subset)
        suffix = suffix_from(area, subset)

        for K in K_GRID:
            for seed in SEEDS:
                out = cv_save_dir(rid, suffix, K, seed)
                print(f"[Rat {rid:2d}] {suffix:20s}  K={K} seed={seed}")
                info = run_cv(h5p, csvp, out, K, seed)
                summ_rows.append({
                    "rat": rid, "suffix": suffix, "K": K, "seed": seed,
                    "status": info["status"], "path": info["path"], "msg": info["msg"]
                })
    # combined responsive_all
    if RUN_RESPONSIVE_ALL:
        area, subset = None, "responsive"
        csvp   = csv_path_responsive_all(rid)
        suffix = suffix_from(area, subset)

        for K in K_GRID:
            for seed in SEEDS:
                out = cv_save_dir(rid, suffix, K, seed)
                print(f"[Rat {rid:2d}] {suffix:20s}  K={K} seed={seed}")
                info = run_cv(h5p, csvp, out, K, seed)
                summ_rows.append({
                    "rat": rid, "suffix": suffix, "K": K, "seed": seed,
                    "status": info["status"], "path": info["path"], "msg": info["msg"]
                })

# summary csv
summary_dir = DRIVE / "rSLDS-Model-Outputs" / "_run_summaries"
summary_dir.mkdir(parents=True, exist_ok=True)
summary_csv = summary_dir / "rslds_crossval_summary.csv"
pd.DataFrame(summ_rows).to_csv(summary_csv, index=False)
