In [None]:
import sys
import os
from pathlib import Path
import warnings
from tqdm import tqdm
from contextlib import redirect_stdout, redirect_stderr

import numpy as np
import scipy
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append(os.path.abspath('../library/Npxl'))
import data as d
import bayes as b
import utils as u
import results as r
import figures as figs
import pipeline as p

warnings.filterwarnings("ignore", category=RuntimeWarning)

## Regular Position Decoding


Data objects are arranged hirearchically:
- `d.Data` is the lowest level which corresponds to the data from a single `area` in a single session (`sesh`).
- `d.Sesh` is one level above which contains information of the session
- `d.Mouse` is the highest level object for each mouse

`d.Data` can be initialised by information from `d.Sesh`, which can be initialised from the information in `d.Mouse`.

<br>

`p.decode_positions` is the main decoding function which takes in:
- `data` is the data object which typically is `d.Data`
- `x` is the first x position bins of the tunnel to be cropped (default `5`)
- `num_pbins` is the number of position bins in the tunnel (default `50`)
- `num_pbins_decode` is the number of position bins excluding the reward zone (default `46`)
- `rel_pos` is the option to decode relative position (default `False`)
- `grouping` is the option of grouping by trial start locations (default `True`)

Decoding is done in the paradigms: `['lgtlgt', 'drkdrk', 'lgtdrk', 'drklgt']`.

<br>

`p.get_decoding_results` takes in `data`, `num_pbins` (here it refers to the number of position bins to plot in the confusion matrices, which is default `46`, the same as `num_pbins_decode` previously), and `rel_pos` (default `False`) as inputs. It outputs a `results` which is a nested `dict` with result type on the first level and paradigms on the second level, examples:
- `results['confusion_mtx']['lgtlgt']` has shape (46, 46)
- `results['mean_accuracy']['lgtlgt']` is a single float
- `results['mean_error']['lgtlgt']` is a single float

Result types:
- `'confusion_mtx'`: A 46 by 46 matrix. Vertical axis is True position, Horizontal axis is Decoded position.
- `'mean_accuracy'`: An average across trials of mean accuracy per trial.
- `'mean_error'`: An average across trials of mean error per trial.
- `'median_error'`: An average across trials of median error per trial.
- `'rt_mse'`: An average across trials of root mean squared error per trial.
- `'mean_wt_error'`: An average across trials of weighted error by posterior per trial.
- `'MostFreqPred_error'`: An average across position bins of errors between true position and most frequently decoded position.

<br>

`p.save_and_plot` has the following options:
- `rel_pos` (default `False`)
- `save_output` to save the `POSTERIOR`, `DECODED_POS`, and `results` (default `True`)
- `savecsv` to save the `results` (except confusion_mtx) in a table (default `True`)
- `savefig` to save the `confusion_mtx` plots (default `True`)
- `save_dir` the directory to save the above in

In [None]:
VDCN01 = d.Mouse('VDCN01')
VDCN02 = d.Mouse('VDCN02')
VDCN04 = d.Mouse('VDCN04')
VDCN05 = d.Mouse('VDCN05')
VDCN09 = d.Mouse('VDCN09')

mice = [VDCN01, VDCN02, VDCN04, VDCN05, VDCN09]

In [None]:
mouse = VDCN04
print(mouse.dates)

for date in mouse.dates:
    # Initialise session object
    sesh = d.Sesh.from_mouse(
        mouse = mouse,
        date = date,
        tau = 0.2,
        rewardzone = np.arange(46,62).tolist())
    print(sesh.areas)

    for area in sesh.areas:
        # Load data
        data = d.Data.from_sesh(sesh, area)
        # Run decoding
        data.POSTERIOR, data.DECODED_POS = p.decode_positions(data)
        data.results = p.get_decoding_results(data)

        # Save results and plots
        save_dir = Path(f"../results/{data.mouse_ID}/{data.date}/regular/{area}")
        os.makedirs(save_dir, exist_ok=True)
        p.save_and_plot(
            data,
            rel_pos=False,
            saveoutput=True,
            savecsv=True,
            savefig=True,
            save_dir=save_dir
        )

## Chance Regular Position Decoding

Similar to regular position decoding but we have one extra data object `d.Shuffled`. It is intialised from `d.Sesh` and contains the shuffled data. We treat each repeition of the shuffled data as a `d.Data` object and follow the regular decoding pipeline.

In [None]:
sessions = {
    'VDCN01': '20240801',
    'VDCN02': '20240730',
    'VDCN05': '20240724'
}
mouse_ID = 'VDCN04'
mouse = d.Mouse(mouse_ID)
# date = sessions[mouse_ID]
date = '20240807'

# Initialise session object
sesh = d.Sesh.from_mouse(
    mouse = mouse,
    date = date,
    tau = 0.2,
    rewardzone = np.arange(46,62).tolist())

# Load shuffled data
shuffled_data = d.Shuffled.from_sesh(sesh)

#### Option 1: Decode and save outputs

In [None]:
shuffled_data.POSTERIOR = []
shuffled_data.DECODED_POS = []
shuffled_data.results = []

for rep in tqdm(range(shuffled_data.num_reps), desc="Decoding shuffled data"):
    with open(os.devnull, "w") as fnull:
        with redirect_stdout(fnull), redirect_stderr(fnull):
            data = d.Data.from_shuffled(shuffled_data, rep)
            data.POSTERIOR, data.DECODED_POS = p.decode_positions(data)
            data.results = p.get_decoding_results(data)

            shuffled_data.POSTERIOR.append(data.POSTERIOR)
            shuffled_data.DECODED_POS.append(data.DECODED_POS)
            shuffled_data.results.append(data.results)
    del data

# Save shuffled results
save_dir = f"../results/{shuffled_data.mouse_ID}/{shuffled_data.date}/shuffled"
os.makedirs(save_dir, exist_ok=True)
np.savez(f"{save_dir}/shuffled_decoding_output.npz",
        posterior = shuffled_data.POSTERIOR,
        decoded_pos = shuffled_data.DECODED_POS,
        results = shuffled_data.results)

#### Option 2: Start from saved outputs

In [None]:
# Load shuffled outputs
save_dir = f"../results/{shuffled_data.mouse_ID}/{shuffled_data.date}/shuffled"
shuffled_output = np.load(f"{save_dir}/shuffled_decoding_output.npz", allow_pickle=True)

shuffled_data.POSTERIOR = shuffled_output['posterior']
shuffled_data.DECODED_POS = shuffled_output['decoded_pos']
shuffled_data.results = shuffled_output['results']

### Results and plots

In [None]:
paradigms = ['lgtlgt', 'drkdrk', 'lgtdrk', 'drklgt']
metrics = ['confusion_mtx', 'mean_accuracy', 'mean_error', 'median_error',
           'rt_mse', 'mean_weighted_error', 'MostFreqPred_error']

# Dictionary to hold mean results
mean_results = {paradigm: {} for paradigm in paradigms}

for paradigm in paradigms:
    # --- Confusion matrix (46Ã—46) ---
    confusion_stack = np.stack(
        [shuffled_data.results[rep]['confusion_mtx'][paradigm]
         for rep in range(shuffled_data.num_reps)],
        axis=0
    )  # shape: (num_reps, 46, 46)

    mean_confusion = np.nanmean(confusion_stack, axis=0)
    mean_results[paradigm]['confusion_mtx'] = mean_confusion

    # --- Scalar metrics ---
    for metric in metrics:
        if metric == 'confusion_mtx':
            continue

        vals = [
            shuffled_data.results[rep][metric][paradigm]
            for rep in range(shuffled_data.num_reps)
        ]

        mean_results[paradigm][metric] = np.nanmean(vals)

# --- Build final DataFrame (confusion matrices stored as array objects) ---
df_mean = pd.DataFrame(mean_results).T  # transpose so paradigms are rows

display(df_mean)

for paradigm, row, in df_mean.iterrows():
    confusion_mtx = row['confusion_mtx']
    save_dir = f"../results/{shuffled_data.mouse_ID}/{shuffled_data.date}/shuffled"
    filepath = f"{save_dir}/{shuffled_data.mouse_ID}_{shuffled_data.date}_200ms_shuffled_confusion_mtx_{paradigm}.png"
    figs.plot_confusion_mtx(confusion_mtx, paradigm, save=True, filepath=filepath)


## Relative Position Decoding

In essence the same as regular position decoding except that:
- `rel_pos` is set to `True`
- `num_relbins` number of relative bins is used in `num_pbins` and `num_pbins_decode`
- `grouping` in decoding is set to `False`

In [None]:
VDCN01 = d.Mouse('VDCN01')
VDCN02 = d.Mouse('VDCN02')
VDCN04 = d.Mouse('VDCN04')
VDCN05 = d.Mouse('VDCN05')
VDCN09 = d.Mouse('VDCN09')

mice = [VDCN01, VDCN02, VDCN04, VDCN05, VDCN09]

In [None]:
num_relbins = 50

mouse = VDCN04
print(mouse.dates)

for date in mouse.dates:
    # Initialise session object
    sesh = d.Sesh.from_mouse(
        mouse = mouse,
        date = date,
        tau = 0.2,
        rewardzone = np.arange(46,62).tolist())
    print(sesh.areas)

    for area in sesh.areas:
        # Load data
        data = d.Data.from_sesh(sesh, area)
        # Run decoding
        data.POSTERIOR, data.DECODED_POS = p.decode_positions(
            data, 
            num_pbins=num_relbins, 
            num_pbins_decode=num_relbins, 
            rel_pos=True, 
            grouping=False
        )
        data.results = p.get_decoding_results(data, num_pbins=num_relbins, rel_pos=True)

        # Save results and plots
        save_dir = Path(f"../results/{data.mouse_ID}/{data.date}/relative_pos/{area}/{num_relbins}bins/")
        # save_dir = Path(f"../results/{data.mouse_ID}/{data.date}/relative_pos/{area}")
        os.makedirs(save_dir, exist_ok=True)
        p.save_and_plot(
            data,
            rel_pos=True,
            saveoutput=True,
            savecsv=True,
            savefig=True,
            save_dir=save_dir
        )