In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datajoint as dj
import os
import numpy as np
import statistics as stat

dj.config['database.host'] = "arseny-lab.cmte3q4ziyvy.il-central-1.rds.amazonaws.com"
dj.config['database.user'] = ""
dj.config['database.password'] = ""

conn = dj.conn()


In [None]:
import datajoint as dj
import numpy as np
import pandas as pd
# import dj_connect
# import getSchema
# from scipy.io import loadmat

# schema = getSchema.getSchema()

# conn = dj_connect.connectToDataJoint("talch012", "simple")

schema = dj.Schema("talch012_EPHYS_IMG")
exp = dj.VirtualModule("EXPt", "talch012_expt", create_tables=True)
lab = dj.VirtualModule("LABt", "talch012_labt", create_tables=True)
# ephys = dj.VirtualModule("EPHYS", "talch012_ephys_test", create_tables=True)
ephys = dj.VirtualModule('schema_module', 'talch012_EPHYS_TEST')


In [None]:
exp.SessionEpoch.insert1({
    'subject_id': 101104,
    'session': 1,
    'session_epoch_type': 'behav_only',
    'session_epoch_number': 1,
    'session_epoch_start_time': 0.0,
    'session_epoch_end_time': 0.0,
    'flag_photostim_epoch': 0,
}, allow_direct_insert=True)


In [None]:
#works
schema_module = dj.VirtualModule("schema_module", "talch012_EPHYS_TEST", create_tables=True)
spikes_time_fetch=schema_module.TrialSpikes
spike_times=spikes_time_fetch.fetch()
import numpy as np

bin_size = 0.1  # 100 ms bins

units = np.unique(spike_times['unit'])
subjects = np.unique(spike_times['subject_id'])
all_spike_data = {}

# Loop over units, subjects, and sessions
for subject_id in subjects:
    for session in np.unique(spike_times['session']):
        for unit in units:

            # Filter rows for current subject, session, and unit
            trials_subset = spike_times[
                (spike_times['subject_id'] == subject_id) &
                (spike_times['session'] == session) &
                (spike_times['unit'] == unit)
            ]

            if len(trials_subset) == 0:
                continue

            all_trial_spikes = []
            trial_start_times = []

            # Loop over trials in sorted order
            sorted_trials = np.sort(np.unique(trials_subset['trial']))

            for trial in sorted_trials:
                trial_data = trials_subset[trials_subset['trial'] == trial]
                if len(trial_data) == 0:
                    continue

                spike_in_trial = trial_data['spike_times']

                trial_times = (exp.SessionTrial10 & 
                            f'subject_id="{subject_id}"' & 
                            f'session={session}' & 
                            f'trial={trial}').fetch1()

                start_time = float(trial_times['start_time'])
                trial_start_times.append(start_time)

                for trial1 in spike_in_trial:
                    if isinstance(trial1, list) and len(trial1) > 0:
                        adjusted_spikes = np.array(trial1) + start_time
                        all_trial_spikes.append(adjusted_spikes)

            if len(all_trial_spikes) == 0:
                continue

            # Concatenate all spikes across trials
            concatenated_spikes = np.concatenate(all_trial_spikes)
            max_time = np.max(concatenated_spikes)

            # Compute bin edges and histogram
            time_bins = np.arange(0, max_time + bin_size, bin_size)
            spike_counts, bin_edges = np.histogram(concatenated_spikes, bins=time_bins)
            spike_rate_hz = spike_counts / bin_size
            time_line = bin_edges[:-1] + bin_size / 2

            # Compute start_frames for each trial based on bin_edges
            start_frames = np.digitize(trial_start_times, bin_edges) - 1  # subtract 1 to get 0-based bin index

            # Store results
            key = (subject_id, session, unit)
            all_spike_data[key] = {
                'spike_counts': spike_counts,
                'spike_rate_hz': spike_rate_hz,
                'time_line': time_line,
                'start_frames': np.array(start_frames),
                'bin_edges': bin_edges[:-1]
            }


In [None]:
def compute_bin_edges_and_trial_starts(key, bin_size=0.1):
    import numpy as np
    schema_module = dj.VirtualModule("schema_module", "talch012_EPHYS_TEST", create_tables=True)

    subject_id = key['subject_id']
    session = key['session']

    # Fetch trial start times sorted by trial number
    trial_start_times = (exp.SessionTrial10 & key).fetch('trial', 'start_time')
    if not trial_start_times:
        return None, None, None

    trial_start_times = sorted(trial_start_times, key=lambda x: x[0])
    trial_nums, start_times = zip(*trial_start_times)

    # Fetch all spikes for the session to get max spike time for bin edges
    spikes_all = (schema_module.TrialSpikes & key).fetch('spike_times')
    # Flatten spikes and add start times (if available)
    all_spikes = []
    for spike_list in spikes_all:
        if isinstance(spike_list, list) and len(spike_list) > 0:
            all_spikes.extend(spike_list)
    if len(all_spikes) == 0:
        return None, None, None

    max_spike_time = max(all_spikes) + max(start_times)  # conservative max time estimate
    max_time = max_spike_time + 1  # add margin

    time_bins = np.arange(0, max_time + bin_size, bin_size)

    # Compute start bin for each trial start time
    start_bins = np.digitize(start_times, time_bins) - 1  # zero-based indices

    return time_bins, trial_nums, start_bins


In [None]:
@schema
class FrameStartFile(dj.Imported):
    definition = """
    -> exp.SessionEpoch
    session_epoch_file_num : int
    ---
    session_epoch_file_start_frame : double
    """

    key_source = exp.SessionEpoch

    def make(self, key):
        time_bins, trial_nums, start_bins = compute_bin_edges_and_trial_starts(key)
        if time_bins is None:
            return

        insert_list = []
        for trial_num, start_bin in zip(trial_nums, start_bins):
            insert_list.append({
                **key,
                'session_epoch_file_num': trial_num,
                'session_epoch_file_start_frame': start_bin
            })

        if insert_list:
            self.insert(insert_list, skip_duplicates=True)


In [None]:
# WORKED
@schema
class FOV(dj.Imported):
    definition = """
        -> exp.Session
        fov_num : int   # assigned probe number as described
        ---
        imaging_frame_rate : int  # e.g. fixed 10 Hz or from bin size
    """

    key_source = exp.Session

    def make(self, key):
        # Fetch full probe table for this session
        probes = (ephys.Probe & key).fetch(as_dict=True)

        # Extract and filter numeric probe_part_no values
        probe_nums = []
        for p in probes:
            val = p.get('probe_part_no', '')
            try:
                # Convert to int if possible
                num = int(val)
                probe_nums.append(num)
            except (ValueError, TypeError):
                # Skip non-numeric or empty entries
                continue

        if probe_nums:
            for probe_num in sorted(probe_nums):
                insert_key = dict(key)
                insert_key['fov_num'] = probe_num
                insert_key['imaging_frame_rate'] = 10
                self.insert1(insert_key, allow_direct_insert=True)
        else:
            # No numeric probe numbers found: insert with 0
            insert_key = dict(key)
            insert_key['fov_num'] = 0
            insert_key['imaging_frame_rate'] = 10
            self.insert1(insert_key, allow_direct_insert=True)


In [None]:
FOV.populate()

In [None]:
# worked
@schema
class Plane(dj.Imported):
    definition = """
        -> FOV
        plane_num : smallint  # electrode_group from ephys.electrodegroup
        channel_num : smallint  # also from ephys.electrodegroup
        ---
    """

    key_source = FOV

    def make(self, key):
        # Fetch electrode groups for this FOV session
        egroups = (ephys.ElectrodeGroup & {
            'subject_id': key['subject_id'],
            'session': key['session'],
            'fov_num': key['fov_num']  # if relevant, else just session
        }).fetch(as_dict=True)

        for eg in egroups:
            insert_key = dict(key)
            insert_key['plane_num'] = eg['electrode_group']  # or eg['shank_num'] if named differently
            insert_key['channel_num'] = eg['electrode_group']  # if same as plane_num, else adjust
            self.insert1(insert_key, allow_direct_insert=True)


In [None]:
Plane.populate()

In [None]:
# worked
@schema
class ROI(dj.Computed):
    definition = """
    -> Plane
    roi_number: smallint  # unit from EPHYS.Unit
    ---
    roi_number_uid: bigint  # unit_uid from EPHYS.Unit
    """

    def make(self, key):
        # Get the plane_num from this plane
        plane_num = (Plane & key).fetch1('plane_num')

        # Get all units in this subject/session/electrode_group (== plane_num)
        units = (ephys.Unit &
                 {'subject_id': key['subject_id'],
                  'session': key['session'],
                  'electrode_group': plane_num}).fetch('unit', 'unit_uid')

        # For each unit, insert a ROI
        for unit, unit_uid in zip(*units):
            self.insert1({
                **key,
                'roi_number': unit,
                'roi_number_uid': unit_uid
            })


In [None]:
ROI.populate()

In [None]:
@schema
class ROISpikes(dj.Imported):
    definition = """
    -> exp.SessionEpoch
    -> ROI
    ---
    spikes_trace : longblob  # spikes per sec (Hz)
    spike_counts : longblob  # spikes per bin of 100ms
    time_line : longblob     # time line for the spikes trace
    start_frames : longblob  # start frames for each trial by the spikes trace
    bin_edges : longblob     # bin edges for the spikes trace
    """

    key_source = exp.SessionEpoch * ROI


    def make(self, key):
        import numpy as np
        import pandas as pd
        print("Make key:", key)  # Debug: check full key contents
        exists = (ROI & key).fetch()
        print("ROI exists?", bool(exists))
        subject_id = key['subject_id']
        session = key['session']

        spike_times_dicts = (ephys.TrialSpikes & {
            'subject_id': subject_id,
            'session': session
        }).fetch(as_dict=True)

        if len(spike_times_dicts) == 0:
            return  # no spikes, nothing to insert

        spike_times = pd.DataFrame(spike_times_dicts)

        bin_size = 0.1  # 100 ms bins

        units = spike_times['unit'].unique()
        insert_list = []

        for unit in units:
            trials_subset = spike_times[
                (spike_times['unit'] == unit)
            ]

            if trials_subset.empty:
                continue

            all_trial_spikes = []
            trial_start_times = []

            sorted_trials = np.sort(trials_subset['trial'].unique())

            for trial in sorted_trials:
                trial_data = trials_subset[trials_subset['trial'] == trial]
                if trial_data.empty:
                    continue

                spike_in_trial = trial_data['spike_times']

                trial_times = (exp.SessionTrial10 & {
                    'subject_id': subject_id,
                    'session': session,
                    'trial': trial
                }).fetch1()

                start_time = float(trial_times['start_time'])
                trial_start_times.append(start_time)

                for trial1 in spike_in_trial:
                    if isinstance(trial1, list) and len(trial1) > 0:
                        adjusted_spikes = np.array(trial1) + start_time
                        all_trial_spikes.append(adjusted_spikes)

            if len(all_trial_spikes) == 0:
                continue

            concatenated_spikes = np.concatenate(all_trial_spikes)
            max_time = np.max(concatenated_spikes)

            time_bins = np.arange(0, max_time + bin_size, bin_size)
            spike_counts, bin_edges = np.histogram(concatenated_spikes, bins=time_bins)
            spike_rate_hz = spike_counts / bin_size
            time_line = bin_edges[:-1] + bin_size / 2
            start_frames = np.digitize(trial_start_times, bin_edges) - 1
            try:
                # After computing everything but before appending:
                print(f"Processing unit {unit} for subject {subject_id}, session {session}: "
                    f"{len(all_trial_spikes)} trials, "
                    f"{len(concatenated_spikes)} total spikes, "
                    f"time range 0-{max_time:.2f}s")
            except:
                continue
            
            insert_list.append({
                **key,
                'spike_counts': spike_counts,
                'spikes_trace': spike_rate_hz,
                'time_line': time_line,
                'start_frames': start_frames,
                'bin_edges': bin_edges[:-1]
            })
            # for k, v in key.items():
            #     print(f"{k}: {v} ({type(v)})")

            # print("About to insert ROISpikes with key:")
            # print({
            #     **key,
            #     'spike_counts': spike_counts,
            #     'spikes_trace': spike_rate_hz,
            #     'time_line': time_line,
            #     'start_frames': start_frames,
            #     'bin_edges': bin_edges[:-1]
            # })

            self.insert1({
                **key,
                'spike_counts': spike_counts,
                'spikes_trace': spike_rate_hz,
                'time_line': time_line,
                'start_frames': start_frames,
                'bin_edges': bin_edges[:-1]
            }, skip_duplicates=True)
        # if insert_list:
        #     self.insert(insert_list, skip_duplicates=True)
                # 'session_epoch_type': 'behav_only',
                # 'session_epoch_number': 1,

        # batch_size = 200
        # for i in range(0, len(insert_list), batch_size):
        #     batch = insert_list[i:i+batch_size]
        #     self.insert(batch, skip_duplicates=True)


In [None]:
ROISpikes.populate()

In [None]:
# @schema
# class FrameStartFile(dj.Imported): 
#     definition = """
#     -> exp.SessionEpoch
#     session_epoch_file_num (int) # first and last bin of spike counts
#     ---
#     session_epoch_file_start_frame(double) # (s) session epoch start frame  (bin) relative to the beginning of the session epoch

# """

@schema
class FrameStartFile(dj.Imported):
    definition = """
    -> exp.SessionEpoch
    session_epoch_file_num : int  # first and last bin of spike counts
    ---
    session_epoch_file_start_frame : double  # (s) session epoch start frame (bin) relative to the beginning of the session epoch
    """

def make(self, key):
    time_bins, trial_nums, start_bins = compute_bin_edges_and_trial_starts(key)
    if time_bins is None:
        return

    insert_list = []
    for trial_num, start_bin in zip(trial_nums, start_bins):
        insert_list.append({
            **key,
            'session_epoch_file_num': trial_num,
            'session_epoch_file_start_frame': start_bin
        })
    if insert_list:
        self.insert(insert_list, skip_duplicates=True)


fetches

In [None]:
(ephys.Unit & {'subject_id': 101104, 'session': 1, 'electrode_group': 1}).fetch('unit', 'unit_uid')


In [None]:
ephys.Unit.fetch()