In [1]:
# Following instructions here: https://alleninstitute.github.io/openscope_databook/basics/download_nwb.html
#run using openscope environment

In [2]:
import warnings
warnings.filterwarnings('ignore')

try:
    from databook_utils.dandi_utils import dandi_stream_open
except:
    !git clone https://github.com/AllenInstitute/openscope_databook.git
    %cd openscope_databook
    %pip install -e .

In [3]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np


import os, glob
from pathlib import Path

from dandi import dandiapi
from pynwb import NWBHDF5IO
from dandi.dandiapi import DandiAPIClient
from nwbwidgets.view import default_neurodata_vis_spec
from dotenv import load_dotenv

import pynwb
from nwbwidgets import nwb2widget


from typing import Union, Iterator, Callable, Tuple, Dict
import os
from pathlib import Path
from tqdm.notebook import tqdm

%matplotlib inline

In [4]:
# define functions to download files with a progress bar

MAX_CHUNK_SIZE = int(os.environ.get("DANDI_MAX_CHUNK_SIZE", 1024 * 1024 * 8))  

def get_download_file_iter_with_steps(
    file, chunk_size: int = MAX_CHUNK_SIZE
) -> Tuple[Callable[[int], Iterator[bytes]], Dict[str, int]]:

    url = file.base_download_url
    steps_dict = {"total_steps": None}
    result = file.client.session.get(url, stream=True)

    total_size = int(result.headers.get('content-length', 0))
    steps_dict["total_steps"] = total_size // chunk_size
    print(f"Downloading {total_size} bytes in {steps_dict['total_steps']} steps")

    def downloader(start_at: int = 0) -> Iterator[bytes]:
        headers = None
        if start_at > 0:
            headers = {"Range": f"bytes={start_at}-"}
        result = file.client.session.get(url, stream=True, headers=headers)
        result.raise_for_status()
        for chunk in result.iter_content(chunk_size=chunk_size):
            if chunk:  
                yield chunk

    return downloader, steps_dict

def download_with_progressbar(
    file, filepath: Union[str, Path], chunk_size: int = MAX_CHUNK_SIZE
) -> None:
    downloader, steps_dict = get_download_file_iter_with_steps(file)
    with open(filepath, "wb") as fp:
        for chunk in tqdm(downloader(0), total=steps_dict["total_steps"], unit="chunk", unit_scale=True, unit_divisor=1024):
            fp.write(chunk)

# Testing: Average Receptive Field Across Probes

In [6]:
# load config file
load_dotenv(dotenv_path="config.env")

download_loc = os.getenv("SAMPLE_DATA_DIR")
print(download_loc)

None


In [7]:
dandiset_id = "000021"
dandi_filepath = "sub-699733573/sub-699733573_ses-715093703.nwb"

dandi_api_key = None
authenticate = False


if authenticate:
    client = dandiapi.DandiAPIClient(token=dandi_api_key)
else:
    client = dandiapi.DandiAPIClient()
my_dandiset = client.get_dandiset(dandiset_id,"draft")

print(f"Got dandiset {my_dandiset}")

filename = dandi_filepath.split("/")[-1]
filepath = f"{download_loc}\{filename}" if os.name == 'nt' else f"{download_loc}/{filename}"
print(filepath)

file = my_dandiset.get_asset_by_path(dandi_filepath)
# this may take awhile, especially if the file to download is large
download_with_progressbar(file, filepath)

print(f"Downloaded file to {filepath}")

Got dandiset DANDI:000021/draft
None/sub-699733573_ses-715093703.nwb
Downloading 2856232912 bytes in 340 steps


FileNotFoundError: [Errno 2] No such file or directory: 'None/sub-699733573_ses-715093703.nwb'

## Open and Display the NWB file

In [8]:
# open the downloaded NWB file
ROOT_DIR = download_loc

# Look for the first NWB file under the dandiset folder
candidates = glob.glob(str(Path(ROOT_DIR) / "**" / "*.nwb"), recursive=True)
if not candidates:
    raise FileNotFoundError("No .nwb file found under " + ROOT_DIR)
nwb_path = candidates[0]
print("Opening:", nwb_path)
io = NWBHDF5IO(nwb_path, "r", load_namespaces=True)


TypeError: expected str, bytes or os.PathLike object, not NoneType

In [None]:
nwb = io.read()
print(nwb)


In [None]:
rf_stim_table = nwb.intervals["gabors_presentations"].to_dataframe()
rf_stim_table[:10]

Opening another nwb file: 

In [9]:
# nwb_path = r"C:\Users\MaryBeth\projects\SarvestaniLab\OpenScopeMouseV1\001568\sub-810531\sub-810531_ses-ecephys-810531-2025-09-17-15-14-30_ecephys.nwb"

# io = NWBHDF5IO(nwb_path, "r", load_namespaces=True)
# nwb = io.read()

# print(nwb)
            
# units = nwb.units
# units.colnames
# nwb.intervals.keys()
# channel_probes = {nwb.electrodes["id"][i]: nwb.electrodes["group_name"][i] for i in range(len(nwb.electrodes))}

# # function retrieves peak channel ID from "units" dataset, looks up corresponding group name, returns probe associated with peak channel
# def get_unit_probe(unit_idx):
#     return str(units['device_name'][unit_idx])

# print(set(channel_probes.values()))

In [None]:
## Retrieve probe names
channel_probes = {nwb.electrodes["id"][i]: nwb.electrodes["group_name"][i] for i in range(len(nwb.electrodes))}

# function retrieves peak channel ID from "units" dataset, looks up corresponding group name, returns probe associated with peak channel
def get_unit_probe(unit_idx):
    peak_channel_id = units["peak_channel_id"][unit_idx]
    return channel_probes[peak_channel_id]

print(set(channel_probes.values()))
## Get units from selected probe & their receptive fields
### helper functions made from electrodes table to get brain location of unit or probe name of unit

# map channel ids to brain location and probe group name
# creates a dictionary that iterates over indices of the dataset, assigns "id" as the key, and makes the location the value
channel_locations = {nwb.electrodes["id"][i]: nwb.electrodes["location"][i] for i in range(len(nwb.electrodes))}
# creates similar dictionary as previous line, but with the group name/probe name as the value
channel_probes = {nwb.electrodes["id"][i]: nwb.electrodes["group_name"][i] for i in range(len(nwb.electrodes))}

# function retrieves peak channel ID for given unit index from "units" dataset and looks up corresponding probe from dictionary
def get_unit_location(unit_idx):
    peak_channel_id = units["peak_channel_id"][unit_idx]
    return channel_locations[peak_channel_id]

# function retrieves peak channel ID from "units" dataset, looks up corresponding group name, returns probe associated with peak channel
def get_unit_probe(unit_idx):
    peak_channel_id = units["peak_channel_id"][unit_idx]
    return channel_probes[peak_channel_id]

print(set(channel_locations.values()))
print(set(channel_probes.values()))

In [10]:
def select_condition(unit_idx, probe, all_units=False):
    # the values below are recommended thresholds for these quality metrics
    if all_units:
        return True
    return units["snr"][unit_idx] > 1 and \
            get_unit_probe(unit_idx) == probe

selected_unit_idxs = []
for unit_idx in range(len(units)):
    if select_condition(unit_idx, "probeB"):
        selected_unit_idxs.append(unit_idx)
        
if len(selected_unit_idxs) == 0:
    raise IndexError("There are no units for this selection")

print(len(selected_unit_idxs))

NameError: name 'units' is not defined

In [None]:
print(rf_stim_table.columns.tolist())

In [11]:
### get x and y coordinates of gabors displayed to build receptive field

xs = np.sort(list(set(rf_stim_table.x_position)))
ys = np.sort(list(set(rf_stim_table.y_position)))
field_units = rf_stim_table.units[0]
print(xs)
print(ys)
print(field_units)


### get receptive field of a unit using its spike times and the stim table

def get_rf(spike_times):
    # creates 2D array that stores response spike counts for each coordinate of the receptive field
    unit_rf = np.zeros([ys.size, xs.size])
    # for every x and y coordinate in the field
    for xi, x in enumerate(xs):
        for yi, y in enumerate(ys):
            
            # for this coordinate of the rf, count all the times that this neuron responds to a stimulus time with a spike
            stim_times = rf_stim_table[(rf_stim_table.x_position == x) & (rf_stim_table.y_position == y)].start_time
            response_spike_count = 0
            for stim_time in stim_times:
                # any spike within 0.2 seconds after stim time is considered a response
                start_idx, end_idx = np.searchsorted(spike_times, [stim_time, stim_time+0.2])
                response_spike_count += end_idx-start_idx

            unit_rf[yi, xi] = response_spike_count
    

NameError: name 'rf_stim_table' is not defined

In [None]:
### compute receptive fields for each unit in selected units

unit_rfs = []
for unit_idx in selected_unit_idxs:
    unit_spike_times = units["spike_times"][unit_idx]
    unit_rfs.append(get_rf(unit_spike_times))

## Plot in cloud

In [12]:
### Plotly with COLOR
import plotly.graph_objects as go
import numpy as np
from PIL import Image
import io
import base64
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Random positions
np.random.seed(42)
n_rfs = len(unit_rfs)
x_positions = np.random.uniform(0, 10, n_rfs)
y_positions = np.random.uniform(0, 10, n_rfs)

fig = go.Figure()

# Add scatter points (clickable)
fig.add_trace(go.Scatter(
    x=x_positions,
    y=y_positions,
    mode='markers',
    marker=dict(size=20, color='lightblue', line=dict(width=1, color='black')),
    text=[f"Unit {i}" for i in range(n_rfs)],
    hoverinfo='text'
))

# Add RF images as layout images with COLOR
for i, (rf, x, y) in enumerate(zip(unit_rfs, x_positions, y_positions)):
    # Normalize RF for display (0-1 range)
    rf_norm = (rf - rf.min()) / (rf.max() - rf.min()) if rf.max() > rf.min() else rf
    
    # Apply colormap (RdBu_r) to get RGB image
    cmap = cm.get_cmap('viridis')  # You can change this: 'viridis', 'plasma', 'hot', etc.
    rf_colored = cmap(rf_norm)  # This returns RGBA
    
    # Convert to uint8 RGB (drop alpha channel)
    rf_uint8 = (rf_colored[:, :, :3] * 255).astype(np.uint8)
    
    # Convert to PIL Image
    img = Image.fromarray(rf_uint8, mode='RGB')
    
    # Convert to base64 string
    buffered = io.BytesIO()
    img.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    
    fig.add_layout_image(
        dict(
            source=f"data:image/png;base64,{img_str}",
            x=x - 0.3,
            y=y + 0.3,
            sizex=0.6,
            sizey=0.6,
            xref="x",
            yref="y",
            opacity=1.0,
            layer="above"
        )
    )

fig.update_layout(
    width=1200,
    height=900,
    xaxis=dict(
        range=[-0.5, 10.5],
        showgrid=False,
        zeroline=False,
        showticklabels=False,  # <-- Hide x tick labels
        visible=False           # <-- Hide x-axis entirely
    ),
    yaxis=dict(
        range=[-0.5, 10.5],
        showgrid=False,
        zeroline=False,
        showticklabels=False,  # <-- Hide y tick labels
        visible=False           # <-- Hide y-axis entirely
    ),
    showlegend=False,
)

fig.show()

NameError: name 'unit_rfs' is not defined

## Plot the average

In [None]:
### Plot the AVERAGE receptive field
import numpy as np
import matplotlib.pyplot as plt

# Calculate average RF
average_rf = np.mean(unit_rfs, axis=0)

# Create figure
fig, ax = plt.subplots(figsize=(8, 8))

# Plot average RF with colorbar
im = ax.imshow(average_rf, origin="lower", cmap='viridis')
ax.set_xlabel(field_units)
ax.set_ylabel(field_units)
#ax.set_title(f'Average Receptive Field (n={len(unit_rfs)} units)', fontsize=14)

# Add axis labels
ax.set_xticks(range(len(xs)))
ax.set_xticklabels(xs, rotation=90, fontsize=8)
ax.set_yticks(range(len(ys)))
ax.set_yticklabels(ys, fontsize=8)

# Show every other tick label for clarity
for i, l in enumerate(ax.xaxis.get_ticklabels()):
    if i % 2 != 0:
        l.set_visible(False)
for i, l in enumerate(ax.yaxis.get_ticklabels()):
    if i % 2 != 0:
        l.set_visible(False)

# # Add colorbar
# cbar = plt.colorbar(im, ax=ax)
# cbar.set_label('Average Response', rotation=270, labelpad=20)

ax.axis('off')

plt.tight_layout()
plt.show()

# Optional: Print statistics
print(f"Average RF shape: {average_rf.shape}")
print(f"Average RF min: {average_rf.min():.4f}")
print(f"Average RF max: {average_rf.max():.4f}")
print(f"Average RF mean: {average_rf.mean():.4f}")