In [None]:
from caiman.source_extraction.cnmf import cnmf

import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import scipy.io as sio
import seaborn as sns
from sklearn.cluster import AgglomerativeClustering

import os

from src.datetime import add_frames_to_datetime, image_desc_to_datetime, timestamp_to_datetime
from src.interpolate import interpolate, stitch, truncate
from src.tensor import minmax
from src.tensor_creation_hyperparams import Hyperparams

## Hyperparameter Setup

In [None]:
# Hyperparameters for F147
F147 = Hyperparams(name='F147')
F147.set_data_paths(estimates=[
    'results/F147_0_memmap__d1_247_d2_256_d3_1_order_C_frames_20995_.hdf5',
    'results/F147_1_memmap__d1_73_d2_256_d3_1_order_C_frames_20995_.hdf5'
])
F147.set_trial_metadata(
    trial='data/2p_raw/F147/20210526_LT_18_0.mat',
    trial_var='trial',
    trial_time_field='timestamps',
    trial_output_field='output',
    trial_fr=160
)
F147.set_image_metadata(
    image='results/F147_imfinfo_edit.mat',
    image_var='image',
    image_time_field='ImageDescription',
    image_fr=4.5
)
F147.set_component_evaluation(snr_thr=1.25, baseline_name='baseline', baseline_selected=1)
F147.set_visualization_params(n_clusters=20, heatmap_bound=2)
F147.set_alignment_params(
    events_field = ['laseron', 'turn_frame', 'laseroff'],
    align_opts = [('interpolate', 'mean'), ('interpolate', 'mean'), ('stitch', [2, 2]), ('truncate', 20)]
)

In [None]:
# Hyperparameters for F201
F201 = Hyperparams(name='F201')
F201.set_data_paths(estimates=[
    'results/F201_0_memmap__d1_320_d2_256_d3_1_order_C_frames_24040_.hdf5'
])
F201.set_trial_metadata(
    trial='data/2p_raw/F201/20210812_RT_13_59.mat',
    trial_var='trial',
    trial_time_field='timestamps',
    trial_output_field='output',
    trial_fr=160
)
F201.set_image_metadata(
    image='results/F201_imfinfo_edit.mat',
    image_var='image',
    image_time_field='ImageDescription',
    image_fr=4.5
)
F201.set_component_evaluation(snr_thr=1.25, baseline_name='baseline', baseline_selected=1)
F201.set_visualization_params(n_clusters=20, heatmap_bound=2)
F201.set_alignment_params(
    events_field = ['laseron', 'turn_frame', 'laseroff'],
    align_opts = [('interpolate', 'mean'), ('interpolate', 'mean'), ('stitch', [2, 2]), ('truncate', 20)]
)

In [None]:
# Currently selected hyperparameters
hyp = F147

## Metadata Loading

In [None]:
# Move to the main project directory
os.chdir('../')

In [None]:
# Load trial and image metadata
trial_info = sio.loadmat(hyp.trial)[hyp.trial_var].flatten()
image_info = sio.loadmat(hyp.image)[hyp.image_var].flatten()

## Trial Grouping

In [None]:
# Get the indices of the fields containing time information in the trial metadata
trial_time_index = trial_info.dtype.names.index(hyp.trial_time_field)

# Get the indices of the fields containing time information in the image metadata
image_time_index = image_info.dtype.names.index(hyp.image_time_field)

In [None]:
# Initialize an array to store the trial that each frame belongs to
trials_by_frame = np.empty(image_info.size)

In [None]:
# Initialize arrays to hold start and end time data for each trial
trial_times_start = np.empty(trial_info.size, dtype='datetime64[us]')
trial_times_end = np.empty(trial_info.size, dtype='datetime64[us]')

# Initialize an array to hold the time data for each frame
frame_timestamps = np.empty(image_info.size, dtype='datetime64[us]')

In [None]:
# Find the start and end time of each trial
for i in range(trial_info.size):
    trial_times_start[i] = timestamp_to_datetime(trial_info[i][trial_time_index][0])
    trial_times_end[i] = timestamp_to_datetime(trial_info[i][trial_time_index][-1])

In [None]:
# Initialize trial grouping parameters
trial_curr = 0
image_curr = 0

In [None]:
# Find which trial each frame belongs to
while image_curr < image_info.size:
    
    # Record the time of the current frame
    image_time = image_desc_to_datetime(image_info[image_curr][image_time_index][0])
    frame_timestamps[image_curr] = image_time
    
    # Check if the time of the current frame is before the start time of the current trial
    if image_time < trial_times_start[trial_curr]:
        
        # The current frame does not belong to any trial
        trials_by_frame[image_curr] = np.nan
        
        # Move on to the next frame
        image_curr += 1
    
    # Check if the time of the current frame is after the end time of the current trial
    elif image_time > trial_times_end[trial_curr]:
        
        # Move on to the next trial if there are still trials remaining
        if trial_curr < trial_info.size - 1:
            trial_curr += 1
        
        # The current frame is past the end time of the last trial otherwise
        else:
            
            # Therefore, the current frame does not belong to any trial
            trials_by_frame[image_curr] = np.nan
            
            # Move on to the next frame
            image_curr += 1
    
    # The time of the current frame is within the time of the current trial otherwise
    else:
        
        # Record the trial this frame belongs to
        trials_by_frame[image_curr] = trial_curr
        
        # Move on to the next frame
        image_curr += 1

## Data Loading

In [None]:
# Load all files containing results
cnms = []
for fname in hyp.estimates:
    cnms.append(cnmf.load_CNMF(fname))

In [None]:
# Get the neural activity traces
traces = []
for cnm in cnms:
    traces.append(cnm.estimates.C)

In [None]:
# Concatenate all traces
data_orig = np.concatenate(traces, axis=0)

## Component Evaluation

In [None]:
# Get the index of the field in the trial metadata containing output
trial_output_index = trial_info.dtype.names.index(hyp.trial_output_field)

In [None]:
# Find all trials that are baselines
trials_baseline = []
for i in range(trial_info.size):
    if trial_info[i][trial_output_index][0] == hyp.baseline_name:
        trials_baseline.append(i)

In [None]:
# Using the selected baseline as the noise region, calculate the standard deviation of the signal region
std_sig = np.std(
    data_orig[:, np.where(trials_by_frame != trials_baseline[hyp.baseline_selected])].squeeze(),
    axis=1, ddof=1)

# Using the selected baseline as the noise region, calculate the standard deviation of the noise region
std_noise = np.std(
    data_orig[:, np.where(trials_by_frame == trials_baseline[hyp.baseline_selected])].squeeze(),
    axis=1, ddof=1)

# Calculate the signal-to-noise ratio for each component
sig_noise_ratio = std_sig / std_noise

In [None]:
# Classify components as noise if their signal-to-noise ratios are below the threshold
noise_indices = []
for i in range(len(sig_noise_ratio)):
    if sig_noise_ratio[i] < hyp.snr_thr:
        noise_indices.append(i)

# Remove all noise components
data = np.delete(data_orig, noise_indices, axis=0)

## Hierarchical Clustering

In [None]:
# Normalize the data by z-score
data_norm = stats.zscore(data, axis=1, ddof=1)

In [None]:
# Label the data by clusters
clustering = AgglomerativeClustering(n_clusters=hyp.n_clusters).fit_predict(data_norm)

In [None]:
# Split the data into clusters
data_clusters = [[] for i in range(hyp.n_clusters)]
for i in range(data_norm.shape[0]):
    data_clusters[clustering[i]].append(data_norm[i])

In [None]:
# Concatenate the clusters
data_norm = np.concatenate(data_clusters, axis=0)

## Heatmap

In [None]:
# Change the default figure size
sns.set_theme(rc={'figure.figsize': (6.5, 3)})

In [None]:
# Create a heatmap
sns.heatmap(data_norm, vmin=-hyp.heatmap_bound, vmax=hyp.heatmap_bound, cmap='jet')

# Add a title and labels to the heatmap
plt.title(hyp.name + " Extracted Sources")
plt.xlabel("Frame")
plt.ylabel("Source")

# Remove axis tick numbers
plt.xticks([])
plt.yticks([])

# Display the final heatmap
plt.show()

## Alignment Parameters

In [None]:
# Find the number of events and number of intervals
n_event = len(hyp.events_field)
n_interval = len(hyp.align_opts)

In [None]:
# Get the indices of the fields in the trial metadata containing event frame information
events_index = []
for event_field in hyp.events_field:
    events_index.append(trial_info.dtype.names.index(event_field))

In [None]:
# Create an empty array to hold times of events used for alignment
events_time = np.empty((len(hyp.events_field), trial_info.size), dtype='datetime64[us]')

In [None]:
# Find all event times
for i in range(len(events_index)):
    
    # Find the time of the current event for each trial
    for trial in range(trial_info.size):
        
        # Get an array of event frames
        event_frames = trial_info[trial][events_index[i]][0]
        
        # Do not save a time if the event did not occur
        if event_frames.size == 0:
            events_time[i, trial] = np.datetime64('NaT')
        
        # Save the first time in the array otherwise
        else:
            events_time[i, trial] = add_frames_to_datetime(trial_times_start[trial], event_frames[0], hyp.trial_fr)

In [None]:
# Create empty arrays to hold the first and last frames of each interval
intervals_frame_first = np.full((len(hyp.events_field) + 1, trial_info.size), np.nan)
intervals_frame_last = np.full((len(hyp.events_field) + 1, trial_info.size), np.nan)

In [None]:
# Create empty lists to store trials to be kept and trials to be replaced
trials_valid = []
trials_replace = []

In [None]:
# Add interval frame bounds and trial validity information
for trial in range(trial_info.size):
    
    # Do not keep baselines
    if trial in trials_baseline:
        continue
    
    # If any event is missing, replace the trial with the last valid trial
    if np.sum(np.isnat(events_time[:, trial])) > 0:
        trials_replace.append(trial)
        continue
    
    # Determine the first and last frames of each interval
    for frame in np.where(trials_by_frame == trial)[0]:
        
        # Boolean for determining if the frame belongs to the interval after the last event
        in_last_interval = True
        
        # Determine which interval the frame belongs to
        for i in range(events_time.shape[0]):
            
            # The frame is in the interval before the event if its time is before the event time
            if frame_timestamps[frame] < events_time[i, trial]:
                
                # Set the frame as the first one of the interval if none exists
                if np.isnan(intervals_frame_first[i, trial]):
                    intervals_frame_first[i, trial] = frame
                
                # Update the last frame of the interval
                intervals_frame_last[i, trial] = frame
                
                # The frame cannot belong to any other intervals, so do not check the rest
                in_last_interval = False
                break
        
        # If the frame occurs after all events, it is part of the last interval
        if in_last_interval:
            
            # Set the frame as the first one of the last interval if none exists
            if np.isnan(intervals_frame_first[-1, trial]):
                intervals_frame_first[-1, trial] = frame
                
            # Update the last frame of the last interval
            intervals_frame_last[-1, trial] = frame 
    
    # Replace the trial if any interval does not contain frames
    if np.sum(np.isnan(intervals_frame_first[:, trial])) > 0:
        trials_replace.append(trial)
    
    # The trial is valid otherwise
    else:
        trials_valid.append(trial)

In [None]:
# Create a list to store the number of frames needed for each interval
intervals_n = []

In [None]:
# Calculate the mean and minimum number of frames for each interval
intervals_frame_elapsed = intervals_frame_last[:, trials_valid] - intervals_frame_first[:, trials_valid] + 1
intervals_frame_elapsed_mean = np.mean(intervals_frame_elapsed, axis=1).astype(np.int64)
intervals_frame_elapsed_min = np.min(intervals_frame_elapsed, axis=1).astype(np.int64)

In [None]:
# Find the number of frames needed for each interval
for i in range(len(hyp.align_opts)):
    
    # For regular interpolation, use the mean or specified number of seconds
    if hyp.align_opts[i][0] == 'interpolate':
        if hyp.align_opts[i][1] == 'mean':
            intervals_n.append(intervals_frame_elapsed_mean[i])
        else:
            intervals_n.append(round(hyp.align_opts[i][1] * hyp.image_fr))
    
    # For stitching, use the total specified number of seconds
    elif hyp.align_opts[i][0] == 'stitch':
        intervals_n.append(round(hyp.align_opts[i][1][0] * hyp.image_fr) + 
                           round(hyp.align_opts[i][1][1] * hyp.image_fr))
    
    # For truncating, use the minimum, possible bounded above by the specified number of seconds
    elif hyp.align_opts[i][0] == 'truncate':
        if hyp.align_opts[i][1] == 'min':
            intervals_n.append(intervals_frame_elapsed_min[i])
        else:
            intervals_n.append(min(intervals_frame_elapsed_min[i], round(hyp.align_opts[i][1] * hyp.image_fr)))

In [None]:
# Display the number of frames per interval and cumulative number of frames
cumulative = 0
for interval_n in intervals_n:
    cumulative += interval_n
    print(interval_n, cumulative)

## Interpolation

In [None]:
# Get all trials performed in the experiment in order
trials_experiment = sorted(trials_valid + trials_replace)

In [None]:
# Initialize empty tensors to store interpolated time series for each interval
interpol = []
for interval_n in intervals_n:
    interpol.append(np.empty((len(trials_experiment), data_norm.shape[0], interval_n)))

In [None]:
# Interpolate points for each trial
for i, trial in enumerate(trials_experiment):
    
    # Interpolate points if the trial is valid
    if trial in trials_valid:
        
        # Initialize an array to hold distances for translating the data after stitching
        distance = np.zeros((data_norm.shape[0], 1))
    
        # Interpolate points within each interval of the current trial
        for j in range(len(hyp.align_opts)):
            
            # Get the indices and times of the first and last image frame in the interval
            frame_start = int(intervals_frame_first[j, trial])
            frame_end = int(intervals_frame_last[j, trial])
            time_start = frame_timestamps[frame_start]
            time_end = frame_timestamps[frame_end]
            
            # Interpolate data within the interval if specified
            if hyp.align_opts[j][0] == 'interpolate':
                interpol[j][i] = interpolate(data_norm, intervals_n[j], time_start, time_end,
                                                    frame_start, frame_end, hyp.image_fr)
                interpol[j][i] += distance
            
            # Check if stitching is specified
            elif hyp.align_opts[j][0] == 'stitch':
                
                # If there is not enough frames to stitch, interpolate data instead
                if frame_end - frame_start + 1 < intervals_n[j]:
                    interpol[j][i] = interpolate(data_norm, intervals_n[j], time_start, time_end,
                                                        frame_start, frame_end, hyp.image_fr)
                    interpol[j][i] += distance
                
                # Stitch the start and end of the interval together otherwise and update the translation distance
                else:
                    interpol[j][i], distance_change = stitch(data_norm, hyp.align_opts[j][1],
                                                            frame_start, frame_end, hyp.image_fr)
                    interpol[j][i] += distance
                    distance += distance_change
            
            # Truncate data in the interval if speified
            elif hyp.align_opts[j][0] == 'truncate':
                interpol[j][i] = truncate(data_norm, intervals_n[j], frame_start)
                interpol[j][i] += distance
        
    # Use the previous trial's interpolated data if the trial needs to be replaced (assuming the first trial is valid)
    else:
        for j in range(len(hyp.align_opts)):
            interpol[j][i] = interpol[j][i - 1]

## Tensor Creation

In [None]:
# Concatenate the interpolation results into a tensor
tensor = np.concatenate(interpol, axis=2)

In [None]:
# Get tensor shape information
trials, neurons, times = tensor.shape

In [None]:
# Manually create an array used as a two-dimensional version of the tensor (debug, problems arise with np.reshape)
tensor_2d = np.empty((neurons, times * trials))
for i in range(trials):
    tensor_2d[:, i * times:(i + 1) * times] = tensor[i]

In [None]:
# Create a heatmap
sns.heatmap(tensor_2d, vmin=-hyp.heatmap_bound, vmax=hyp.heatmap_bound, cmap='jet')

# Add a title and labels to the heatmap
plt.title(hyp.name + " Extracted Sources After Alignment")
plt.xlabel("Frame")
plt.ylabel("Source")

# Remove axis tick numbers
plt.xticks([])
plt.yticks([])

# Display the final heatmap
plt.show()

In [None]:
# Min-max normalization
tensor_minmax_2d = minmax(tensor_2d, axis=1)

In [None]:
# Manually create a three-dimensional tensor using the min-max array
tensor_minmax = np.empty((trials, neurons, times))
for i in range(trials):
    tensor_minmax[i] = tensor_minmax_2d[:, i * times:(i + 1) * times]

In [None]:
# Save the two tensors
np.save('results/' + hyp.name + '_tensor_zscore.npy', tensor)
np.save('results/' + hyp.name + '_tensor_minmax.npy', tensor_minmax)