## Setup

In [1]:
%matplotlib inline

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
from allensdk.brain_observatory.ecephys.stimulus_analysis.receptive_field_mapping import ReceptiveFieldMapping

import torch
import cv2 as cv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Setup cache
data_dir = "./allendata"
manifest_path = os.path.join(data_dir, "manifest.json")
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

In [3]:
# Get session
session_id = 798911424 # Understand which session to get
session = cache.get_session_data(session_id, timeout=3000)

  return func(args[0], **pargs)
  return func(args[0], **pargs)


## Get data

In [4]:
# Get stimulus
presentations = session.stimulus_presentations
presentations = presentations[presentations['stimulus_name']=='natural_scenes']

In [5]:
# Get spikes
df_spike_times = session.presentationwise_spike_times()
df_spike_times['count'] = np.zeros(df_spike_times.shape[0])
df_spike_counts = df_spike_times.groupby(['stimulus_presentation_id','unit_id']).count()
df_spike_counts = pd. pivot_table(
    data = df_spike_counts,
    values='count',
    index='stimulus_presentation_id', columns='unit_id', fill_value=0.0, aggfunc=np.sum
)
df_firing_rates = df_spike_counts.div(presentations['duration'], axis=0)
df_firing_rates.dropna(inplace=True)

In [6]:
df_firing_rates.head()

unit_id,951088664,951088679,951088721,951088734,951088823,951088862,951088891,951088939,951088948,951088957,...,951117252,951117258,951117264,951117297,951117365,951117389,951117418,951117426,951117435,951117571
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
51355,15.985981,15.985981,11.989486,0.0,15.985981,27.975466,0.0,3.996495,39.964952,0.0,...,7.99299,7.99299,0.0,0.0,11.989486,0.0,0.0,0.0,0.0,0.0
51356,11.989486,3.996495,15.985981,7.99299,7.99299,19.982476,0.0,0.0,47.957942,3.996495,...,7.99299,11.989486,0.0,0.0,15.985981,0.0,0.0,0.0,0.0,0.0
51357,39.964952,11.989486,27.975466,0.0,3.996495,31.971961,0.0,0.0,63.943923,7.99299,...,15.985981,19.982476,0.0,0.0,19.982476,15.985981,0.0,0.0,0.0,0.0
51358,15.985981,7.99299,35.968457,0.0,0.0,31.971961,0.0,0.0,55.950933,0.0,...,11.989486,0.0,11.989486,3.996495,7.99299,7.99299,0.0,0.0,0.0,0.0
51359,19.984152,11.990491,47.961965,0.0,3.99683,35.971474,3.99683,7.993661,67.946117,0.0,...,15.987322,7.993661,15.987322,3.99683,0.0,0.0,7.993661,7.993661,0.0,0.0


In [7]:
presentations.head()

Unnamed: 0_level_0,stimulus_block,start_time,stop_time,contrast,color,stimulus_name,size,x_position,phase,orientation,spatial_frequency,temporal_frequency,frame,y_position,duration,stimulus_condition_id
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
51355,9.0,5909.794447,5910.044666,,,natural_scenes,,,,,,,13.0,,0.250219,4908
51356,9.0,5910.044666,5910.294885,,,natural_scenes,,,,,,,38.0,,0.250219,4909
51357,9.0,5910.294885,5910.545104,,,natural_scenes,,,,,,,30.0,,0.250219,4910
51358,9.0,5910.545104,5910.795324,,,natural_scenes,,,,,,,35.0,,0.250219,4911
51359,9.0,5910.795324,5911.045522,,,natural_scenes,,,,,,,112.0,,0.250198,4912


## Select Neurons by Receptive field

In [8]:
rf_mapping = ReceptiveFieldMapping(session)

In [9]:
def is_receptive_field_centered(rf):
    rf_center = np.argmax(rf) 
    rf_center_x = rf_center // rf.shape[1]
    rf_center_y = rf_center % rf.shape[1]

    return rf_center_x == 5 and rf_center_y == 5

In [10]:
drop_units = []

for unit_id in df_firing_rates.columns:
    rf = rf_mapping.get_receptive_field(unit_id)

    if not is_receptive_field_centered(rf):
        drop_units.append(unit_id)

In [11]:
df_firing_rates.drop(drop_units, axis=1, inplace=True)
df_firing_rates.reset_index(inplace=True)

In [12]:
df_firing_rates.head()

unit_id,stimulus_presentation_id,951095829,951103530,951104171,951104421,951105377,951112628,951112992,951113221
0,51355,7.99299,23.978971,0.0,11.989486,0.0,35.968457,11.989486,15.985981
1,51356,3.996495,15.985981,0.0,3.996495,0.0,0.0,11.989486,7.99299
2,51357,0.0,31.971961,0.0,3.996495,0.0,0.0,19.982476,23.978971
3,51358,0.0,11.989486,0.0,0.0,3.996495,3.996495,7.99299,11.989486
4,51359,0.0,11.990491,0.0,27.977813,0.0,0.0,11.990491,7.993661


In [13]:
df_firing_rates = df_firing_rates[df_firing_rates["stimulus_presentation_id"].isin(presentations[presentations["frame"] >= 0].index)]

In [14]:
df_firing_rates = df_firing_rates.iloc[:200, :]

## Select images

In [None]:
def get_image_by_id(id):
    return cache.get_natural_scene_template(int(presentations["frame"][id]))

## Get depth maps

In [15]:
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform

def get_depth(id):
    img = get_image_by_id(id)
    img = cv.cvtColor(img, cv.COLOR_GRAY2RGB)
    
    img = transform(img).cpu()
    
    with torch.no_grad():
        prediction = midas(img).squeeze(0)
    
    output = prediction.cpu().numpy()
    
    tile_size_x = output.shape[0] // 9
    tile_size_y = output.shape[1] // 9
    
    output = output[5*tile_size_x:6*tile_size_x, 5*tile_size_y:6*tile_size_y].mean() # TODO : check if cropping is right
    
    return output

Using cache found in /Users/riccardoalberghi/.cache/torch/hub/intel-isl_MiDaS_master
Using cache found in /Users/riccardoalberghi/.cache/torch/hub/intel-isl_MiDaS_master


In [16]:
df_firing_rates["depth"] = df_firing_rates["stimulus_presentation_id"].apply(get_depth)

## Normalize dataframe

In [17]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df_firing_rates.iloc[:, 1:-1], df_firing_rates["depth"], test_size=0.2, random_state=42)

## Linear Regression

In [18]:
from sklearn.linear_model import LinearRegression

reg = LinearRegression().fit(X_train, y_train)

## Evaluate

In [19]:
from sklearn.metrics import mean_squared_error, mean_absolute_error

mean_absolute_error(y_test, reg.predict(X_test))

6.493222316293735