In [2]:
#imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from scipy.ndimage import gaussian_filter, maximum_filter, generate_binary_structure, iterate_structure
from itertools import chain
import random

In [3]:
#Set working directory and save folder
os.chdir('/Users/hadrienpadilla/Documents/McGill/Peyrache Lab/pRNN')
savefolder = 'Data/hadrien_analyzed_nets/'

In [4]:
def _to_2d(a):
    """Return a 2-D float array from array-like or dict->{array}."""
    if a is None or (isinstance(a, float) and np.isnan(a)):
        raise ValueError("missing")
    if isinstance(a, dict):
        a = a[0] if 0 in a else a[min(a.keys())]
    a = np.asarray(a, dtype=float).squeeze()
    if a.ndim == 2:
        return a
    if a.ndim == 3:
        return a[..., 0] if a.shape[-1] <= 4 else a[0, ...]
    raise ValueError(f"expected 2D array, got {a.shape}")

#normalize values between 0 and 1 
def _normalize(a):
    a = np.asarray(a, dtype=float)
    lo, hi = np.nanmin(a), np.nanmax(a)
    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        return np.zeros_like(a)
    return (a - lo) / (hi - lo)

# ---------- peak finding ----------
def _peak_coords(a2d, threshold=0.3, neighborhood=2, min_distance=3, sigma=0.8):
    # Get raw and smoothed arrays
    raw_arr = np.array(a2d)
    smooth_arr = gaussian_filter(a2d, sigma=sigma)
    
    # Create footprint for local maximum detection
    fp = generate_binary_structure(raw_arr.ndim, neighborhood)
    
    # Find peaks separately in raw and smoothed data
    raw_peaks = np.argwhere((raw_arr == maximum_filter(raw_arr, footprint=fp)) & 
                           (raw_arr >= threshold))
    smooth_peaks = np.argwhere((smooth_arr == maximum_filter(smooth_arr, footprint=fp)) & 
                              (raw_arr >= threshold))  # Still use raw_arr for threshold
    
    # Combine peaks
    all_peaks = np.unique(np.vstack([raw_peaks, smooth_peaks]), axis=0)
    
    # Sort peaks by raw value (highest first)
    peak_values = raw_arr[all_peaks[:, 0], all_peaks[:, 1]]
    sorted_idx = np.argsort(-peak_values)
    all_peaks = all_peaks[sorted_idx]
    
    # Filter by minimum distance
    filtered_peaks = []
    for peak in all_peaks:
        if not filtered_peaks:
            filtered_peaks.append(peak)
            continue
        
        distances = np.sqrt(np.sum((np.array(filtered_peaks) - peak)**2, axis=1))
        if np.all(distances > min_distance):
            filtered_peaks.append(peak)
    
    return [tuple(p) for p in filtered_peaks]
    


def identify_peaks(cells, *, threshold=0.5, neighborhood=1, normalize=True, cell_ids=None, sigma = 1.0, min_distance = 5):
    out = {}
    for i, cell in enumerate(cells):
        try:
            a = _to_2d(cell)
        except ValueError:
            continue
        if normalize:
            a = _normalize(a)
        key = cell_ids[i] if (cell_ids is not None and i < len(cell_ids)) else i
        out[key] = _peak_coords(a, threshold=threshold, neighborhood=neighborhood, min_distance = min_distance, sigma = sigma)
    return out

def _grid_xy(env, H, W):
    """Return (H,W,2) array of bin centers (x,y)."""
    e = getattr(env, "env", env)
    coords = getattr(e, "discrete_coords", None)
    if coords is not None and np.asarray(coords).shape[:2] == (H, W):
        return np.asarray(coords)
    xedges, yedges = getattr(e, "xedges", None), getattr(e, "yedges", None)
    if xedges is not None and yedges is not None:
        xc = (np.asarray(xedges)[:-1] + np.asarray(xedges)[1:]) / 2
        yc = (np.asarray(yedges)[:-1] + np.asarray(yedges)[1:]) / 2
        X, Y = np.meshgrid(xc, yc)
        return np.stack([X, Y], -1)
    X, Y = np.meshgrid((np.arange(W) + 0.5) / W, (np.arange(H) + 0.5) / H)
    return np.stack([X, Y], -1)


def peaks_to_xy_centers(peaks_by_cell, pf_dict, env):
    # discover grid
    first = next(pf_dict[k] for k in peaks_by_cell if k in pf_dict)
    H, W = _to_2d(first).shape
    grid = _grid_xy(env, H, W)

    out = {}
    for cid, rc_list in peaks_by_cell.items():
        if cid not in pf_dict or not rc_list:
            out[cid] = []
            continue
        a2d = _to_2d(pf_dict[cid])
        xy = []
        for r, c in rc_list:
            ri, ci = int(r), int(c)
            x, y = grid[ri, ci]
            xy.append((float(x), float(y)))
        out[cid] = xy
    return out


In [None]:
base_analysis = pd.read_pickle('Data/hadrien_analyzed_nets/base/base_net.pkl')
base_net = pd.read_pickle('nets/for_Hadrien/base/multRNN_5win_i2_o2-no_reward-s1042_ep5-cpu.pkl')

In [None]:
base_analysis = pd.read_pickle('Data/hadrien_analyzed_nets/base/base_net.pkl')
curr_analysis = base_analysis
base_net = pd.read_pickle('nets/for_Hadrien/base/multRNN_5win_i2_o2-no_reward-s1042_ep5-cpu.pkl')

bins = curr_analysis.bins
circles = curr_analysis.circles


# explode -> list of dicts
cells_s = base_analysis.cells.explode()
cells_l = cells_s.tolist()
# positions in the flattened list (0..K-1) where type == 'complex'
complex_pos = [i for i, d in enumerate(cells_l)
               if isinstance(d, dict) and d.get('type') == 'complex']
# the external IDs (your 'idx' values) for later labeling
complex_ids = [d['idx'] for d in cells_l
               if isinstance(d, dict) and d.get('type') == 'complex']
pf_dict = base_net.TrainingSaver.place_fields[0]
env_obj = base_net.EnvLibrary[0]
ids = [cid for cid in complex_ids if cid in pf_dict]
complex_pfs = [pf_dict[cid] for cid in ids]
peaks = identify_peaks(complex_pfs, threshold=0.6, cell_ids = ids)
xy_by_peak = peaks_to_xy_centers(peaks, pf_dict, env_obj)

bins['single+complex_mult_peaks'] = build_bins


base_net.circle

In [17]:
base_analysis = pd.read_pickle('Data/hadrien_analyzed_nets/base/base_net.pkl')
bins = base_analysis.bins
print(bins)
print(bins[0])
circles = base_analysis.circles
print(circles[0])


0    {'single': {'counts': [12, 18, 19, 24, 32, 40]...
Name: bins, dtype: object
{'single': {'counts': [12, 18, 19, 24, 32, 40], 'fractions': [0.08275862068965517, 0.12413793103448276, 0.1310344827586207, 0.16551724137931034, 0.2206896551724138, 0.27586206896551724]}, 'single+complex': {'counts': [38, 47, 49, 92, 124, 99], 'fractions': [0.08463251670378619, 0.10467706013363029, 0.1091314031180401, 0.20489977728285078, 0.27616926503340755, 0.22048997772828507]}, 'single+complex_mult_peaks': {'counts': [1109, 567, 19, 24, 32, 82], 'fractions': [0.6050190943807965, 0.309328968903437, 0.010365521003818877, 0.01309328968903437, 0.01745771958537916, 0.0447354064375341]}}
{'single': {'counts': [1, 0, 0], 'fractions': [0.006896551724137931, 0.0, 0.0], 'radius': 0.05}, 'single+complex': {'counts': [2, 4, 0], 'fractions': [0.004454342984409799, 0.008908685968819599, 0.0], 'radius': 0.05}, 'single+complex_mult_peaks': {'counts': [1, 0, 0], 'fractions': [0.0005455537370430987, 0.0, 0.0], 'radius':

In [None]:
analysis = 