Written by Hadrien Padilla - July 30th 2025 

Purpose: 

Analysis of different nets to investigate the hippocampus' mechanism of representation of reward locations. The study focuses on a four fold evaluation: with and without reward representation crossed with and without reward directed behavior. 

Focus of Analysis: 

- Single-field analysis 
- Complex cell analysis 
- Population analysis 
- Shifting over time (over the course of training at different steps )

October 2025 update:

- Logic has been added to look into the placement of peaks within complex cells when there is more than one peak present. 
- Code refactored for better organization

# Imports

In [10]:
# Imports
%load_ext autoreload
%autoreload 2

import os
import re
import glob
import pandas as pd
import numpy as np
import itertools
import torch
import random
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter, maximum_filter, generate_binary_structure, iterate_structure
from itertools import chain
import plotly.express as px
import plotly.graph_objects as go


from prnn.utils.predictiveNet import PredictiveNet
from prnn.utils.agent import RatInABoxAgent, RandomActionAgent
from prnn.utils.env import make_env
from prnn.utils.general import saveFig
from prnn.utils.figures import TrainingFigure
from prnn.analysis.SpatialTuningAnalysis import SpatialTuningAnalysis
from prnn.analysis.OfflineTrajectoryAnalysis import OfflineTrajectoryAnalysis
from prnn.analysis.OfflineActivityAnalysis import SpontaneousActivityAnalysis
from prnn.analysis.representationalGeometryAnalysis import representationalGeometryAnalysis
from prnn.analysis.TuningCurveAnalysis import TuningCurveAnalysis

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Statistical Test Functions

In [11]:
from scipy.stats import chi2  # for p-value
#method to compute chi-squared test to see if there is significance with the fraction of single-field place cell distribution
#NOTE: you need an expected count of 5 per bin for chi square to be reliable 
def compute_chi_squared(counts):
    expected = np.mean(counts)  # uniform distribution
    chi2_stat = np.sum((counts - expected)**2 / expected)

    df = counts.size - 1
    pval = 1.0 - chi2.cdf(chi2_stat, df)

    return chi2_stat, pval, df, expected


def compute_chi_squared_with_residuals(counts):
    expected = np.mean(counts)
    chi2_stat = np.sum((counts - expected)**2 / expected)
    df = counts.size - 1
    pval = 1.0 - chi2.cdf(chi2_stat, df)

    # standardized residuals
    residuals = (counts - expected) / np.sqrt(expected)

    return chi2_stat, pval, df, expected, residuals

# Analysis Functions

## bins and circles

In [12]:
def counts_within_radius(points_xy, centers_xy, radius):
    if len(centers_xy) == 0: return np.array([], dtype=int)
    if len(points_xy) == 0:  return np.zeros(len(centers_xy), dtype=int)
    diff = points_xy[None,:,:] - centers_xy[:,None,:]
    d2 = np.sum(diff*diff, axis=2)
    return np.sum(d2 <= radius**2, axis=1).astype(int)

def count_angular_bins(points_xy,
                       center=(0.6, 0.6),
                       n_slices=6,
                       start_angle=0.0,
                       radius=None):
    pts = np.asarray(points_xy, float)
    c = np.asarray(center, float)
    v = pts - c
    r = np.hypot(v[:,0], v[:,1])
    theta = (np.arctan2(v[:,1], v[:,0])- start_angle) % (2*np.pi)
    w = 2*np.pi / n_slices
    idx = np.floor(theta / w).astype(int)          # 0..n_slices-1
    counts = np.bincount(idx[idx >= 0], minlength=n_slices)
    return counts, idx

def build_bins(points_xy, *, center=(0.6,0.6), n_slices=6, start_angle=0.0, radius=0.6):
    import numpy as np

    pts = np.asarray(points_xy, float)
    if pts.size == 0:
        return {"counts": [0]*n_slices, "fractions": [0.0]*n_slices}
    if pts.ndim == 1 and pts.size == 2:
        pts = pts.reshape(1, 2)
    try:
        counts, _ = count_angular_bins(pts, center=center, n_slices=n_slices,
                                       start_angle=start_angle, radius=radius)
    except NameError:
        return None

    counts = [int(c) for c in counts]
    total = sum(counts) or 1  # avoid div-by-zero
    return {
        "counts": counts,
        "fractions": [c/total for c in counts]
    }

def build_circles(points_xy, reward_xy, *, radius=0.05):
    counts = counts_within_radius(points_xy, reward_xy, radius)
    total  = max(len(points_xy), 1)
    return {"counts":[int(c) for c in counts],
            "fractions":[float(c)/float(total) for c in counts],
            "radius": float(radius)}

def _safe_pct(nums, totals):
    """Return 100*nums/totals with NaN where totals==0 (no warnings)."""
    return np.divide(100.0 * nums, totals,
                     out=np.full_like(nums, np.nan, dtype=float),
                     where=(totals > 0))

def totals_and_pct_in_circles(df, cat):
    """
    Totals from df['bins'] (denominator), numerator from df['circles'].
    Returns (totals, pct).
    """
    totals = df["bins"].apply(lambda b: sum(b[cat]["counts"])).to_numpy()
    nums   = df["circles"].apply(lambda c: sum(c[cat]["counts"])).to_numpy()
    return totals, _safe_pct(nums, totals)

def totals_and_pct_in_bins(df, cat):
    """
    Totals and numerator from df['bins'].
    'nums' is the sum of selected reward bins (indices 0,1,3) — change if needed.
    Returns (totals, pct).
    """
    totals = df["bins"].apply(lambda b: sum(b[cat]["counts"])).to_numpy()
    nums   = df["bins"].apply(lambda b: sum(b[cat]["counts"][i] for i in (0,1,3))).to_numpy()
    return totals, _safe_pct(nums, totals)

EXTRACTORS = {
    "circles": totals_and_pct_in_circles,
    "bins":    totals_and_pct_in_bins,
}

## Coordinates

In [17]:
def com_center(tc):
    y, x = np.indices(tc.shape)
    w = np.clip(tc - np.nanmin(tc), 0, None)  # nonnegative weights
    s = w.sum()
    return (int(round((y*w).sum()/s)), int(round((x*w).sum()/s)))

# def centers_xy_from_rc(TCA, idx, env, use_start_offset=True):
#     rc = np.array([com_center(TCA.tuning_curves[i]) for i in idx], dtype=int)  # (N,2) [r,c]
#     coords = env.env.discrete_coords
#     off = int(getattr(TCA, "start_pos", 0)) if use_start_offset else 0
#     return np.array([coords[r+off, c+off] for r,c in rc], dtype=float)



def centers_xy_from_rc(TCA, idx, env, use_start_offset=True):
    # Handle empty index case
    if len(idx) == 0:
        return np.zeros((0, 2), dtype=float)
    
    rc = np.array([com_center(TCA.tuning_curves[i]) for i in idx], dtype=int)  # (N,2) [r,c]
    
    # Ensure rc is 2D even if there's only one cell
    if rc.ndim == 1:
        rc = rc.reshape(1, -1)
    
    coords = np.asarray(env.env.discrete_coords)                                # shape (H, W, …)
    H, W = coords.shape[:2]
    off = int(getattr(TCA, "start_pos", 0)) if use_start_offset else 0

    rr = rc[:, 0] + off
    cc = rc[:, 1] + off
    mask = (rr >= 0) & (rr < H) & (cc >= 0) & (cc < W)
    if not mask.all():
        pass
    rr, cc = rr[mask], cc[mask]

    return coords[rr, cc].astype(float)


def _peak_coords_2d_single(tc, threshold=0.5, neighborhood=1):
    """Return list of (row, col) local maxima in a single 2D array."""
    arr = np.where(np.isnan(tc), -np.inf, tc)
    struct = generate_binary_structure(2, neighborhood)  # 8-connected for neighborhood=1
    local_max = arr == maximum_filter(arr, footprint=struct, mode="nearest")
    mask = arr >= threshold
    peaks = np.argwhere(local_max & mask)
    return [tuple(rc) for rc in peaks]



## Analysis Run

In [5]:
os.chdir('/Users/hadrienpadilla/Documents/McGill/Peyrache Lab/pRNN')
savefolder = 'Data/hadrien_analyzed_nets/GD_reward_repeats/'
netfolder = 'for_Hadrien/GD_reward_repeats/3003/'
netfiles = glob.glob(os.path.join('nets', netfolder, '*.dill'))
netnames = [os.path.splitext(os.path.basename(f))[0] for f in netfiles]
rows = []
print(netnames)

['multRNN_5win_i2_o23-_GD_repeat_25_ep5_3003-s1042_ep5', 'multRNN_5win_i2_o23-_GD_repeat_10_ep5_3003-s1042_ep5', 'multRNN_5win_i2_o23-_GD_repeat_50_ep5_3003-s1042_ep5', 'multRNN_5win_i2_o23-_GD_repeat_100_ep5_3003-s1042_ep5', 'multRNN_5win_i2_o23-_GD_repeat_5_ep5_3003-s1042_ep5', 'multRNN_5win_i2_o23-_GD_repeat_150_ep5_3003-s1042_ep5', 'multRNN_5win_i2_o23-_GD_repeat_75_ep5_3003-s1042_ep5']


In [6]:
for netname in netnames:
    predictiveNet = PredictiveNet.loadNet(netfolder+netname, dil=True)
    env = predictiveNet.EnvLibrary[0]         
    predictiveNet.env_shell = env
    agent = RatInABoxAgent('Vis')
    env.wanb_log = False
    TCA = TuningCurveAnalysis(predictiveNet)

    # selections
    idx_single  = np.flatnonzero(TCA.groupID == 2)
    idx_complex = np.flatnonzero((TCA.groupID == 5) & (TCA.metrics['EVs'] >= 0.5))
    idx_union   = np.unique(np.concatenate([idx_single, idx_complex]))

    # centers
    XY_single  = centers_xy_from_rc(TCA, idx_single,  env)
    XY_complex = centers_xy_from_rc(TCA, idx_complex, env)
    XY_union   = centers_xy_from_rc(TCA, idx_union,   env)


    # cells list: union with type per cell
    single_set  = set(idx_single.tolist())
    complex_set = set(idx_complex.tolist())
    cells = []
    for i, xy in zip(idx_union.tolist(), XY_union.tolist()):
        ctype = "single" if i in single_set else "complex"
        cells.append({"idx": int(i), "center": [float(xy[0]), float(xy[1])], "type": ctype})


    # bins (per selection)
    bins = {
        "single":          build_bins(XY_single),
        "single+complex":  build_bins(np.vstack([XY_single, XY_complex]) if len(XY_complex) else XY_single),
    }

    # circles (per selection)
    rewards_xy = getattr(env.Reward, "place_cell_centres", None)
    print(rewards_xy)
    circles = {
        "single":          build_circles(XY_single, rewards_xy, radius=0.05),
        "single+complex":  build_circles(np.vstack([XY_single, XY_complex]) if len(XY_complex) else XY_single,
                                         rewards_xy, radius=0.05),
    }

    rows.append({
        "netname": netname,
        "bins": bins,
        "cells": cells,
        "circles": circles,
    })

Net Loaded from pathname


  print(f"loss: {steploss:>.2}, sparsity: {sparsity:>.2}, meanrate: {meanrate:>.2} [{bb:>5d}\{num_trials:>5d}]")
  print(f"loss: {steploss:>.2}, sparsity: {sparsity:>.2}, meanrate: {meanrate:>.2} [{bb:>5d}\{num_trials:>5d}]")


KeyboardInterrupt: 

In [10]:
rewards_xy = getattr(env.Reward, "place_cell_centres", None)
print(rewards_xy)

[[0.95       0.87692308]
 [0.55       0.78461538]
 [0.35       0.32307692]
 [1.21       0.6       ]]


In [25]:
#save
os.makedirs(savefolder, exist_ok=True)
df = pd.DataFrame(rows).sort_values("netname").reset_index(drop=True)
out_pkl = os.path.join(savefolder, "GD_reward_repeats_3003.pkl")
df.to_pickle(out_pkl)
print("Saved:", out_pkl)

Saved: Data/hadrien_analyzed_nets/GD_reward_repeats/GD_reward_repeats_3003.pkl


## Complex peak Analysis Functions

In [14]:
def _normalize_tc(tc):
    """Normalize one 2D tuning curve to [0,1], NaN-safe."""
    tc = np.asarray(tc, dtype=float)
    mn = np.nanmin(tc)
    mx = np.nanmax(tc)
    if not np.isfinite(mn) or not np.isfinite(mx) or mx <= mn:
        # all NaNs or flat — return zeros so it won’t produce peaks
        return np.zeros_like(tc, dtype=float)
    return (tc - mn) / (mx - mn)

def identify_peaks(TCA, complexCells, threshold=0.5, neighborhood=1, normalize=True):
    """
    TCA.tuning_curves: array-like of shape (N, H, W)
    complexCells: iterable of cell indices (e.g., [0, 3, 7])
    Returns: dict {cell_index: [(r,c), ...], ...}
    """
    peaks_by_cell = {}
    tcs = TCA.tuning_curves  # assume shape (N, H, W)

    for i in complexCells:
        tc = np.array(tcs[i], dtype=float)
        if normalize:
            tc = _normalize_tc(tc)
        peaks_by_cell[i] = _peak_coords_2d_single(tc, threshold=threshold, neighborhood=neighborhood)

    return peaks_by_cell

In [None]:
def _to_2d(a):
    """Return a 2-D float array from array-like or dict->{array}."""
    if a is None or (isinstance(a, float) and np.isnan(a)):
        raise ValueError("missing")
    if isinstance(a, dict):
        a = a[0] if 0 in a else a[min(a.keys())]
    a = np.asarray(a, dtype=float).squeeze()
    if a.ndim == 2:
        return a
    if a.ndim == 3:
        return a[..., 0] if a.shape[-1] <= 4 else a[0, ...]
    raise ValueError(f"expected 2D array, got {a.shape}")

#normalize values between 0 and 1 
def _normalize(a):
    a = np.asarray(a, dtype=float)
    lo, hi = np.nanmin(a), np.nanmax(a)
    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.zeros_like(a)
    return (a - lo) / (hi - lo)

# ---------- peak finding ----------
def _peak_coords(a2d, threshold=0.3, neighborhood=2, min_distance=3, sigma=0.8):
    # Get raw and smoothed arrays
    raw_arr = np.array(a2d)
    smooth_arr = gaussian_filter(a2d, sigma=sigma)
    
    # Create footprint for local maximum detection
    fp = generate_binary_structure(raw_arr.ndim, neighborhood)
    
    # Find peaks separately in raw and smoothed data
    raw_peaks = np.argwhere((raw_arr == maximum_filter(raw_arr, footprint=fp)) & 
                           (raw_arr >= threshold))
    smooth_peaks = np.argwhere((smooth_arr == maximum_filter(smooth_arr, footprint=fp)) & 
                              (raw_arr >= threshold))  # Still use raw_arr for threshold
    
    # Combine peaks
    all_peaks = np.unique(np.vstack([raw_peaks, smooth_peaks]), axis=0)
    
    # Sort peaks by raw value (highest first)
    peak_values = raw_arr[all_peaks[:, 0], all_peaks[:, 1]]
    sorted_idx = np.argsort(-peak_values)
    all_peaks = all_peaks[sorted_idx]
    
    # Filter by minimum distance
    filtered_peaks = []
    for peak in all_peaks:
        if not filtered_peaks:
            filtered_peaks.append(peak)
            continue
        
        distances = np.sqrt(np.sum((np.array(filtered_peaks) - peak)**2, axis=1))
        if np.all(distances > min_distance):
            filtered_peaks.append(peak)
    
    return [tuple(p) for p in filtered_peaks]
    


def identify_peaks(cells, *, threshold=0.6, neighborhood=1, normalize=True, cell_ids=None, sigma = 0.8, min_distance = 3): #these parameters found after fine-tuning
    out = {}
    for i, cell in enumerate(cells):
        try:
            a = _to_2d(cell)
        except ValueError:
            continue
        if normalize:
            a = _normalize(a)
        key = cell_ids[i] if (cell_ids is not None and i < len(cell_ids)) else i
        out[key] = _peak_coords(a, threshold=threshold, neighborhood=neighborhood, min_distance = min_distance, sigma = sigma)
    return out

def _grid_xy(env, H, W):
    """Return (H,W,2) array of bin centers (x,y)."""
    e = getattr(env, "env", env)
    coords = getattr(e, "discrete_coords", None)
    if coords is not None and np.asarray(coords).shape[:2] == (H, W):
        return np.asarray(coords)
    xedges, yedges = getattr(e, "xedges", None), getattr(e, "yedges", None)
    if xedges is not None and yedges is not None:
        xc = (np.asarray(xedges)[:-1] + np.asarray(xedges)[1:]) / 2
        yc = (np.asarray(yedges)[:-1] + np.asarray(yedges)[1:]) / 2
        X, Y = np.meshgrid(xc, yc)
        return np.stack([X, Y], -1)
    X, Y = np.meshgrid((np.arange(W) + 0.5) / W, (np.arange(H) + 0.5) / H)
    return np.stack([X, Y], -1)


def peaks_to_xy_centers(peaks_by_cell, pf_dict, env):
    # discover grid
    out = {}

    if not peaks_by_cell:
        return out

    first = next(pf_dict[k] for k in peaks_by_cell if k in pf_dict)
    H, W = _to_2d(first).shape
    grid = _grid_xy(env, H, W)

    for cid, rc_list in peaks_by_cell.items():
        if cid not in pf_dict or not rc_list:
            out[cid] = []
            continue
        a2d = _to_2d(pf_dict[cid])
        xy = []
        for r, c in rc_list:
            ri, ci = int(r), int(c)
            x, y = grid[ri, ci]
            xy.append((float(x), float(y)))
        out[cid] = xy
    return out

## Updated Analysis Run (with mult-peaks)

In [18]:
os.chdir('/Users/hadrienpadilla/Documents/McGill/Peyrache Lab/pRNN')
savefolder = 'Data/hadrien_analyzed_nets/rand_rew_mult/'
netfolder = 'for_Hadrien/random_rew_mult/2002/'
netfiles = glob.glob(os.path.join('nets', netfolder, '*.pkl'))
netnames = [os.path.splitext(os.path.basename(f))[0] for f in netfiles]
rows = []
print(netnames)

['multRNN_5win_i2_o23-multiply_75_ep5_2002-s1042_ep5', 'multRNN_5win_i2_o23-multiply_150_ep5_2002-s1042_ep5', 'multRNN_5win_i2_o23-multiply_50_ep5_2002-s1042_ep5', 'multRNN_5win_i2_o23-multiply_25_ep5_2002-s1042_ep5', 'multRNN_5win_i2_o23-multiply_10_ep5_2002-s1042_ep5', 'multRNN_5win_i2_o23-multiply_100_ep5_2002-s1042_ep5', 'multRNN_5win_i2_o23-multiply_5_ep5_2002-s1042_ep5']


In [20]:
rows = []  

netnames = [os.path.splitext(os.path.basename(f))[0] for f in netfiles]

for netname in netnames:
    predictiveNet = PredictiveNet.loadNet(netfolder+netname, dil=False)
    env = predictiveNet.EnvLibrary[0]         
    predictiveNet.env_shell = env
    agent = RatInABoxAgent('Vis')
    env.wanb_log = False
    TCA = TuningCurveAnalysis(predictiveNet)

    # selections
    idx_single  = np.flatnonzero(TCA.groupID == 2)
    idx_complex = np.flatnonzero((TCA.groupID == 5) & (TCA.metrics['EVs'] >= 0.5))
    idx_union   = np.unique(np.concatenate([idx_single, idx_complex]))

    # centers
    XY_single  = centers_xy_from_rc(TCA, idx_single,  env)
    XY_complex = centers_xy_from_rc(TCA, idx_complex, env)
    XY_union   = centers_xy_from_rc(TCA, idx_union,   env)

    pf_dict = predictiveNet.TrainingSaver.place_fields[0]

    pf_list = [pf_dict[int(i)] for i in idx_complex]   # int() for safety
    peaks_rc_by_cell = identify_peaks(
        pf_list,  threshold=0.6, neighborhood=1, sigma=0.8, min_distance=3
    )

    pf_dict_for_complex = {int(i): pf_dict[int(i)] for i in idx_complex}
    print("peaks_rc_by_cell keys:", list(peaks_rc_by_cell.keys())[:20])
    print("pf_dict_for_complex keys:", list(pf_dict_for_complex.keys())[:20])
    print("intersection:", set(peaks_rc_by_cell) & set(pf_dict_for_complex))
    peaks_xy_by_cell = peaks_to_xy_centers(peaks_rc_by_cell, pf_dict_for_complex, env)

    peaks_list_xy = [xy for lst in peaks_xy_by_cell.values() for xy in lst]
    XY_complex_mult_peaks_arr = (
        np.asarray(peaks_list_xy, dtype=float).reshape(-1, 2) if len(peaks_list_xy) else np.zeros((0, 2))
    )


    # cells list: union with type per cell
    single_set  = set(idx_single.tolist())
    complex_set = set(idx_complex.tolist())
    cells = []
    for i, xy in zip(idx_union.tolist(), XY_union.tolist()):
        ctype = "single" if i in single_set else "complex"
        cells.append({"idx": int(i), "center": [float(xy[0]), float(xy[1])], "type": ctype})

    complex_peaks = []
    for i in idx_complex:
        # Option A: store row/col indices
        peaks_rc = peaks_rc_by_cell.get(int(i), [])
        complex_peaks.append({"idx": int(i), "peaks": [(int(r), int(c)) for r, c in peaks_rc]})

    # bins (per selection)
    bins = {
        "single":         build_bins(XY_single),
        "single+complex": build_bins(np.vstack([XY_single, XY_complex]) if len(XY_complex) else XY_single),
        "single+complex_mult_peaks": build_bins(
            np.vstack([XY_single, XY_complex_mult_peaks_arr]) if len(XY_complex_mult_peaks_arr) else XY_single
        ),
    }

    # circles (per selection)
    rewards_xy = getattr(env.Reward, "place_cell_centres", None)
    circles = {
        "single":         build_circles(XY_single, rewards_xy, radius=0.05),
        "single+complex": build_circles(
            np.vstack([XY_single, XY_complex]) if len(XY_complex) else XY_single, rewards_xy, radius=0.05
        ),
        "single+complex_mult_peaks": build_circles(
            np.vstack([XY_single, XY_complex_mult_peaks_arr]) if len(XY_complex_mult_peaks_arr) else XY_single,
            rewards_xy, radius=0.05
        ),
    }

    rows.append({
        "netname": netname,
        "bins": bins,
        "cells": cells,
        "circles": circles,
        "complex_peaks": complex_peaks
    })

Net Loaded from pathname
Calculating EV_s
Running WAKE


  border_score = (mean_border_rate - mean_nonborder_rate) / (mean_border_rate + mean_nonborder_rate)


peaks_rc_by_cell keys: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
pf_dict_for_complex keys: [3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 20, 21, 23, 24, 25, 27, 30, 32, 33, 38]
intersection: {3, 4, 7, 9, 10, 11, 13, 14, 15, 18, 20, 21, 23, 24, 25, 27, 30, 32, 33, 38, 40, 42, 46, 53, 57, 64, 68, 69, 70, 71, 73, 74, 75, 79, 80, 81, 83, 85, 86, 88, 92, 98, 102, 103, 107, 109, 111, 112, 113, 116, 118, 121, 122, 124, 127, 132, 136, 137, 138, 140, 142, 143, 144, 145, 147, 149, 153, 159, 160, 166, 167, 168, 169, 170, 173, 180, 185, 187, 194, 196, 198, 201, 202, 205, 206, 213, 214, 215, 216, 217, 219, 221, 224, 225, 226, 227, 228, 234, 237, 240, 242, 243, 244, 245, 246, 251, 253, 254, 255, 262, 263, 266, 271, 273, 275, 277, 280, 281, 282, 283, 284, 288, 290, 291, 292, 295, 298, 301, 302, 304}
Net Loaded from pathname
Calculating EV_s
Running WAKE
peaks_rc_by_cell keys: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
pf_dict_for_complex keys: [3, 4, 5,

  fxfr = fx / fr
  SI = SI / fr[:, 0, 0]
  fxfr = fx / fr


Calculating EV_s
Running WAKE


  border_score = (mean_border_rate - mean_nonborder_rate) / (mean_border_rate + mean_nonborder_rate)


peaks_rc_by_cell keys: []
pf_dict_for_complex keys: []
intersection: set()


  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled == whicharea
  centerlabeled = labeled

StopIteration: 

In [None]:
#Saving
os.makedirs(savefolder, exist_ok=True)
df = pd.DataFrame(rows).sort_values("netname").reset_index(drop=True)

out_pkl = os.path.join(savefolder, "new_multiply-2002_new.pkl")
df.to_pickle(out_pkl)
print("Saved:", out_pkl)

Saved: Data/hadrien_analyzed_nets/GD_no_reward/GD_noRew_1042.pkl


In [22]:
os.chdir('/Users/hadrienpadilla/Documents/McGill/Peyrache Lab/pRNN')
savefolder = 'Data/hadrien_analyzed_nets/GD_rew_mults/'
netfolder = 'for_Hadrien/GD_reward_repeats/1042/'
netfiles = glob.glob(os.path.join('nets', netfolder, '*.dill'))
netnames = [os.path.splitext(os.path.basename(f))[0] for f in netfiles]
rows = []
print(netnames)

rows = []  

netnames = [os.path.splitext(os.path.basename(f))[0] for f in netfiles]

for netname in netnames:
    predictiveNet = PredictiveNet.loadNet(netfolder+netname, dil=True)
    env = predictiveNet.EnvLibrary[0]         
    predictiveNet.env_shell = env
    agent = RatInABoxAgent('Vis')
    env.wanb_log = False
    TCA = TuningCurveAnalysis(predictiveNet)

    # selections
    idx_single  = np.flatnonzero(TCA.groupID == 2)
    idx_complex = np.flatnonzero((TCA.groupID == 5) & (TCA.metrics['EVs'] >= 0.5))
    idx_union   = np.unique(np.concatenate([idx_single, idx_complex]))

    # centers
    XY_single  = centers_xy_from_rc(TCA, idx_single,  env)
    XY_complex = centers_xy_from_rc(TCA, idx_complex, env)
    XY_union   = centers_xy_from_rc(TCA, idx_union,   env)

    pf_dict = predictiveNet.TrainingSaver.place_fields[0]

    pf_list = [pf_dict[int(i)] for i in idx_complex]   # int() for safety
    peaks_rc_by_cell = identify_peaks(
        pf_list, cell_ids=idx_complex, threshold=0.6, neighborhood=1, sigma=0.8, min_distance=3
    )

    pf_dict_for_complex = {int(i): pf_dict[int(i)] for i in idx_complex}
    peaks_xy_by_cell = peaks_to_xy_centers(peaks_rc_by_cell, pf_dict_for_complex, env)

    peaks_list_xy = [xy for lst in peaks_xy_by_cell.values() for xy in lst]
    XY_complex_mult_peaks_arr = (
        np.asarray(peaks_list_xy, dtype=float).reshape(-1, 2) if len(peaks_list_xy) else np.zeros((0, 2))
    )


    # cells list: union with type per cell
    single_set  = set(idx_single.tolist())
    complex_set = set(idx_complex.tolist())
    cells = []
    for i, xy in zip(idx_union.tolist(), XY_union.tolist()):
        ctype = "single" if i in single_set else "complex"
        cells.append({"idx": int(i), "center": [float(xy[0]), float(xy[1])], "type": ctype})

    complex_peaks = []
    for i in idx_complex:
        # Option A: store row/col indices
        peaks_rc = peaks_rc_by_cell.get(int(i), [])
        complex_peaks.append({"idx": int(i), "peaks": [(int(r), int(c)) for r, c in peaks_rc]})

    # bins (per selection)
    bins = {
        "single":         build_bins(XY_single),
        "single+complex": build_bins(np.vstack([XY_single, XY_complex]) if len(XY_complex) else XY_single),
        "single+complex_mult_peaks": build_bins(
            np.vstack([XY_single, XY_complex_mult_peaks_arr]) if len(XY_complex_mult_peaks_arr) else XY_single
        ),
    }

    # circles (per selection)
    rewards_xy = getattr(env.Reward, "place_cell_centres", None)
    circles = {
        "single":         build_circles(XY_single, rewards_xy, radius=0.05),
        "single+complex": build_circles(
            np.vstack([XY_single, XY_complex]) if len(XY_complex) else XY_single, rewards_xy, radius=0.05
        ),
        "single+complex_mult_peaks": build_circles(
            np.vstack([XY_single, XY_complex_mult_peaks_arr]) if len(XY_complex_mult_peaks_arr) else XY_single,
            rewards_xy, radius=0.05
        ),
    }

    rows.append({
        "netname": netname,
        "bins": bins,
        "cells": cells,
        "circles": circles,
        "complex_peaks": complex_peaks
    })

#Saving
os.makedirs(savefolder, exist_ok=True)
df = pd.DataFrame(rows).sort_values("netname").reset_index(drop=True)

out_pkl = os.path.join(savefolder, "GD_reward_mult_1042_new.pkl")
df.to_pickle(out_pkl)
print("Saved:", out_pkl)


['multRNN_5win_i2_o23-_GD_repeat_150_ep5-s1042_ep5-cpu', 'multRNN_5win_i2_o23-_GD_repeat_75_ep5-s1042_ep5-cpu', 'multRNN_5win_i2_o23-_GD_repeat_100_ep5-s1042_ep5-cpu', 'multRNN_5win_i2_o23-_GD_repeat_5_ep5-s1042_ep5-cpu', 'multRNN_5win_i2_o23-_GD_repeat_10_ep5-s1042_ep5-cpu', 'multRNN_5win_i2_o23-_GD_repeat_25_ep5-s1042_ep5-cpu', 'multRNN_5win_i2_o23-_GD_repeat_50_ep5-s1042_ep5-cpu']
Net Loaded from pathname


KeyboardInterrupt: 

# Add complex peaks

In [25]:
os.chdir('/Users/hadrienpadilla/Documents/McGill/Peyrache Lab/pRNN')
savefolder = 'Data/hadrien_analyzed_nets/GD_reward_repeats/'
net_root   = 'nets'
net_folder = 'for_Hadrien/GD_reward_repeats/1042/'  # <- your folder
pattern    = os.path.join(net_root, net_folder, '*.pkl')

def repeat_key(name): return int(re.search(r'repeat_(\d+)', name).group(1))

netfiles = sorted(glob.glob(pattern), key=lambda p: repeat_key(os.path.basename(p)))
netnames  = [os.path.splitext(os.path.basename(p))[0] for p in netfiles]

analysis = pd.read_pickle('Data/hadrien_analyzed_nets/GD_reward_repeats/GD_reward_repeats_1042.pkl')
print(analysis)
analysis = analysis.sort_values(
    'netname',
    key=lambda s: s.str.extract(r'repeat_(\d+)').astype(int)[0],
    kind='stable'
).reset_index(drop=True)

rows = []

                                             netname  \
0  multRNN_5win_i2_o23-_GD_repeat_100_ep5-s1042_e...   
1  multRNN_5win_i2_o23-_GD_repeat_150_ep5-s1042_e...   
2  multRNN_5win_i2_o23-_GD_repeat_5_ep5-s1042_ep5...   
3  multRNN_5win_i2_o23-_GD_repeat_75_ep5-s1042_ep...   
4  multRNN_5win_i2_o23-_GD_repeat_10_ep5-s1042_ep...   
5  multRNN_5win_i2_o23-_GD_repeat_25_ep5-s1042_ep...   
6  multRNN_5win_i2_o23-_GD_repeat_50_ep5-s1042_ep...   

                                                bins  \
0  {'single': {'counts': [8, 15, 17, 23, 19, 12],...   
1  {'single': {'counts': [9, 18, 20, 21, 12, 12],...   
2  {'single': {'counts': [5, 19, 16, 22, 16, 14],...   
3  {'single': {'counts': [9, 14, 19, 23, 17, 12],...   
4  {'single': {'counts': [5, 15, 19, 22, 13, 15],...   
5  {'single': {'counts': [6, 15, 19, 23, 14, 16],...   
6  {'single': {'counts': [7, 16, 23, 24, 15, 15],...   

                                               cells  \
0  [{'idx': 1, 'center': [0.59375, 0.656254020

In [26]:
index = 0
for netname in netnames:
    predictiveNet = PredictiveNet.loadNet(net_folder+netname)
    env = predictiveNet.EnvLibrary[0]         
    predictiveNet.env_shell = env
    agent = RatInABoxAgent('Vis')
    env.wanb_log = False

    curr_analysis = analysis.iloc[index]

    bins = curr_analysis.bins
    cells = curr_analysis.cells
    circles = curr_analysis.circles


    single_pos = [i for i, d in enumerate(cells) if isinstance(d, dict) and d.get('type') == 'single']
    XY_single = [d['center'] for d in cells if isinstance(d, dict) and d.get('type') == 'single']
    
    complex_pos = [i for i, d in enumerate(cells) if isinstance(d, dict) and d.get('type') == 'complex']
    complex_ids = [d['idx'] for d in cells if isinstance(d, dict) and d.get('type') == 'complex']

    pf_dict = predictiveNet.TrainingSaver.place_fields[0]
    env_obj = predictiveNet.EnvLibrary[0]
    
    ids = [cid for cid in complex_ids if cid in pf_dict]
    complex_pfs = [pf_dict[cid] for cid in ids]
    peaks1 = identify_peaks(complex_pfs, threshold=0.6, cell_ids = ids)
    peaks_list = [xy for lst in peaks1.values() for xy in lst]
    XY_complex_mult_peaks_arr = np.array(peaks_list, dtype=float).reshape(-1, 2)

    complex_peaks = []
    for i in complex_ids:
        peaks = peaks1.get(i, [])
        complex_peaks.append({"idx": int(i), "peaks": [(int(r), int(c)) for r,c in peaks]})

        # bins (per selection)
    bins_new = {
        "single":          bins['single'],
        "single+complex":  bins['single+complex'],
        "single+complex_mult_peaks": build_bins(np.vstack([XY_single, XY_complex_mult_peaks_arr]) if len(XY_complex_mult_peaks_arr) else XY_single)
    }

    rewards_xy = getattr(env.Reward, "place_cell_centres", None)
    circles_new = {
        "single":          circles['single'],
        "single+complex":  circles['single+complex'],
        "single+complex_mult_peaks": build_circles(np.vstack([XY_single, XY_complex_mult_peaks_arr]) if len(XY_complex_mult_peaks_arr) else XY_single,
                                         rewards_xy, radius=0.05)
    }

    rows.append ({
        "netname": netname,
        "bins": bins_new,
        "cells": cells,
        "circles": circles_new,
        "complex_peaks": complex_peaks
    })
    index += 1

In [27]:
df = pd.DataFrame(rows)
print("n_rows:", len(df))
print("columns:", list(df.columns))
print(df.head(3))


n_rows: 0
columns: []
Empty DataFrame
Columns: []
Index: []


In [22]:
os.makedirs(savefolder, exist_ok=True)
df = pd.DataFrame(rows).sort_values("netname").reset_index(drop=True)

out_pkl = os.path.join(savefolder, "new_repeat-1042-new.pkl")
df.to_pickle(out_pkl)
print("Saved:", out_pkl)

KeyError: 'netname'

# Plotting Functions

In [None]:
#function to visualize bins when plotting
def overlay_angular_bins(center=(0.6, 0.6), radius=0.6, n_slices=6, **line_kwargs):
    ax = plt.gca()
    for k in range(n_slices):
        ang = k * (2*np.pi/n_slices)
        end = center + radius * np.array([np.cos(ang), np.sin(ang)])
        ax.plot([center[0], end[0]], [center[1], end[1]],
                **{"color": "k", "linewidth": 1, **line_kwargs})
        

In [31]:
def _normalize_metric(metric: str) -> str:
    m = (metric or "").strip().lower()
    return "pct" if m in {"pct", "percent", "percentage", "%"} else "totals"

def short_label(full: str) -> str:
    s = full.lower()
    if re.search(r'no[-_]?reward', s):
        return 'no_reward'
    m = re.search(r'(repeat|multiply|multiple)[^0-9]*([0-9]+)', s)
    if m:
        kind, num = m.group(1), int(m.group(2))
        return f"{'r' if kind=='repeat' else 'm'}{num}"
    return full

def order_by_repeat(labels_iterable):
    """
    Build x-order: no_reward, then r* ascending, then m* ascending.
    labels_iterable can be any iterable of short labels.
    """
    labels = set(labels_iterable)
    has_nr = "no_reward" in labels
    r_nums = sorted({int(x[1:]) for x in labels if re.fullmatch(r"r\d+", x)})
    m_nums = sorted({int(x[1:]) for x in labels if re.fullmatch(r"m\d+", x)})
    return (["no_reward"] if has_nr else []) + [f"r{n}" for n in r_nums] + [f"m{n}" for n in m_nums]

In [32]:
#   TABLE BUILD
def build_long_table(
    dfs_by_seed,
    *,
    source="circles",           # "circles" or "bins"
    cat="single+complex",
    metric="pct",               # "pct" or "totals"
    include_base=False,
    base_df=None,
    base_label="base"
):
    if not dfs_by_seed:
        raise ValueError("dfs_by_seed is empty")
    if source not in EXTRACTORS:
        raise ValueError(f"source must be one of {list(EXTRACTORS)}, got {source!r}")
    metric = _normalize_metric(metric)
    extractor = EXTRACTORS[source]

    # Build x-axis order from ALL seeds (robust to mixed r*/m*/no_reward across seeds)
    labels_all = []
    for df_ in dfs_by_seed:
        if not {"netname", "bins", "circles"}.issubset(df_.columns):
            raise KeyError("Each dataframe must have columns: 'netname', 'bins', 'circles'")
        labels_all.extend(short_label(s) for s in df_["netname"])
    short_order = order_by_repeat(labels_all)

    rows = []
    for s_idx, df_seed in enumerate(dfs_by_seed, start=1):
        totals, pct = extractor(df_seed, cat)
        vals = pct if metric == "pct" else totals
        for cond, val in zip(df_seed["netname"], vals):
            rows.append({
                "condition": cond,
                "cond_short": short_label(cond),
                "full_name": cond,
                "value": val,
                "seed": f"seed_{s_idx}",
            })

        if include_base:
            if base_df is None or len(base_df) != 1:
                raise ValueError("include_base=True requires base_df with exactly one row.")
            b_totals, b_pct = extractor(base_df, cat)
            b_val = b_pct[0] if metric == "pct" else b_totals[0]
            rows.append({
                "condition": base_label,
                "cond_short": base_label,
                "full_name": "BASE: " + base_label,
                "value": b_val,
                "seed": f"seed_{s_idx}",
            })

    data = pd.DataFrame(rows)
    if include_base and base_label not in short_order:
        short_order.append(base_label)
    return data, short_order


In [33]:
def make_box_with_seed_lines(
    dfs_by_seed,
    *,
    source="circles",           # "circles" or "bins"
    cat="single+complex",
    metric="pct",               # "pct" or "totals"
    include_base=False,
    base_df=None,
    base_label="base",
    title=None
):
    data, short_order = build_long_table(
        dfs_by_seed, source=source, cat=cat, metric=metric,
        include_base=include_base, base_df=base_df, base_label=base_label
    )

    metric = _normalize_metric(metric)

    # 1) Box layer (no scatter in this trace to avoid double plotting)
    fig = px.box(
        data, x="cond_short", y="value",
        category_orders={"cond_short": short_order},
        template="plotly_white"
    )
    for tr in fig.data:
        if tr.type == "box":
            tr.boxmean = True
            tr.showlegend = False
            tr.hovertemplate = "%{x}: %{y:.2f}<extra></extra>"

    # 2) Seed lines (points connected across conditions)
    for seed_name, df_seed_points in data.groupby("seed", sort=False):
        # d = (df_seed_points
        #      .set_index("cond_short")
        #      .reindex(short_order)   # align to x-order
        #      .reset_index()
        #      .dropna(subset=["value"]))  # drop missing for this seed
        order = pd.CategoricalDtype(categories=short_order, ordered=True)
        d = (df_seed_points[df_seed_points['cond_short'].isin(short_order)]
            .assign(cond_short=pd.Categorical(df_seed_points['cond_short'], dtype=order))
            .sort_values('cond_short')
            .dropna(subset=['value']))
        fig.add_trace(go.Scatter(
            x=d["cond_short"], y=d["value"],
            mode="lines+markers",
            name=seed_name,
            customdata=d["full_name"],
            hovertemplate="%{x}: %{y:.2f}<br>%{customdata}<extra></extra>",
            line=dict(width=1.5),
            marker=dict(size=7, opacity=0.85)
        ))

    # 3) Titles / axes
    y_title = "Percent" if metric == "pct" else "Total counts"
    src_tag = f" ({source})"
    full_title = title or f"{cat} — {y_title}{src_tag}"
    if metric == "pct":
        y_title = "Percent" if source == "circles" else "Percent (bins-based %)"

    fig.update_layout(
        title=full_title,
        xaxis_title="Condition",
        yaxis_title=y_title,
        boxmode="group",
        legend_title_text="Seed",
        margin=dict(l=50, r=20, t=60, b=60)
    )
    fig.update_xaxes(tickangle=0)
    return fig

# Plots

groups = {
            'untuned':untuned,
            'HD_cells':HD_cells,
            'single_field':single_field,
            'border_cells':border_cells,
            'spatial_HD':spatial_HD,
            'complex_cells':complex_cells
        }

Bins with rewards: 1,2,4

### Base Net

In [11]:
base_net = pd.read_pickle(os.path.join(savefolder, "base_net.pkl"))

### new_zip (seeds 1001-6006)

In [None]:
import pandas as pd
seed_1001 = pd.read_pickle(open('Figures/RiaB/1001_place_fields_per_net.pkl', 'rb'))
seed_2002 = pd.read_pickle(open('Figures/RiaB/2002_place_fields_per_net.pkl', 'rb'))
seed_3003 = pd.read_pickle(open('Figures/RiaB/3003_place_fields_per_net.pkl', 'rb'))
seed_4004 = pd.read_pickle(open('Figures/RiaB/4004_place_fields_per_net.pkl', 'rb'))
seed_5005 = pd.read_pickle(open('Figures/RiaB/5005_place_fields_per_net.pkl', 'rb'))
seed_6006 = pd.read_pickle(open('Figures/RiaB/6006_place_fields_per_net.pkl', 'rb'))
dfs_by_seed = [seed_1001, seed_2002, seed_3003, seed_4004, seed_5005, seed_6006]
fig1 = make_box_with_seed_lines(dfs_by_seed, cat="single", metric="pct")
fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, cat="single+complex", metric="pct")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct")
fig4.show()

### new_repeats (original + seeds 1001-6006)

In [30]:
original = pd.read_pickle(os.path.join(savefolder, "place_fields_per_net.pkl"))
original = original.drop(index=0).reset_index(drop=True)
seed_1001 = pd.read_pickle(open('Figures/RiaB/new_repeat-1001.pkl', 'rb'))
seed_2002 = pd.read_pickle(open('Figures/RiaB/new_repeat-2002.pkl', 'rb'))
seed_3003 = pd.read_pickle(open('Figures/RiaB/new_repeat-3003.pkl', 'rb'))
seed_4004 = pd.read_pickle(open('Figures/RiaB/new_repeat-4004.pkl', 'rb'))
seed_5005 = pd.read_pickle(open('Figures/RiaB/new_repeat-5005.pkl', 'rb'))
seed_6006 = pd.read_pickle(open('Figures/RiaB/new_repeat-6006.pkl', 'rb'))
dfs_by_seed = [original, seed_1001, seed_2002, seed_3003, seed_4004, seed_5005, seed_6006]
fig1 = make_box_with_seed_lines(dfs_by_seed,source = "circles", cat="single", metric="pct")
fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single+complex", metric="pct")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct")
fig4.show()

### new_multiply (original + seeds 1001-6006)

In [31]:
original = pd.read_pickle(os.path.join(savefolder, "place_fields_per_net(multiply).pkl"))
original = original.drop(index=0).reset_index(drop=True)
seed_1001 = pd.read_pickle(open('Figures/RiaB/new-multiply-1001.pkl', 'rb'))
seed_2002 = pd.read_pickle(open('Figures/RiaB/new-multiply-2002.pkl', 'rb'))
seed_3003 = pd.read_pickle(open('Figures/RiaB/new-multiply-3003.pkl', 'rb'))
seed_4004 = pd.read_pickle(open('Figures/RiaB/new-multiply-4004.pkl', 'rb'))
seed_5005 = pd.read_pickle(open('Figures/RiaB/new-multiply-5005.pkl', 'rb'))
seed_6006 = pd.read_pickle(open('Figures/RiaB/new-multiply-6006.pkl', 'rb'))
dfs_by_seed = [original, seed_1001, seed_2002, seed_3003, seed_4004, seed_5005, seed_6006]
fig1 = make_box_with_seed_lines(dfs_by_seed, source ='circles', cat="single", metric="pct")
fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, source ='circles', cat="single+complex", metric="pct")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct")
fig4.show() 

### GD Multiply 

In [12]:
savefolder = 'Figures/RiaB/'
GD_multiply = pd.read_pickle(os.path.join(savefolder, "GD_multiply.pkl"))
GD_repeat = pd.read_pickle(os.path.join(savefolder, "GD_repeat.pkl"))
dfs_by_seed = [GD_multiply]
fig1 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single+complex", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig4.show()

### GD Repeat

In [20]:
dfs_by_seed = [GD_repeat]
fig1 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single+complex", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig4.show()

### GD_no_reward

In [19]:
GD_no_reward = pd.read_pickle(os.path.join(savefolder, "new-GD-no_rew_2.pkl"))
dfs_by_seed = [GD_no_reward]
fig1 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single", metric="pct", include_base=True, base_df=base_net, base_label="base")

print()

fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, source = "circles", cat="single+complex", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct", include_base=True, base_df=base_net, base_label="base")
fig4.show()




### GD + reward repeats

In [26]:
print(savefolder)

Data/hadrien_analyzed_nets/GD_reward_repeats/


In [34]:
savefolder = 'Data/hadrien_analyzed_nets/'
base = pd.read_pickle(os.path.join(savefolder, "base/base_net.pkl"))
seed_1001 = pd.read_pickle(os.path.join(savefolder, 'GD_reward_repeats/GD_reward_repeats_1001.pkl'))
seed_2002 = pd.read_pickle(os.path.join(savefolder, 'GD_reward_repeats/GD_reward_repeats_2002.pkl'))
seed_3003 = pd.read_pickle(os.path.join(savefolder, 'GD_reward_repeats/GD_reward_repeats_3003.pkl'))
dfs_by_seed = [base, seed_1001, seed_2002, seed_3003]
fig1 = make_box_with_seed_lines(dfs_by_seed, source ='circles', cat="single", metric="pct", include_base=True, base_df=base, base_label="base")
fig1.show()
fig2 = make_box_with_seed_lines(dfs_by_seed, source ='circles', cat="single+complex", metric="pct", include_base=True, base_df=base, base_label="base")
fig2.show()
fig3 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single", metric="pct", include_base=True, base_df=base, base_label="base")
fig3.show()
fig4 = make_box_with_seed_lines(dfs_by_seed, source = "bins", cat="single+complex", metric="pct", include_base=True, base_df=base, base_label="base")
fig4.show() 

In [None]:
savefolder = 'Data/hadrien_analyzed_nets/fullDF/'
