<a href="https://colab.research.google.com/github/IanQS/neuromatch_project/blob/main/steinmetz_modeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Modeling of the Steinmetz dataset

- uses [Neuromatch Load Steinmetz Decisions](https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/main/projects/neurons/load_steinmetz_decisions.ipynb#scrollTo=DJ-jzsE5eLxX) as a base

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import zscore
from sklearn.decomposition import PCA
import concurrent.futures
from multiprocessing import Pool
from typing import Dict, List, Any
from sklearn.utils import shuffle
import pandas as pd

import copy
from sklearn.model_selection import train_test_split

np.random.seed(42)

# !pip install -q ipython-autotime
# %load_ext autotime

In [2]:
# @title Data Downloading And Stacking
import os, requests

fname = []
for j in range(3):
  fname.append('steinmetz_part%d.npz'%j)
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not os.path.isfile(fname[j]):
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open(fname[j], "wb") as fid:
          fid.write(r.content)

all_ds = np.array([])
for j in range(len(fname)):
  all_ds = np.hstack((all_ds,
                      np.load('steinmetz_part%d.npz'%j,
                              allow_pickle=True)['dat']))

# Dataset Description

(taken and modified from the Neuromatch Load Steinmetz Decisions notebook)

## High-level

`all_ds` contains 39 sessions from 10 mice, data from Steinmetz et al, 2019. Time bins for all measurements are 10ms, starting 500ms before stimulus onset. The mouse had to determine which side has the highest contrast. For each `curr_ds = all_ds[k]`, you have the fields below. For extra variables, check out the extra notebook and extra data files (lfp, waveforms and exact spike times, non-binned).

## Fields Used

* `curr_ds['spks']`: neurons by trials by time bins.    
* `curr_ds['brain_area']`: brain area for each neuron recorded.
* `curr_ds['response']`: which side the response was (`-1`, `0`, `1`). When the right-side stimulus had higher contrast, the correct choice was `-1`. `0` is a no go response.

## Fields present (not all are used)

* `curr_ds['mouse_name']`: mouse name
* `curr_ds['date_exp']`: when a session was performed
* `curr_ds['ccf']`: Allen Institute brain atlas coordinates for each neuron.
* `curr_ds['ccf_axes']`: axes names for the Allen CCF.
* `curr_ds['contrast_right']`: contrast level for the right stimulus, which is always contralateral to the recorded brain areas.
* `curr_ds['contrast_left']`: contrast level for left stimulus.
* `curr_ds['gocue']`: when the go cue sound was played.
* `curr_ds['response_time']`: when the response was registered, which has to be after the go cue. The mouse can turn the wheel before the go cue (and nearly always does!), but the stimulus on the screen won't move before the go cue.  
* `curr_ds['feedback_time']`: when feedback was provided.
* `curr_ds['feedback_type']`: if the feedback was positive (`+1`, reward) or negative (`-1`, white noise burst).  
* `curr_ds['wheel']`: turning speed of the wheel that the mice uses to make a response, sampled at `10ms`.
* `curr_ds['pupil']`: pupil area  (noisy, because pupil is very small) + pupil horizontal and vertical position.
* `curr_ds['face']`: average face motion energy from a video camera.
* `curr_ds['licks']`: lick detections, 0 or 1.   
* `curr_ds['trough_to_peak']`: measures the width of the action potential waveform for each neuron. Widths `<=10` samples are "putative fast spiking neurons".
* `curr_ds['%X%_passive']`: same as above for `X` = {`spks`, `pupil`, `wheel`, `contrast_left`, `contrast_right`} but for  passive trials at the end of the recording when the mouse was no longer engaged and stopped making responses.
* `curr_ds['prev_reward']`: time of the feedback (reward/white noise) on the previous trial in relation to the current stimulus time.
* `curr_ds['reaction_time']`: ntrials by 2. First column: reaction time computed from the wheel movement as the first sample above `5` ticks/10ms bin. Second column: direction of the wheel movement (`0` = no move detected).  


The original dataset is here: https://figshare.com/articles/dataset/Dataset_from_Steinmetz_et_al_2019/9598406

In [3]:
regions = ["vis ctx", "thal", "hipp", "other ctx", "midbrain", "basal ganglia", "cortical subplate", "other"]
region_colors = ['blue', 'red', 'green', 'darkblue', 'violet', 'lightblue', 'orange', 'gray']
brain_groups = [["VISa", "VISam", "VISl", "VISp", "VISpm", "VISrl"],  # visual cortex
                ["CL", "LD", "LGd", "LH", "LP", "MD", "MG", "PO", "POL", "PT", "RT", "SPF", "TH", "VAL", "VPL", "VPM"], # thalamus
                ["CA", "CA1", "CA2", "CA3", "DG", "SUB", "POST"],  # hippocampal
                ["ACA", "AUD", "COA", "DP", "ILA", "MOp", "MOs", "OLF", "ORB", "ORBm", "PIR", "PL", "SSp", "SSs", "RSP","TT"],  # non-visual cortex
                ["APN", "IC", "MB", "MRN", "NB", "PAG", "RN", "SCs", "SCm", "SCig", "SCsg", "ZI"],  # midbrain
                ["ACB", "CP", "GPe", "LS", "LSc", "LSr", "MS", "OT", "SNr", "SI"],  # basal ganglia
                ["BLA", "BMA", "EP", "EPd", "MEA"]  # cortical subplate
                ]

In [4]:
DATASET_IDX = 11
curr_ds = all_ds[DATASET_IDX]

dt = curr_ds["bin_size"]
NUM_NEURONS_RECORDED = curr_ds["spks"].shape[0]
NUM_TRIALS = curr_ds["spks"].shape[1]
NUM_BINNED_TIMES = curr_ds["spks"].shape[2]

if DATASET_IDX != 11:
    raise Exception("Code is only meant for DATASET_IDX=11")
else:
    NUM_REGIONS = 4
    NUM_NEURONS_RECORDED = len(curr_ds["brain_area"])  # The string idx version of

brain_subregions = NUM_REGIONS * np.ones(NUM_NEURONS_RECORDED, )  # last one is "other"
for j in range(NUM_REGIONS):
  brain_subregions[
      np.isin(curr_ds['brain_area'], brain_groups[j])
      ] = j  # assign a number to each region


# Creating the dataset

1) Create the labels

2) Create a dataset dictionary where the keys are brain areas (sub-regions) and the values are all the neuron readings that are in that area/sub-region

3) Enable users to specify their config of how they want the data: do we consider region interactions, should we consider the start/middle/end of the spike train, etc.

In [5]:
LABELS = curr_ds["response"]  # RIGHT - NO_GO - LEFT (-1, 0, 1)
y = LABELS

In [6]:
# @title Logging for Sanity Checking
def log_shapes(ds):
    _ds = ds['spks']
    print(f"All spikes shape: {_ds.shape}")
    _ds_brain_region = _ds[brain_subregions == 0]
    print(f"\t- Spike shape for sample brain region (0-th): {_ds_brain_region.shape}")

    _ds_0th_left_response = _ds_brain_region[:, y >= 0]
    print(f"\t- Spike shape for sample brain region (0-th) left responses: {_ds_0th_left_response.shape}")

    averaged_over_left_response = _ds_0th_left_response.mean(axis=(0, 1))
    print(f"\t- Averaged brain region (0-th) left responses: {averaged_over_left_response.shape}")

log_shapes(curr_ds)


All spikes shape: (698, 340, 250)
	- Spike shape for sample brain region (0-th): (145, 340, 250)
	- Spike shape for sample brain region (0-th) left responses: (145, 199, 250)
	- Averaged brain region (0-th) left responses: (250,)


In [7]:
# @title Creating the Fine-Grained Data Dictionary (RUN ME!)
def dataset_by_subregion(arr_of_subregions: List[str], ds: Dict[str, Any]) -> Dict[str, List[np.ndarray]]:
    spike_partitioned = {}  # brain region to spike mapping
    unique_subregions = set(arr_of_subregions)
    for subregion in unique_subregions:
        subregion_idxs = arr_of_subregions == subregion
        subregion_data = ds["spks"][subregion_idxs]


        # from the "Dataset Description" section above
        #       > which side the response was (-1, 0, 1)
        spikes_for_right_response = subregion_data[:, y < 0]
        spikes_for_left_response = subregion_data[:, y > 0]

        # spikes_for_no_response = subregion_data[:, y == 0]

        spike_partitioned[subregion] = [
            spikes_for_left_response,
            # spikes_for_no_response,
            spikes_for_right_response
        ]
    return spike_partitioned

subregion_data_dict = dataset_by_subregion(curr_ds["brain_area"], curr_ds)

print("Number of Neurons recorded in each subregion ")
running_sum = 0
for k, v in subregion_data_dict.items():
    print(f"\t{k}\t {v[0].shape[0]}")
    running_sum += v[0].shape[0]

assert running_sum == curr_ds["spks"].shape[0], "Our totaled neurons across all subregions are not equal to the number of neurons measured"
print(running_sum)

Number of Neurons recorded in each subregion 
	VISp	 66
	VISam	 79
	DG	 65
	LGd	 11
	CA1	 50
	SUB	 105
	LH	 18
	PL	 56
	ACA	 16
	root	 100
	MD	 126
	MOs	 6
698


## Creating the Coarse-Grained Data Dictionary

- we do this manually since we do not have too many subregions

### All regions
```python
["VISa", "VISam", "VISl", "VISp", "VISpm", "VISrl"],  # visual cortex
["CL", "LD", "LGd", "LH", "LP", "MD", "MG", "PO", "POL", "PT", "RT", "SPF", "TH", "VAL", "VPL", "VPM"], # thalamus
["CA", "CA1", "CA2", "CA3", "DG", "SUB", "POST"],  # hippocampal
["ACA", "AUD", "COA", "DP", "ILA", "MOp", "MOs", "OLF", "ORB", "ORBm", "PIR", "PL", "SSp", "SSs", "RSP","TT"],  # non-visual cortex
["APN", "IC", "MB", "MRN", "NB", "PAG", "RN", "SCs", "SCm", "SCig", "SCsg", "ZI"],  # midbrain
["ACB", "CP", "GPe", "LS", "LSc", "LSr", "MS", "OT", "SNr", "SI"],  # basal ganglia
["BLA", "BMA", "EP", "EPd", "MEA"]  # cortical subplate
```


### Refined Regions

- only the ones relevant to our dataset

```python
MD -> thalamus
ACA -> non-visual-cortex
SUB -> hippocampal
CA1 -> hippocampal
DG -> hippocampal
LGd -> thalamus
LH -> thalamus
PL -> non-visual-cortex
root ->
VISp -> visual-cortex
MOs -> non-visual-cortex
VISam -> visual-cortex
```

In [None]:
# @title Creating the Coarse-Grained Data Dictionary
def consolidate_fine_grained(subregion_dict):

    mapping = {
        "thalamus": ["MD", "LGd", "LH"],
        "non-visual-cortex": ["ACA", "PL", "MOs"],
        "hippocampal": ["SUB", "CA1", "DG"],
        "visual-cortex": ["VISp", "VISam"]
    }


    coarse_region_data_dict: Dict[str, List[List[np.ndarray]]] = dict()

    for coarse_region_name, subregion_name_arr in mapping.items():
        print("*" * 10)
        print(coarse_region_name)
        for subregion_name in subregion_name_arr:
            print(f"Subregion: {subregion_name}")
            if coarse_region_name not in coarse_region_data_dict:
                print(f"\tInit: Left and Right: {subregion_dict[subregion_name][0].shape}, {subregion_dict[subregion_name][1].shape}")
                coarse_region_data_dict[coarse_region_name] = copy.deepcopy(subregion_dict[subregion_name])
            else:
                subregion_left, subregion_right = subregion_dict[subregion_name]

                print(f"\tIncoming Shapes: Left and Right: {subregion_left.shape}, {subregion_right.shape}")
                # print(f"Container: {coarse_region_data_dict[coarse_region_name][1]}")
                coarse_region_data_dict[coarse_region_name][0] = np.vstack(
                    (coarse_region_data_dict[coarse_region_name][0],
                    subregion_left)
                )
                coarse_region_data_dict[coarse_region_name][1] = np.vstack(
                    (coarse_region_data_dict[coarse_region_name][1],
                    subregion_right)
                )
            print(f"\tPost-stack shapes: {coarse_region_data_dict[coarse_region_name][0].shape} {coarse_region_data_dict[coarse_region_name][1].shape}")
    return coarse_region_data_dict

coarse_region_data_dict = consolidate_fine_grained(subregion_data_dict)

In [43]:
# @title Constructing the Dataset (pre-processing)

def populate_data(designed_matrix, is_left, X_container, y_container, min_num_spikes):
    """
    designed_matrix is of shape (a, b, c)
        a:= num_neurons in coarse_region
        b:= num_trials  (in this case either the left or right response trials)
        c:= spike_train
    """
    OFFSET_FOR_ONE_HOT = 1
    for trials_matrix in designed_matrix:
        for spike_train in trials_matrix:
            non_zero_length = sum(spike_train[len(coarse_region_data_dict):] > 0)
            if non_zero_length > min_num_spikes + OFFSET_FOR_ONE_HOT:  #
                X_container.append(spike_train)
                y_container.append(1 if is_left else -1)
    return X_container, y_container

def encode_coarse_data(
    coarse_data_dict: Dict[str, List[List[np.ndarray]]],
    min_num_spikes
):
    unique_keys = dict()
    one_hot_idx = 0

    X_container = []
    y_container = []
    for coarse_region_name, coarse_region_data in coarse_data_dict.items():
        print("*" * 20)
        print(coarse_region_name)
        # Enumerate all of the arrays of the subregions and vertically stack them
        left = coarse_region_data[0]   # The positive label (left) of our LGd, for example
        right = coarse_region_data[1]  # The negative label (right) of our LH, for example
        _l_shape = left.shape
        _r_shape = right.shape

        assert _l_shape[-1] == 250
        assert _r_shape[-1] == 250


        ##########################################
        # Add 1-hot encoded data
        #   For more information: https://en.wikipedia.org/wiki/One-hot

        vec_one_hot = [0 for _ in range(len(coarse_region_data_dict.keys()))]
        vec_one_hot[one_hot_idx] = 1
        #[1, 0, 0, 0]
        left_pad = np.tile(vec_one_hot, (_l_shape[0], _l_shape[1], 1))
        right_pad = np.tile(vec_one_hot, (_r_shape[0], _r_shape[1], 1))
        left_designed = np.concatenate((left_pad, left), axis=-1)
        right_designed = np.concatenate((right_pad, right), axis=-1)

        # print(f"Shape B4 populating: {np.asarray(X_container).shape}, {np.asarray(y_container).shape}")
        pre_population_length = len(X_container)
        X_container, y_container = populate_data(left_designed, True, X_container, y_container, min_num_spikes)
        X_container, y_container = populate_data(right_designed, False, X_container, y_container, min_num_spikes)
        post_population_length = len(X_container)
        print(f"\tPre-population Length: {pre_population_length}")
        print(f"\tTotal Samples in {coarse_region_name}: {left_designed.shape[0] * left_designed.shape[1] + right_designed.shape[0] * right_designed.shape[1]}")
        print(f"\tSamples from {coarse_region_name} inserted post-min_num_spikes filtering: {post_population_length - pre_population_length}")


        # We are now in a new region, so we increment the index for the one-hot
        one_hot_idx += 1

    return np.asarray(X_container), np.asarray(y_container)

# Dataset Configuration

Here we use the `DATASET_PARAMETERS` that was specified above

In [44]:
def create_dataset(dataset_parameters, X_data, y_data):
    """
    DATASET_PARAMETERS["window_choice"] = WindowChoice.END
    DATASET_PARAMETERS["window_size"] = WINDOW_SIZE
    DATASET_PARAMETERS["train-test-split"] = TRAIN_TEST_SPLIT
    """

    ##############################################
    # First step is we extract the spike train window of interest based on whether we want the start, mid, or end
    window_choice = dataset_parameters["window_choice"]

    if dataset_parameters["coarse"]:
        region_offset_for_onehot = len(coarse_region_data_dict)
    else:
        raise NotImplementedError("Not yet supported")

    if window_choice == WindowChoice.START:
        start = region_offset_for_onehot
        end = start + dataset_parameters["window_size"] + region_offset_for_onehot
    elif window_choice == WindowChoice.MID:
        start = X_data.shape[-1] // 2
        end = start + dataset_parameters["window_size"]
    else:
        start = (X_data.shape[-1] - 1) - dataset_parameters["window_size"]
        end = (X_data.shape[-1] - 1)

    idxs_to_use = list(range(region_offset_for_onehot)) + list(range(start, end))
    new_Xs = []
    for row in X_data:
        new_Xs.append(row[idxs_to_use])
    X_data = np.asarray(new_Xs)

    ##############################################
    # Optionally shuffle the dataset
    if dataset_parameters["shuffle"]:
        X_data, y_data = shuffle(X_data, y_data)

    if dataset_parameters["false_as_0"]:
        y_data = np.maximum(y_data, 0)

    ##############################################
    # Next step is to split into train-test

    X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, train_size= dataset_parameters["train-test-split"], random_state=42)

    return X_train, X_test, y_train, y_test

# Do the modeling!

Feel free to try different models and play to your hearts content :)

**NOTE**

You may want to try what happens as you increase the "drop" count parameter. What I've observed is that with a window of 100 elements,

```
mean 	4.393311
std 	8.028090
min 	0.000000
25% 	0.000000
50% 	1.000000
75% 	5.000000
max 	84.000000
```

which is insane, because it means that 75% of our data has just 5 elements in that 100-range.

At a window size of 250, we have the following statistics:

```
count 	132038.000000
mean 	10.849104
std 	19.008744
min 	0.000000
25% 	0.000000
50% 	2.000000
75% 	13.000000
max 	183.000000
```

---

Remember a few things:

1) K-fold cross-validation might be useful

2) accuracy is a good test-measure here because we are doing a classification task

3) The shape of each row of the training data is 254. The first 4 are what is
called a 1-hot encoding and basically encodes for which part of the brain we are looking at (there are 4 parts in our problem setup). The remaining 250 (or smaller, depending on what you set the window size and all that) are spike train data.

4) Remember that regularization might be helpful if you are overfitting.

5) Check out "confusion matrix" to help you understand where your model is making mistakes.

Have fun!

In [45]:
# @title Misc Setup (Run me!)

import enum

class WindowChoice(enum.Enum):
    START = 0
    MID = 1
    END = 2

In [48]:
###################################################
# TODO: Set parameters as you see fit!
###################################################

WINDOW_SIZE = 249
TRAIN_TEST_SPLIT = 0.8  # 80% is training, 20% test
SHUFFLE_DATASET = True
FALSE_AS_0 = True  # Convert the false case to 0 instead of -
USE_COARSE_DATASET = True
WINDOW_CHOICE = WindowChoice.END
MIN_NUM_SPIKES = 50  # The minimum number of spikes in the entire array for us to use it in the dataset

assert 0 < TRAIN_TEST_SPLIT <= 1.0


DATASET_PARAMETERS = dict()

DATASET_PARAMETERS["window_choice"] = WINDOW_CHOICE
DATASET_PARAMETERS["window_size"] = WINDOW_SIZE
DATASET_PARAMETERS["train-test-split"] = TRAIN_TEST_SPLIT
DATASET_PARAMETERS["shuffle"] = SHUFFLE_DATASET
DATASET_PARAMETERS["false_as_0"] = FALSE_AS_0
DATASET_PARAMETERS["coarse"] = USE_COARSE_DATASET
DATASET_PARAMETERS["min_num_spikes"] = MIN_NUM_SPIKES


In [49]:
if DATASET_PARAMETERS["coarse"]:
    Xs, ys = encode_coarse_data(
        coarse_region_data_dict,
        DATASET_PARAMETERS["min_num_spikes"]
    )
else:
    raise NotImplementedError("Fine-Grained dataset not supported yet")
X_train, X_test, y_train, y_test = create_dataset(DATASET_PARAMETERS, Xs, ys)

summed_ax_1 = np.sum(X_train, axis=1) - 1  # Have to subtract 1 since our 1-hot encoding makes us have an "inflated" 1

pd.DataFrame(summed_ax_1).describe()

********************
thalamus
	Pre-population Length: 0
	Total Samples in thalamus: 42780
	Samples from thalamus inserted post-min_num_spikes filtering: 4176
********************
non-visual-cortex
	Pre-population Length: 4176
	Total Samples in non-visual-cortex: 21528
	Samples from non-visual-cortex inserted post-min_num_spikes filtering: 13
********************
hippocampal
	Pre-population Length: 4189
	Total Samples in hippocampal: 60720
	Samples from hippocampal inserted post-min_num_spikes filtering: 1878
********************
visual-cortex
	Pre-population Length: 6067
	Total Samples in visual-cortex: 40020
	Samples from visual-cortex inserted post-min_num_spikes filtering: 494


Unnamed: 0,0
count,5248.0
mean,79.035823
std,26.378363
min,51.0
25%,60.0
50%,69.0
75%,90.0
max,183.0
