In [346]:
import scipy
import numpy as np
import json
import glob
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
from bisect import bisect_left
import statsmodels.api as sm
import os

In [347]:
def load_bpod_data(dpath):
    with open(dpath, 'r') as file:
        data = json.load(file)
    return data

def take_closest(ms_stamps, bpod_stamps):
    """
    Assumes ms_stamps is sorted. Returns closest stamps to the given bpod_stamps.

    If two numbers are equally close, return the smallest number.
    """
    msAlignedIndex = []
    for bpod_stamp in bpod_stamps:
        pos = bisect_left(ms_stamps, bpod_stamp)
        if pos == 0:
            return ms_stamps[0]
        if pos == len(ms_stamps):
            return ms_stamps[-1]
        before = ms_stamps[pos - 1]
        after = ms_stamps[pos]
        if after - bpod_stamp < bpod_stamp - before:
            msAlignedIndex.append(np.where(ms_stamps ==after))
        else:
            msAlignedIndex.append(np.where(ms_stamps ==before))
    return np.squeeze(np.array(msAlignedIndex))
    
def get_event_time(event, bpod_data):
    """
    Input a event name and bpod_data. Return the time stamp of the event(relative to trial start)
    """
    onsets = []
    for trial in bpod_data["SessionData"]['RawEvents']['Trial']:
        if not trial['States']['WrongPort'][0]:
            if type(trial['States'][event][0]) != float:
                onsets.append(np.squeeze(trial['States'][event][0])[0])
            else:
                onsets.append(np.squeeze(trial['States'][event])[0])
    return np.array(onsets)
    
def get_RPE_neuron_activities(trialRewards, calcium_activities, calcium_baselines):
    """Select RPE relative neurons using the method from Muller. 2024 """
    RPE_neurons = []
    RPE_neurons_id = []
    for i, singleNeuronActivity in enumerate(calcium_activities):
        model = sm.OLS(singleNeuronActivity, trialRewards).fit()
        if model.pvalues < 0.005:
            # print(singleNeuronActivity)
            # print(calcium_baselines[i])
            # print(singleNeuronActivity-calcium_baselines[i])
            RPE_neurons.append(singleNeuronActivity-calcium_baselines[i])
            RPE_neurons_id.append(i)
    return np.array(RPE_neurons), np.array(RPE_neurons_id)

### Load Data

In [348]:
mouse = 'ZZ0024-LR'
date = '2024-08-27'


In [349]:
### load wavesurfer stamps
ms_stamps = np.squeeze(scipy.io.loadmat(f"/Users/fgs/HMLworkplace/Arena_analysis/Data/ws_data/processed/{mouse}/{mouse}_{date}_ms.mat")['ms_frames_samplingstamps'])
bpod_trialstart_stamps = np.squeeze(scipy.io.loadmat(f"/Users/fgs/HMLworkplace/Arena_analysis/Data/ws_data/processed/{mouse}/{mouse}_{date}_bpod.mat")['bpod_trialstart_samplingstamps'])

In [351]:
### load bpod results
bpod_data = load_bpod_data('/Users/fgs/HMLworkplace/Arena_analysis/Data/bpod_data/ZZ0024-LR/ZZ0024-LR_unity_stage3test_v1_20240827_111252.json')
trialRewards = np.array(bpod_data["SessionData"]['TrialRewards'])

### Exclude wrong trials
wrongTrials = []
for i, trial in enumerate(bpod_data['SessionData']['RawEvents']['Trial']):
    if trial['States']['WrongPort'][0]:
        wrongTrials.append(i)
trialRewards = np.delete(trialRewards, wrongTrials)

In [None]:
### load minian results
minian_results = xr.open_dataset(f"/Users/fgs/HMLworkplace/Arena_analysis/Data/minian_data/{mouse}_{date}.netcdf")

### Reversal Points

In [None]:
BASELINE_WND_DURATION = 0.5
ACTIVITY_WND_DURATION = 2
FS = 20000


### Handle sessions during which the wavesurfer crashed
if len(bpod_trialstart_stamps) < 40:
    for i in range(len(bpod_trialstart_stamps), 40):
        next_stamp = bpod_trialstart_stamps[i-1] + bpod_data['SessionData']['RawEvents']['Trial'][i-1]['States']['EndState'][1] * FS
        bpod_trialstart_stamps = np.append(bpod_trialstart_stamps, next_stamp)
        
bpod_trialstart_stamps = np.delete(bpod_trialstart_stamps, wrongTrials)

if (len(minian_results.frame) - len(ms_stamps)) >= 150:
    for i in range(len(ms_stamps), len(minian_results.frame)):
        next_stamp = ms_stamps[i-1] + int(FS/30)
        ms_stamps = np.append(ms_stamps, next_stamp)


In [356]:
### get the needed bpod events' time stamps
toneTime = get_event_time('Tone',bpod_data)
rewardTime = get_event_time('Reward', bpod_data)
lickingTime = get_event_time('Drinking', bpod_data)

### set the baseline window(f0) and get the wavesurfer stamps for baseline window
baseline_time = toneTime - BASELINE_WND_DURATION - 0.5 
baseline_time = rewardTime - 0.5 # -0.5s to 0s before licking as the baseline window
baseline_stamps_bpod = baseline_time*FS + bpod_trialstart_stamps

### get the wavesurfer stamps for reward response window
reward_stamps_bpod = (rewardTime)*FS + bpod_trialstart_stamps

In [358]:
### get the needed events frames
FPS = 30
FRAME_DOWNSAMPLE = 1

### get the baseline start and end frames
msBaselineStartFrames = take_closest(ms_stamps, baseline_stamps_bpod)
fps = FPS/FRAME_DOWNSAMPLE
msBaselineStartFrames = msBaselineStartFrames/FRAME_DOWNSAMPLE
msBaselineStartFrames = msBaselineStartFrames.astype(int)
msBaselineEndFrames = (msBaselineStartFrames + BASELINE_WND_DURATION*fps).astype(int)

### get the reward start and end frames
msRewardStartFrames = (take_closest(ms_stamps, reward_stamps_bpod)/FRAME_DOWNSAMPLE).astype(int)
msRewardEndFrames = (msRewardStartFrames + ACTIVITY_WND_DURATION*fps).astype(int)


In [359]:
### calculate the mean calcium activities of baseline windows for each trials
calcium_traces = minian_results.C
calcium_traces_baselines = []
for i, msBaselineStartFrame in enumerate(msBaselineStartFrames):
    calcium_baseline_wnd = calcium_traces[:,msBaselineStartFrames[i]:msBaselineEndFrames[i]]
    mean = np.mean(calcium_baseline_wnd, axis=1)
    calcium_traces_baselines.append(mean)

calcium_traces_baselines = np.array(calcium_traces_baselines).T


### calculate the mean calcium activities of reward response windows for each trials
mean_reward_responses = []

for i,msRewardStartFrame in enumerate(msRewardStartFrames):
    calcium_activities_wnd = calcium_traces[:,msRewardStartFrames[i]:msRewardEndFrames[i]]
    mean = np.mean(calcium_activities_wnd, axis=1)
    mean_reward_responses.append(mean)

mean_reward_responses = np.array(mean_reward_responses).T

In [360]:
### Directly use original calcium activities instead of df/f0
no_baseline = np.zeros(calcium_traces_baselines.shape)

In [361]:
### get the RPE relative neurons
RPE_neurons_reward_responses, RPE_neurons_id= get_RPE_neuron_activities(trialRewards, mean_reward_responses, no_baseline)


In [362]:
# for i,activity in enumerate(RPE_neurons_rewards_activities):
#     fig_path = f'/Users/fgs/HMLworkplace/Arena_analysis/Results_temp/{mouse}/{date}/C/-0.5~1beforelicking'
#     if not os.path.exists(fig_path):
#         os.makedirs(fig_path)
#     slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(trialRewards, activity)
#     x_intercept = -intercept / slope
#     plt.scatter(trialRewards, activity, color='blue', label='Data Points')
#     plt.plot(trialRewards, slope * trialRewards + intercept, color='red', label=f'Fitted Line: y={slope:.2f}x + {intercept:.2f}')
#     plt.axhline(0, color='gray', linestyle='--')
#     plt.axvline(x_intercept, color='green', linestyle='--', label=f'x-intercept: {x_intercept:.2f}')
#     plt.title('Linear Regression and x-intercept')
#     plt.xlabel('x')
#     plt.ylabel('y')
#     plt.legend()
#     plt.grid(True)
#     plt.savefig(f'{fig_path}/reversal_point{i}.png')
#     plt.clf()

# x_intercept

In [363]:
# for i,activity in enumerate(RPE_neurons_rewards_activities):
#     fig_path = f'/Users/fgs/HMLworkplace/Arena_analysis/Results_temp/{mouse}/{date}/C/'
#     if not os.path.exists(fig_path):
#         os.makedirs(fig_path)
#     plt.scatter(trialRewards, activity)
#     plt.savefig(f'{fig_path}/RPE_neuron{i}.png')
#     plt.clf()

In [364]:
### Z-score the mean calcium activities
zscored_RPE_reward_responses = (RPE_neurons_reward_responses - np.mean(RPE_neurons_reward_responses, axis=1, keepdims=True))/np.std(RPE_neurons_reward_responses, axis=1, keepdims=True)

In [365]:
### Estimate and plot reversal points
reward_amounts = np.unique(trialRewards)
for i,activity in enumerate(zscored_RPE_reward_responses):
    fig_path = f'/Users/fgs/HMLworkplace/Arena_analysis/Results_temp/{mouse}/{date}/C/normalized'
    if not os.path.exists(fig_path):
        os.makedirs(fig_path)
    plt.scatter(trialRewards, activity, color='gray', alpha=0.1)
    means = []
    for reward_amount in reward_amounts:
        mean = np.mean(activity[trialRewards == reward_amount])
        std = np.std(activity[trialRewards == reward_amount])
        means.append(mean)
        #plt.scatter(reward_amount, mean, color='orange', s=50)
        plt.errorbar(reward_amount, mean, yerr=std, fmt='o', color='orange', ecolor='lightgray', elinewidth=2, capsize=4)
    means = np.array(means)
    reward_amounts = np.array(reward_amounts)
    slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(trialRewards, activity)
    plt.plot(reward_amounts, slope * reward_amounts + intercept, color='orange', label=f'RP={-(intercept/slope)}')
    plt.legend()
    plt.savefig(f'{fig_path}/RPE_neuron{RPE_neurons_id[i]}.png')
    plt.clf()

<Figure size 640x480 with 0 Axes>

### Calcium Trace Plot


In [343]:
### A function to plot given calcium activity around time points of interest
def plot_calcium(reward_amounts, neural_signal, neuron_id, save_path):
    fig, axs = plt.subplots(1, len(reward_amounts), figsize=(14,10))
    time_range = np.linspace(-3,3,180)
    max_ylim = 0
    for reward_amount in reward_amounts:    
        max = np.max(neural_signal[reward_amount])
        if max >= max_ylim:
            max_ylim = max
    for i, reward_amount in enumerate(reward_amounts):
        ax = axs.flat[i]
        colors = plt.cm.viridis(np.linspace(0, 1, len(neural_signal[reward_amount])))
        average_signal = np.mean(neural_signal[reward_amount], axis=0)

        for j, series in enumerate(neural_signal[reward_amount]):
            ax.plot(time_range, series, color=colors[j], alpha=0.5)
        
        ax.plot(time_range, average_signal, color = 'black', linewidth = 2, label='Average')
        #max_ylim = max(ax.get_ylim(), key=abs)
        ax.set_ylim(0, max_ylim)        
        ax.set_title(f'Reward amount {reward_amount}')
        ax.set_xlabel('Time')
        ax.set_ylabel('Raw fluorescence')
        ax.legend()
    fig.suptitle(f'Neuron{neuron_id} activity around reward', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fig.savefig(f'{save_path}/neuron{neuron_id}.png')
    plt.close(fig)


In [344]:
### A function to plot calcium activity as Muller. 2024 style
def plot_calcium_DRL(reward_amounts, neural_signal, neuron_id, save_path):
    cmap = plt.get_cmap("tab10")
    colors = [cmap(i) for i in np.linspace(0, 1, len(reward_amounts))]
    fig = plt.figure(figsize=(14,10))
    time_range = np.linspace(-3,3,180)
    for i, reward_amount in enumerate(reward_amounts):
        average_signal = np.mean(neural_signal[reward_amount], axis=0)
        std = np.std(neural_signal[reward_amount], axis=0)
        plt.plot(time_range, average_signal, color=colors[i], label=f'reward{reward_amount}')
        plt.fill_between(time_range, average_signal-std, average_signal+std, color=colors[i], alpha=0.2)
        plt.legend()
    plt.xlabel('Times')
    plt.ylabel('Raw fluorescence')
    fig.suptitle(f'Neuron{neuron_id} activity around reward', fontsize=16)
    fig.savefig(f'{save_path}/neuron{neuron_id}.png')
    plt.close(fig)

In [None]:
reward_amounts = np.unique(trialRewards)

### Plot calcium activites around reward onset
# Iterate all neurons
for k, single_neuron_C in enumerate(minian_results.C):
    singleNeuronAroundReward = {}
    # Iterate all reward amounts
    for reward_amount in reward_amounts:
        msRewardFrames = take_closest(ms_stamps, reward_stamps_bpod[trialRewards == reward_amount]).astype(int)
        rewardWindowStarts = msRewardFrames - FPS*3
        rewardWindowEnds = msRewardFrames + FPS*3
        calciumTracesAroundReward_list = []
    # Iterate all licking time stamps in trials with a specific reward amount
        for i, rewardWindowStart in enumerate(rewardWindowStarts):
            rewardWindowEnd = rewardWindowEnds[i]
            calciumTracesAroundReward = single_neuron_C[rewardWindowStart:rewardWindowEnd]
            calciumTracesAroundReward_list.append(calciumTracesAroundReward)
        singleNeuronAroundReward[reward_amount] = np.array(calciumTracesAroundReward_list)
    save_path = f'/Users/fgs/HMLworkplace/Arena_analysis/Results_temp/{mouse}/{date}/activitiesAroundReward'
    save_pathDRL = f'/Users/fgs/HMLworkplace/Arena_analysis/Results_temp/{mouse}/{date}/DRL'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(save_pathDRL):
        os.makedirs(save_pathDRL)
    plot_calcium(reward_amounts, singleNeuronAroundReward, k, 
                 save_path=save_path)
    plot_calcium_DRL(reward_amounts, singleNeuronAroundReward, k,
                     save_path=save_pathDRL)
    
