In [None]:
# trf_maker.py
# Avery Krieger 11/2022, adapted components from Tim Currier and Max Turner

# argumenmts: [1] date (yyyy-mm-dd); [2] series_number; [3] roi_set_name
# implementation: save_strfs.py 2022-03-17 1 roi_set_postfrom visanalysis.analysis import imaging_data, shared_analysis

import os
from pathlib import Path

import matplotlib.patches as mpatches  # for the legend patches
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem
from scipy import interpolate
from two_photon_analysis import medulla_analysis as ma
from visanalysis.analysis import imaging_data, shared_analysis

In [None]:
# Function to randomly draw a value
def getRandVal(rand_min, rand_max, start_seed, update_rate, t):
  seed = int(round(start_seed + t*update_rate))
  np.random.seed(seed)
  return np.random.choice([rand_min, (rand_min + rand_max)/2, rand_max], size=1)[0]

In [None]:
def optoSplit(roi_trfs, silent = True):
    # SPLIT into NO Opto and YES Opto trials (alternating)
    no_slice = roi_trfs[:, :, 0::2]  # Every 2 trials, starting at 0
    yes_slice = roi_trfs[:, :, 1::2]  # Every 2 trials, starting at 1

    # compute mean STRF
    # roi_mean_strf = np.mean(roi_strf,3);
    roi_mean_trf = np.mean(roi_trfs, 2)
    nopto_mean_trf = np.mean(no_slice, 2)
    yopto_mean_trf = np.mean(yes_slice, 2)

    # Standard Error of the Mean calculations
    nopto_sem = sem(no_slice, axis=2)  # calculate the no opto sem
    nopto_sem_plus = np.squeeze(nopto_mean_trf + nopto_sem)
    nopto_sem_minus = np.squeeze(nopto_mean_trf - nopto_sem)
    yopto_sem = sem(yes_slice, axis=2)  # calculate the yes opto sem
    yopto_sem_plus = np.squeeze(yopto_mean_trf + yopto_sem)
    yopto_sem_minus = np.squeeze(yopto_mean_trf - yopto_sem)


    if silent == False:
        # Checking ouputs for separating no opto from opto trials
        print("\n----------------------------------------------------------------------------")
        print("----------------------------------------------------------------------------")
        print("||    Checking the shape of the various trfs to ensure opto/no opto split:!")
        print("||    Shape of roi_trfs is: " + str(roi_trfs.shape))
        print("||    Shape of no_slice is: " + str(no_slice.shape))
        print("||    Shape of yes_slice is: " + str(yes_slice.shape))
        print("||")
        print("||    Shape of roi_mean_trf is: " + str(roi_mean_trf.shape))
        print("||    Shape of nopto_mean_trf is: " + str(nopto_mean_trf.shape))
        print("||    Shape of yopto_mean_trf is: " + str(yopto_mean_trf.shape))
        print("----------------------------------------------------------------------------")
        print("----------------------------------------------------------------------------\n")

        print(f'shape of std error is {nopto_sem_plus.shape}')
        
    return (roi_mean_trf, nopto_mean_trf, nopto_sem_plus, nopto_sem_minus, 
            yopto_mean_trf, yopto_sem_plus, yopto_sem_minus
           )

In [None]:
def avgAcrossROIs(nopto_mean_trf, nopto_sem_plus, nopto_sem_minus, yopto_mean_trf, yopto_sem_plus, yopto_sem_minus):
    across_roi_nopto_trf = np.mean(nopto_mean_trf, axis = 0)
    across_roi_nopto_sem_plus = np.mean(nopto_sem_plus, axis = 0)
    across_roi_nopto_sem_minus = np.mean(nopto_sem_minus, axis = 0)
    across_roi_yopto_trf = np.mean(yopto_mean_trf, axis = 0)
    across_roi_yopto_sem_plus = np.mean(yopto_sem_plus, axis = 0)
    across_roi_yopto_sem_minus = np.mean(yopto_sem_minus, axis = 0)
    
    return across_roi_nopto_trf, across_roi_nopto_sem_plus, across_roi_nopto_sem_minus, across_roi_yopto_trf, across_roi_yopto_sem_plus, across_roi_yopto_sem_minus

In [None]:
def trfMaker(
             experiment_file_directory, experiment_file_name, 
             series_number, roi_set_name, filter_length, dff, savefig
            ):
    # join path to proper format for ImagingDataObject()
    file_path = os.path.join(experiment_file_directory, experiment_file_name + ".hdf5")
    print(file_path)
    # create save directory
    save_directory = "/Volumes/ROG2TBAK/data/bruker/trfs/" + experiment_file_name + "/"
    print(save_directory)
    Path(save_directory).mkdir(exist_ok=True)
    # create ImagingDataObject (wants a path to an hdf5 file and a series number from that file)
    ID = imaging_data.ImagingDataObject(file_path, series_number, quiet=True)

    # Interpolation of the ROI response trace into sample period time
    roi_ind = 0
    response_trace = roi_data.get('roi_response')[roi_ind][0, :]
    response_time = np.arange(1, len(response_trace)+1) * ID.getAcquisitionMetadata('sample_period')
    # Interpolation function to interpolate response. Fxn built using sample period
    f_interp_response = interpolate.interp1d(response_time, response_trace)
    
    # Establish relevant variables
    epoch_parameters = ID.getEpochParameters()
    ideal_frame_rate = 120  # Hz
    sample_period = ID.getAcquisitionMetadata('sample_period') # (sec), bruker imaging acquisition period
    filter_len = filter_length * ideal_frame_rate
    stimulus_timing = ID.getStimulusTiming(plot_trace_flag=False)
    stimulus_start_times = stimulus_timing['stimulus_start_times']
    run_parameters = ID.getRunParameters
    stim_frames = run_parameters('stim_time') * ideal_frame_rate

    # frame flip times in a stimulus presentation:
    stim_times = np.arange(1, stim_frames+1) * 1/ideal_frame_rate

    # Initialize roi_trfs - ROI x Filter x Trials
    roi_trfs = np.zeros(
        (
            roi_data["epoch_response"].shape[0],
            int(filter_len),
            int(run_parameters("num_epochs")),
        )
    )
    all_stims = []
    all_responses = []

    print('\n----------------------------------------------------------------------------------')
    print('Initializing: (1) visual stim recreation (2) fft responses (3) generate filters...\n')

    for epoch_ind in range(int(ID.getRunParameters('num_epochs'))):
        if epoch_ind%10 == 0:
            print(f'...Starting Trial {epoch_ind} of {int(ID.getRunParameters("num_epochs"))}...')
        for roi_ind in range(0, roi_data["epoch_response"].shape[0]):
            # initalize trf by trial array (T, trial)
            roi_trf = np.zeros((int(filter_len), int(run_parameters("num_epochs"))))

            response_trace = roi_data.get('roi_response')[roi_ind][0, :]
            response_time = np.arange(1, len(response_trace)+1) * ID.getAcquisitionMetadata('sample_period')
            # Interpolation function to interpolate response. Fxn built using sample period
            f_interp_response = interpolate.interp1d(response_time, response_trace)

            # Regenerate the stimulus
            start_seed = epoch_parameters[epoch_ind]['start_seed']
            rand_min = eval(epoch_parameters[epoch_ind]['distribution_data'])['kwargs']['rand_min']
            rand_max = eval(epoch_parameters[epoch_ind]['distribution_data'])['kwargs']['rand_max']
            update_rate = epoch_parameters[epoch_ind]['update_rate']

            new_stim = np.array([getRandVal(rand_min, rand_max, start_seed, update_rate, t) for t in stim_times])
            current_frame_times = stimulus_start_times[epoch_ind] + stim_times  # In Prairie View time (sec)

            baseline_time = 1 # (sec) generally could be pre_time, but for opto only take previous 1 sec
            baseline_times = np.linspace(current_frame_times[0]-baseline_time, current_frame_times[0], int(1/sample_period))
            baseline = np.mean(f_interp_response(baseline_times))

            if dff == True:
                # Convert to dF/F
                current_interp_response = (f_interp_response(current_frame_times) - baseline) / baseline
            else: # don't df/f
                current_interp_response = f_interp_response(current_frame_times) 

            filter_fft = np.fft.fft(current_interp_response - np.mean(current_interp_response)) * np.conj(
              np.fft.fft(new_stim - np.mean(new_stim)))

            filt = np.real(np.fft.ifft(filter_fft))[0 : int(filter_len)]

            trf = np.flip(filt)

            #roi_trf[:, epoch_ind] = trf
            roi_trfs[roi_ind, :, epoch_ind] = trf

    #all_trfs = np.stack(all_trfs, axis=-1)
    print('\n-------')
    print('DONE!')
    print('-------\n')
    
    # Run function to average across ROIs
    (across_roi_nopto_trf, across_roi_nopto_sem_plus, across_roi_nopto_sem_minus, 
     across_roi_yopto_trf,across_roi_yopto_sem_plus, across_roi_yopto_sem_minus) = \ 
    avgAcrossROIs(
        nopto_mean_trf, nopto_sem_plus, nopto_sem_minus, yopto_mean_trf, yopto_sem_plus, 
                  yopto_sem_minus)