In [1]:
import sys
import os
from pathlib import Path
import warnings
from tqdm import tqdm
from contextlib import redirect_stdout, redirect_stderr

import numpy as np
import scipy
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append(os.path.abspath('../library/Npxl'))
import data as d
import bayes as b
import utils as u
import results as r
import figures as figs
import pipeline as p

warnings.filterwarnings("ignore", category=RuntimeWarning)

## Regular Position Decoding


Data objects are arranged hirearchically:
- `d.Data` is the lowest level which corresponds to the data from a single `area` in a single session (`sesh`).
- `d.Sesh` is one level above which contains information of the session
- `d.Mouse` is the highest level object for each mouse

`d.Data` can be initialised by information from `d.Sesh`, which can be initialised from the information in `d.Mouse`.

<br>

`p.decode_positions` is the main decoding function which takes in:
- `data` is the data object which typically is `d.Data`
- `x` is the first x position bins of the tunnel to be cropped (default `5`)
- `num_pbins` is the number of position bins in the tunnel (default `50`)
- `num_pbins_decode` is the number of position bins excluding the reward zone (default `46`)
- `rel_pos` is the option to decode relative position (default `False`)
- `grouping` is the option of grouping by trial start locations (default `True`)

Decoding is done in the paradigms: `['lgtlgt', 'drkdrk', 'lgtdrk', 'drklgt']`.

<br>

`p.get_decoding_results` takes in `data`, `num_pbins` (here it refers to the number of position bins to plot in the confusion matrices, which is default `46`, the same as `num_pbins_decode` previously), and `rel_pos` (default `False`) as inputs. It outputs a `results` which is a nested `dict` with result type on the first level and paradigms on the second level, examples:
- `results['confusion_mtx']['lgtlgt']` has shape (46, 46)
- `results['mean_accuracy']['lgtlgt']` is a single float
- `results['mean_error']['lgtlgt']` is a single float

Result types:
- `'confusion_mtx'`: A 46 by 46 matrix. Vertical axis is True position, Horizontal axis is Decoded position.
- `'mean_accuracy'`: An average across trials of mean accuracy per trial.
- `'mean_error'`: An average across trials of mean error per trial.
- `'median_error'`: An average across trials of median error per trial.
- `'rt_mse'`: An average across trials of root mean squared error per trial.
- `'mean_wt_error'`: An average across trials of weighted error by posterior per trial.
- `'MostFreqPred_error'`: An average across position bins of errors between true position and most frequently decoded position.

<br>

`p.save_and_plot` has the following options:
- `rel_pos` (default `False`)
- `save_output` to save the `POSTERIOR`, `DECODED_POS`, and `results` (default `True`)
- `savecsv` to save the `results` (except confusion_mtx) in a table (default `True`)
- `savefig` to save the `confusion_mtx` plots (default `True`)
- `save_dir` the directory to save the above in

In [None]:
data_dir = Path("../datafiles/").resolve()

VDCN01 = d.Mouse('VDCN01', data_dir)
VDCN02 = d.Mouse('VDCN02', data_dir)
VDCN04 = d.Mouse('VDCN04', data_dir)
VDCN05 = d.Mouse('VDCN05', data_dir)
VDCN09 = d.Mouse('VDCN09', data_dir)

# mice = [VDCN01, VDCN02, VDCN04, VDCN05]
mice = [VDCN09]

for mouse in mice:
    for date in mouse.dates:
        # Initialise session object
        sesh = d.Sesh.from_mouse(
            mouse = mouse,
            date = date,
            tau = 0.25,
            rewardzone = np.arange(46,62).tolist(),
            sesh_dir = Path(mouse.mouse_dir / f"{date}/250ms_50sigma")
        )
        print(sesh.areas)

        for area in sesh.areas:
            try:
                # Load data
                data = d.Data.from_sesh(sesh, area)
                # Run decoding
                data.POSTERIOR, data.DECODED_POS = p.decode_positions(data)
                data.results = p.get_decoding_results(data)

                # Save results and plots
                tau = str(int(data.tau * 1000))+"ms"
                save_dir = Path(f"../results/{data.mouse_ID}/{data.date}/250ms_50sigma/regular/{area}")
                save_dir.mkdir(parents=True, exist_ok=True)
                p.save_and_plot(
                    data,
                    rel_pos=False,
                    saveoutput=True,
                    savecsv=True,
                    savefig=True,
                    save_dir=save_dir
                )
            except Exception as e:
                print(f"Error processing {mouse.mouse_ID} on {date} in area {area}: {e}")
                continue

## Chance Regular Position Decoding

Similar to regular position decoding but we have one extra data object `d.Shuffled`. It is intialised from `d.Sesh` and contains the shuffled data. We treat each repeition of the shuffled data as a `d.Data` object and follow the regular decoding pipeline.

In [None]:
data_dir = Path("../datafiles/").resolve()

VDCN01 = d.Mouse('VDCN01', data_dir)
VDCN02 = d.Mouse('VDCN02', data_dir)
VDCN04 = d.Mouse('VDCN04', data_dir)
VDCN05 = d.Mouse('VDCN05', data_dir)
VDCN09 = d.Mouse('VDCN09', data_dir)

# mice = [VDCN01, VDCN02, VDCN04, VDCN05]
mice = [VDCN09]

mode = 1  # 1 for decode shuffled data; 2 for load shuffled outputs

for mouse in mice:
    for date in mouse.dates:

        # Initialise session object
        sesh = d.Sesh.from_mouse(
            mouse = mouse,
            date = date,
            tau = 0.25,
            rewardzone = np.arange(46,62).tolist(),
            sesh_dir = Path(mouse.mouse_dir / f"{date}/250ms_50sigma")
        )
        print(sesh.areas)

        for area in sesh.areas:
            try:
                # Load shuffled data
                shuffled_data = d.Shuffled.from_sesh(sesh, area)

                save_dir = Path(f"../results/{shuffled_data.mouse_ID}/{shuffled_data.date}/250ms_50sigma/shuffled/{area}")
                save_dir.mkdir(parents=True, exist_ok=True)
                        
                if mode == 1:
                    # Run decoding on shuffled data and save
                    p.decode_shuffled_data(shuffled_data, save_dir)

                elif mode == 2:
                    # Load previously done shuffled results
                    shuffled_output = np.load(f"{save_dir}/shuffled_decoding_output.npz", allow_pickle=True)
                    shuffled_data.POSTERIOR = shuffled_output['posterior']
                    shuffled_data.DECODED_POS = shuffled_output['decoded_pos']
                    shuffled_data.results = shuffled_output['results']
                    del shuffled_output

                # Save plots for shuffled data
                p.save_and_plot_shuffled(
                    shuffled_data,
                    savecsv=True,
                    savefig=True,
                    save_dir=save_dir
                )
                del shuffled_data

            except Exception as e:
                print(f"Error processing {mouse.mouse_ID} on {date} in area {area}: {e}")
                continue

## Chance Relative Position Decoding

In [None]:
data_dir = Path("../datafiles/").resolve()

VDCN01 = d.Mouse('VDCN01', data_dir)
VDCN02 = d.Mouse('VDCN02', data_dir)
VDCN04 = d.Mouse('VDCN04', data_dir)
VDCN05 = d.Mouse('VDCN05', data_dir)
VDCN09 = d.Mouse('VDCN09', data_dir)

# mice = [VDCN01, VDCN02, VDCN04, VDCN05]
mice = [VDCN02, VDCN04, VDCN05, VDCN09]

num_relbins = 50

mode = 1  # 1 for decode shuffled data; 2 for load shuffled outputs

for mouse in mice:
    for date in mouse.dates:

        # Initialise session object
        sesh = d.Sesh.from_mouse(
            mouse = mouse,
            date = date,
            tau = 0.25,
            rewardzone = np.arange(46,62).tolist(),
            sesh_dir = Path(mouse.mouse_dir / f"{date}/250ms_50sigma")
        )
        print(sesh.areas)

        for area in sesh.areas:
            try:
                # Load shuffled data
                shuffled_data = d.Shuffled.from_sesh(sesh, area)

                save_dir = Path(f"../results/{shuffled_data.mouse_ID}/{shuffled_data.date}/250ms_50sigma/shuffled_relpos/{area}")
                save_dir.mkdir(parents=True, exist_ok=True)
                        
                if mode == 1:
                    # Run decoding on shuffled data and save
                    p.decode_shuffled_data(shuffled_data, save_dir, relpos=True, num_relbins=num_relbins)

                elif mode == 2:
                    # Load previously done shuffled results
                    shuffled_output = np.load(f"{save_dir}/shuffled_decoding_output.npz", allow_pickle=True)
                    shuffled_data.POSTERIOR = shuffled_output['posterior']
                    shuffled_data.DECODED_POS = shuffled_output['decoded_pos']
                    shuffled_data.results = shuffled_output['results']
                    del shuffled_output

                # Save plots for shuffled data
                p.save_and_plot_shuffled(
                    shuffled_data,
                    relpos=True,
                    savecsv=True,
                    savefig=True,
                    save_dir=save_dir
                )
                del shuffled_data

            except Exception as e:
                print(f"Error processing {mouse.mouse_ID} on {date} in area {area}: {e}")
                continue

## Relative Position Decoding

In essence the same as regular position decoding except that:
- `rel_pos` is set to `True`
- `num_relbins` number of relative bins is used in `num_pbins` and `num_pbins_decode`
- `grouping` in decoding is set to `False`

In [None]:
data_dir = Path("../datafiles/").resolve()

VDCN01 = d.Mouse('VDCN01', data_dir)
VDCN02 = d.Mouse('VDCN02', data_dir)
VDCN04 = d.Mouse('VDCN04', data_dir)
VDCN05 = d.Mouse('VDCN05', data_dir)
VDCN09 = d.Mouse('VDCN09', data_dir)

# mice = [VDCN01, VDCN02, VDCN04, VDCN05]
mice = [VDCN09]

num_relbins = 50

for mouse in mice:
    for date in mouse.dates:
        # Initialise session object
        sesh = d.Sesh.from_mouse(
            mouse = mouse,
            date = date,
            tau = 0.25,
            rewardzone = np.arange(46,62).tolist(),
            sesh_dir = Path(mouse.mouse_dir / f"{date}/250ms_50sigma")
        )
        print(sesh.areas)

        for area in sesh.areas:
            try:
                # Load data
                data = d.Data.from_sesh(sesh, area)
                # Run decoding
                data.POSTERIOR, data.DECODED_POS = p.decode_positions(
                    data, 
                    num_pbins=num_relbins, 
                    num_pbins_decode=num_relbins, 
                    rel_pos=True, 
                    grouping=False
                )
                data.results = p.get_decoding_results(data, num_pbins=num_relbins, rel_pos=True)

                # Save results and plots
                save_dir = Path(f"../results/{data.mouse_ID}/{data.date}/250ms_50sigma/relative_pos/{area}")
                save_dir.mkdir(parents=True, exist_ok=True)
                p.save_and_plot(
                    data,
                    rel_pos=True,
                    saveoutput=True,
                    savecsv=True,
                    savefig=True,
                    save_dir=save_dir
                )
            except Exception as e:
                print(f"Error processing {mouse.mouse_ID} on {date} in area {area}: {e}")
                continue