In [1]:
from astro.load import SessionData, Loader
from astro.transforms import GroupSplitter
from astro.config import Config

In [2]:
paths = Config.from_env()

### Session Data

- A container for data recorded on a single session
- Has methods for loading all data
- Has methods for loading data as dicts by group

In [3]:
loader = Loader(data_dir=paths.data_dir)
group_splitter = GroupSplitter(
    df_mice=loader.load_mice(),
    df_neurons=loader.load_neurons(),
    excluded_groups=["VEH-VEH"],
)

session_data = SessionData(loader, session_name="ret", group_splitter=group_splitter)

In [4]:
# traces


display(session_data.df_traces.head(2))

trace_dict = session_data.traces_by_group
display(trace_dict["CNO-VEH"].head(2))

Unnamed: 0,time,1050,1051,1053,1054,1055,1056,1057,1059,106,...,869,871,874,875,876,877,878,879,880,881
0,0.0,-0.411073,15.39712,8.223338,11.41209,1.958277,5.975921,0.776018,5.613267,6.147187,...,2.535664,-0.049567,13.52503,-0.21261,-2.273412,1.355439,2.556248,0.799562,1.940257,-2.880405
1,0.1,1.193376,15.53693,9.570282,10.82667,2.167694,3.99791,0.867376,4.860403,4.00393,...,3.79568,-0.282045,15.99658,-0.670835,-1.421511,-0.190482,0.138527,-2.068296,4.265496,-3.336056


Unnamed: 0,time,1315,1316,1318,1323,1325,1327,1328,1329,1331,...,869,871,874,875,876,877,878,879,880,881
0,0.0,-0.974774,8.514519,1.714646,-1.009016,4.13278,33.13108,18.86387,5.097204,18.60578,...,2.535664,-0.049567,13.52503,-0.21261,-2.273412,1.355439,2.556248,0.799562,1.940257,-2.880405
1,0.1,-2.580639,6.832629,1.288389,2.397229,4.518605,31.63235,18.51209,5.173556,17.77478,...,3.79568,-0.282045,15.99658,-0.670835,-1.421511,-0.190482,0.138527,-2.068296,4.265496,-3.336056


In [5]:
display(session_data.df_block_starts(block_group="CS").head(2))

block_starts_dict = session_data.df_block_starts_by_group(block_group="CS")

display(block_starts_dict["CNO-VEH"].head(2))

Unnamed: 0,mouse_name,block_name,block_group,start_time
10,AS-Gq-GRIN-25,,CS,180.0
11,AS-Gq-GRIN-25,,CS,240.0


Unnamed: 0,mouse_name,block_name,block_group,start_time
10,AS-Gq-GRIN-25,,CS,180.0
11,AS-Gq-GRIN-25,,CS,240.0


## Neuron Permutation

In [6]:
loader = Loader(data_dir=paths.data_dir)
group_splitter = GroupSplitter(
    df_mice=loader.load_mice(),
    df_neurons=loader.load_neurons(),
    excluded_groups=["VEH-VEH"],
    permute_neurons=True,
)

session_data = SessionData(loader, session_name="ret", group_splitter=group_splitter)

In [7]:
trace_dict = session_data.traces_by_group
display(trace_dict["CNO-VEH"].head(2))


cno_cols = set(trace_dict["CNO-VEH"].columns) - {"time"}
veh_cols = set(trace_dict["VEH-CNO"].columns) - {"time"}

intersection = cno_cols.intersection(veh_cols)
union = cno_cols.union(veh_cols)
score = len(intersection) / len(union)
print(f"Jaccard similarity of neurons in groups: {score:.2f}")

Unnamed: 0,time,1051,1053,1054,106,1060,1067,107,1070,1071,...,857,858,860,865,866,874,876,877,878,880
0,0.0,15.39712,8.223338,11.41209,6.147187,2.168524,13.67741,72.85371,35.84431,2.273836,...,0.924617,1.099478,-2.620656,2.545328,2.344234,13.52503,-2.273412,1.355439,2.556248,1.940257
1,0.1,15.53693,9.570282,10.82667,4.00393,3.449356,13.82774,70.85292,36.45652,1.131818,...,1.656075,1.036856,-2.184451,1.813191,-0.016554,15.99658,-1.421511,-0.190482,0.138527,4.265496


Jaccard similarity of neurons in groups: 0.00
