WHERE needs input parameters is labeled with "=== xxxx ===" in the headline!

## Import packages

In [None]:
import os
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import copy
import sys
import random
import h5py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR

## Read and plot fluorescence structure data

In [None]:
# # The code below is used to see whether a trial's post 2s is same as another trial's pre 2s.
# # However, no finding that a trial's post 2s is same as another trial's pre 2s, because the
# # post 2s periods are processed to make them more close to the baseline.

# # Here, use dFF data. Changing mat_data['Master_dFF'] to mat_data['Master_f'] to check F data.

# root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"

# i = 1
# path_ = os.path.join(root_path, cell_name, cell_name + 'red_dFFStructuresrun' + run_num[i] + '.mat')
# mat_data = scipy.io.loadmat(path_)
# fluo_runx = mat_data['Master_dFF']

# conca_runx = np.empty((0,fluo_runx[0,0].shape[2]))
# index_record = []
# action = True
# while action:
#     action = False
#     for index, value in np.ndenumerate(fluo_dFF_runx):
#         value = np.squeeze(value)
#         # print(value.shape)
#         # print(value[-31:, :].shape)
#         if conca_runx.shape[0] == 0:
#             conca_runx = np.concatenate((conca_runx, value), axis=0)
#             index_record = index_record + [index]
#             action = True
#             break
#         if np.array_equal(conca_runx[:31, :], value[-31:, :]):
#             conca_runx = np.concatenate((value[:-31, :], conca_runx), axis=0)
#             index_record = [index] + index_record
#             action = True
#             break
#         if np.array_equal(conca_runx[-31:, :], value[:31, :]):
#             fluo_dFconca_dFF_runxF_runx = np.concatenate((conca_runx, value[31:, :]), axis=0)
#             index_record = index_record + [index]
#             action = True
#             break
# print(conca_runx.shape)
# # a shape (93, 10) means no found.

### Functions: read and plot structure data

In [None]:
def read_conca_fluo_data(cell_name = 'CL090_230515',
             run_num = ['4', '5', '6'],
             color = 'red',
             datatype = 'F',
             root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"):
    '''
    This fuction is used to concatenate dFF or F data from different pieces of
    red/green structure data.
    color = 'red' or 'green'
    type = 'F' or 'dFF'
    if len(run_num) is 3, the output is <class 'numpy.ndarray'> with shape (3, 48),
    and its elements are <class 'numpy.ndarray'> with shape (93, 1, 10) for red
    and shape (93, 281, 10) for green (281 is the number of components, 10 is
    the number of repeats and 93 is 6 seconds; the 2p imaging frequency is
    15.63 Hz so roughly 93 frames for 6 seconds).
    '''

    conca_fluo_data = np.empty((0, 48))

    for i in range(len(run_num)):
        path_ = os.path.join(root_path, cell_name, cell_name + color + '_dFFStructuresrun' + run_num[i] + '.mat')
        mat_data = scipy.io.loadmat(path_)
        if datatype == 'dFF':
            fluo_data_runx = mat_data['Master_dFF']
        elif datatype == 'F':
            fluo_data_runx = mat_data['Master_f']
        fluo_data_runx = fluo_data_runx[:, :-1] # delete the last column (49th)
        # print(type(fluo_data_runx))
        # print(fluo_data_runx.shape)
        # print(fluo_data_runx[0,0].shape)

        conca_fluo_data = np.concatenate((conca_fluo_data, fluo_data_runx), axis=0) # like np.vsatck

    return conca_fluo_data

def plot_all_trials(conca_fluo_data,
           cell_name = 'CL090_230515',
           run_num = ['4', '5', '6'],
           color = 'red',
           datatype = 'F',
           component = 1):
    '''
    This function draws the figures based on the conca_fluo_data generated by
    the function read_conca_fluo_data.
    color = 'red' or 'green'
    type = 'F' or 'dFF'
    component is for 'green'; it is from 1 to the component number.
    if color is 'red', it can be arbitrary value.
    '''
    runs_string = ' '.join(run_num)
    color_name = 'Red' if color == 'red' else 'Green'

    # Create some sample data
    x_intervals = np.linspace(0, 6, 94) # Divide [0, 6] into 94 intervals, cuz data length of a trial is 93
    x_middle_points = (x_intervals[:-1] + x_intervals[1:]) / 2  # Calculate middle points

    # Create a figure and axis
    fig, ax = plt.subplots(nrows=48, ncols=len(run_num), figsize=(5 * len(run_num), 3 * 48), constrained_layout=True)

    # Ensure ax is always 2-dimensional, otherwise if len(run_num) is 1, ax will be 1-dimensional (we need 2-dimensional ax).
    if len(ax.shape) == 1:
        ax = ax[:, np.newaxis]

    # Create sf array
    sf = np.tile([0.02, 0.08, 0.32], 16)
    # Create tf array
    tf = np.tile([1, 1, 1, 4, 4, 4], 8)
    # Create orientation array
    orientation = np.repeat(np.arange(0, 360, 45), 6)

    for x in range(48):
        for y in range(len(run_num)):
            if color == 'green':
                data_patch = conca_fluo_data[y, x][:,component-1,:]
            if color == 'red':
                data_patch = np.squeeze(conca_fluo_data[y, x])
            for i in range(data_patch.shape[1]):
                ax[x, y].plot(x_middle_points, data_patch[:, i])

            # Set axis labels and title
            ax[x, y].set_xlabel('Time (seconds)', fontsize=16)
            ax[x, y].set_ylabel(datatype, fontsize=16)
            # ax.set_title('dFF for a trial', fontsize=18)

            # Set tick parameters
            ax[x, y].tick_params(labelsize=16)  # Adjust tick size as needed

            ax[x, y].axvspan(2, 4, facecolor='gray', alpha=0.2)

            if y == 0:
                x_data = ax[x, y].get_xticks()
                y_data = ax[x, y].get_yticks()
                x_data_min = np.min(x_data)
                x_data_max = np.max(x_data)
                y_data_min = np.min(y_data)
                y_data_max = np.max(y_data)
                x_text = x_data_min - 0.3 * (x_data_max - x_data_min)
                y_text = 0.5 * (y_data_max + y_data_min)

        # Write the group information to the right of each row
        ax[x, 0].text(x_text, y_text, f'Ori: {orientation[x]}° \nTF: {tf[x]} \nSF: {sf[x]}', fontsize=14, va='center')
    if color == 'red':
        fig.suptitle(f'All Conditions All Rounds All Repests {color_name} {datatype} Data ({cell_name}) (Columns: run {runs_string})', fontsize=16)
        filename = f'{cell_name}_{color_name}_{datatype}.pdf'
    if color == 'green':
        fig.suptitle(f'All Conditions All Rounds All Repests {color_name} {datatype} Data ({cell_name} Component {component}) (Columns: run {runs_string})', fontsize=16)
        filename = f'{cell_name}_{color_name}_{datatype}_{component:03d}.pdf'
    plt.savefig(filename)
    plt.show()


def plot_trials_separately_in_one_run(conca_fluo_data,
           cell_name = 'CL090_230515',
           run_num = ['4', '5', '6'],
           run_num_ = '4',
           color = 'red',
           datatype = 'F',
           component = 1):
    '''
    This function draws the figures based on the conca_fluo_data generated by
    the function read_conca_fluo_data.
    color = 'red' or 'green'
    type = 'F' or 'dFF'
    component is for 'green'; it is from 1 to the component number.
    if color is 'red', it can be arbitrary value.
    '''
    runs_string = ' '.join(run_num)
    color_name = 'Red' if color == 'red' else 'Green'

    # Create some sample data
    x_intervals = np.linspace(0, 6, 94) # Divide [0, 6] into 94 intervals, cuz data length of a trial is 93
    x_middle_points = (x_intervals[:-1] + x_intervals[1:]) / 2  # Calculate middle points
    # x_intervals = np.linspace(0, 14, 218) # 31+93+93=217
    # x_middle_points = (x_intervals[:-1] + x_intervals[1:]) / 2 
    # !!! Note: choosing (0,6,93) or (0, 14, 218)
    # depends on the intertrial duration. It is possible that even if the intertrial is 6s (93 frames;total 93+31+93)
    # but the structure data only uses 2 s before  and after the stimultation, thereby being 31 fames for intertrials 
    # (total 31+31+31, keeping same as the original setting).

    # Create sf array
    sf = np.tile([0.02, 0.08, 0.32], 16)
    # sf = np.tile([0.08], 8)
    # Create tf array
    tf = np.tile([1, 1, 1, 4, 4, 4], 8)
    # tf = np.tile([2], 8)
    # Create orientation array
    orientation = np.repeat(np.arange(0, 360, 45), 6)
    # orientation = np.repeat(np.arange(0, 360, 45), 1)

    for x in range(48):
    # for x in range(8):
        if color == 'green':
            data_patch = conca_fluo_data[run_num.index(run_num_), x][:,component-1,:]
        if color == 'red':
            data_patch = np.squeeze(conca_fluo_data[run_num.index(run_num_), x])
        if x == 0: # only at first time, create the figure 
            # Create a figure and axis
            # fig, ax = plt.subplots(nrows=8, ncols=data_patch.shape[1], figsize=(5 * data_patch.shape[1], 3 * 8), constrained_layout=True)
            fig, ax = plt.subplots(nrows=48, ncols=data_patch.shape[1], figsize=(5 * data_patch.shape[1], 3 * 48), constrained_layout=True)
            if len(ax.shape) == 1: ## ensure ax is 2-dimensional
                ax = ax[:, np.newaxis]
        for y in range(data_patch.shape[1]):
            ax[x, y].plot(x_middle_points, data_patch[:, y])

            # Set axis labels and title
            ax[x, y].set_xlabel('Time (seconds)', fontsize=16)
            ax[x, y].set_ylabel(datatype, fontsize=16)
            # ax.set_title('dFF for a trial', fontsize=18)

            # Set tick parameters
            ax[x, y].tick_params(labelsize=16)  # Adjust tick size as needed

            ax[x, y].axvspan(2, 4, facecolor='gray', alpha=0.2)
            # ax[x, y].axvspan(6, 8, facecolor='gray', alpha=0.2)

            ax[x, y].set_title(f'Repeat {y+1}')

            if y == 0:
                x_data = ax[x, y].get_xticks()
                y_data = ax[x, y].get_yticks()
                x_data_min = np.min(x_data)
                x_data_max = np.max(x_data)
                y_data_min = np.min(y_data)
                y_data_max = np.max(y_data)
                x_text = x_data_min - 0.3 * (x_data_max - x_data_min)
                y_text = 0.5 * (y_data_max + y_data_min)

        # Write the group information to the right of each row
        ax[x, 0].text(x_text, y_text, f'Ori: {orientation[x]}° \nTF: {tf[x]} \nSF: {sf[x]}', fontsize=14, va='center')
    if color == 'red':
        fig.suptitle(f'All Conditions All Repests for run {run_num_} {color_name} {datatype} Data ({cell_name}) (Columns: run {runs_string})', fontsize=16)
        filename = f'{cell_name}_{color_name}_{datatype}_run-{run_num_}.pdf'
    if color == 'green':
        fig.suptitle(f'All Conditions All Repests for run {run_num_} {color_name} {datatype} Data ({cell_name} Component {component}) (Columns: run {runs_string})', fontsize=16)
        filename = f'{cell_name}_{color_name}_{datatype}__run-{run_num_}_{component:03d}.pdf'
    plt.savefig(filename)
    plt.show()

### Eg of Function Use 1

In [None]:
# cell_name = 'CL096_231018'
# run_num = ['4']
# run_num_ = '4'
# color = 'red'
# datatype = 'F'
# component = 1
# conca_fluo_data = read_conca_fluo_data(cell_name, run_num, color, datatype)
# plot_all_trials(conca_fluo_data, cell_name, run_num, color, datatype, component)
# plot_trials_separately_in_one_run(conca_fluo_data, cell_name, run_num, run_num_, color, datatype, component)

# color = 'green'
# datatype = 'F'
# component = 1
# conca_fluo_data = read_conca_fluo_data(cell_name, run_num, color, datatype)
# plot_all_trials(conca_fluo_data, cell_name, run_num, color, datatype, component)

# print(type(conca_fluo_data))
# print(conca_fluo_data.shape)
# print(type(conca_fluo_data[1,1]))
# print(conca_fluo_data[1,1].shape)

### Eg of Function Use 2

Read red and green data then plot

In [None]:
# cell_name_list = ['CL075_230228', 'CL079_230324' 'CL090_230515']
# run_num_list = [['1', '2', '3'], ['1', '2', '3'], ['4', '5', '6']]

# cell_name_list = ['CL090_230515']
# run_num_list = [['4', '5', '6']]

# for cell_name_, run_num_ in zip(cell_name_list, run_num_list):
#     for color_ in ['red', 'green']:
#         for datatype_ in ['F', 'dFF']:
#             conca_fluo_data_ = read_conca_fluo_data(cell_name_, run_num_, color_, datatype_)
#             if color_ == 'red':
#                 plot_all_trials(conca_fluo_data_, cell_name_, run_num_, color_, datatype_)
#             if color_ == 'green':
#                 for component_ in range(1, conca_fluo_data_[0,0].shape[1]+1):
#                     plot_all_trials(conca_fluo_data_, cell_name_, run_num_, color_, datatype_, component_)

### Download files if running on Google Colab

In [None]:
# batch download the plotted figures
# uncomment the code below to download figures if needed

# import glob

# folder_path = '.'
# # file_prefix = 'All_Conditions_All_Rounds_All_Repeats_'
# file_prefix = 'CL'

# # Use glob to find all files with the given prefix in the folder
# matching_files = glob.glob(f"{folder_path}/{file_prefix}*")
# # print(matching_files)
# # # Print the matching file names
# # for file_path in matching_files:
# #     print(file_path)

# import zipfile

# zip_filename = 'files.zip'
# with zipfile.ZipFile(zip_filename, 'w') as zipf:
#     # Add files to the zip file
#     for file_path in matching_files:
#         zipf.write(file_path)

# from google.colab import files
# files.download(zip_filename)

## Extract chronological order of randomized conditions and generate sequences

### Functions: extract chronological order and get condition order

In [None]:
def extract_chronological_order(cell_name = 'CL090_230515',
                 run_num = ['4', '5', '6'],
                 root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"):
    '''
    This function uses the frame_2p_metadata_run data to extract the the chronological
    order of randomized conditions.
    Finally, it returns
    ori_3d, tf_3d, sf_3d, all are with shape (3, 10, 48) if the length of rum_num is 3,
    3 is the number of runs (the length of run_num); 10 is the number of repeats in each
    run; 48 is the number of conditions.
    The elements in "48" dimension show the ori/tf/sf parameters in chronological order
    for a certain run and repeat.
    '''

    ori_3d = np.empty((len(run_num),10,48))
    tf_3d = np.empty((len(run_num),10,48))
    sf_3d = np.empty((len(run_num),10,48))

    for i in range(len(run_num)):
        path_ = os.path.join(root_path, cell_name, 'frame_2p_metadata_run' + run_num[i] + '.mat')
        mat_data = scipy.io.loadmat(path_)
        stim = mat_data['stim']
        # print(type(stim))
        # print(stim.shape)
        # print(type(stim[0,0]))
        # print(stim[0,0].dtype.names)
        # print(type(stim['frame']))
        # print(stim['frame'].shape)
        # print(type(stim['frame'][0,0]))
        # print(stim['frame'][0,0].dtype.names)

        ori = np.squeeze(stim['frame'][0,0]['orientation'][0,0])
        tf = np.squeeze(stim['frame'][0,0]['temporal_frequencies_hz'][0,0])
        sf = np.squeeze(stim['frame'][0,0]['sf'][0,0])
        # print(type(ori))
        # print(type(tf))
        # print(type(sf))
        # print(ori.shape)
        # print(tf.shape)
        # print(sf.shape)
        # print(len(np.nonzero(ori)[0]))
        # print(len(np.nonzero(tf)[0]))
        # print(len(np.nonzero(sf)[0]))

        indices = np.nonzero(tf)[0] # cannot use nonzero in ori, because it can be 0 degree

        unique_indices = [indices[0]]
        last_number = indices[0]
        for num in indices:
            if num - last_number > 1:
                unique_indices.append(num)
            last_number = num
        # the resulting array where continuous numbers are deleted, and only the first one is retained
        # i.e., [1,2,3,21,22,23,24,33,34,35,36,37,48,49] -> [1,21,33,48]

        # print(len(unique_indices))
        ori = ori[unique_indices].reshape(10,48)
        tf = tf[unique_indices].reshape(10,48)
        sf = sf[unique_indices].reshape(10,48)
        # print(ori.shape)
        # print(tf.shape)
        # print(sf.shape)

        ori_3d[i,:,:] = ori
        tf_3d[i,:,:] = tf
        sf_3d[i,:,:] = sf

    # print(ori_3d.shape)
    # print(tf_3d.shape)
    # print(sf_3d.shape)

    return ori_3d, tf_3d, sf_3d


def get_condition_order(ori_3d, tf_3d, sf_3d):
    '''
    This function uses ori_3d, tf_3d, sf_3d, which are generated by the function
    extract_chronological_order to calculate condition_order, where the column
    number represents the condition and they are sorted in chronological order.

    The conditions along the columns in structure data are given in a fixed order,
    following the order of vectors: orientation, tf, sf (created as follows). But they
    are temporally randomized. That is why we use this function to get the chronological
    order of the columns (conditions).

    The returned condition_order is with shape (3,10,48). For the dimension "48", 48 elements
    are column indexes in structure data, like [2, 0, 23, 24, ...], meaning in structure data
    column 2 occurs first, then column 0, ...
    '''

    # Create orientation array
    orientation = np.repeat(np.arange(0, 360, 45), 6)
    # Create tf array
    tf = np.tile([1, 1, 1, 4, 4, 4], 8)
    # Create sf array
    sf = np.tile([0.02, 0.08, 0.32], 16)

    condition_order = np.empty((3,10,48)) # the column number in chronological order

    condition_in_column_order = [np.array([x,y,z]) for x,y,z in zip(orientation, tf, sf)]

    for i in range(3):
        for j in range(10):
            conditions_in_time_order = [np.array([x,y,z]) for x,y,z in zip(ori_3d[i,j,:], tf_3d[i,j,:], sf_3d[i,j,:])]
            for ii, ele1 in enumerate(conditions_in_time_order):
                is_found = False
                for jj, ele2 in enumerate(condition_in_column_order):
                    if ele1[0] == ele2[0] and ele1[1] == ele2[1] and ele1[2] == ele2[2]:
                        condition_order[i,j,ii] = jj
                        is_found = True
                if not is_found:
                    print("An Element Not Found!!!")

    return condition_order

### Eg of Function Use

In [None]:
# cell_name = 'CL090_230515'
# run_num = ['4', '5', '6']

# ori_3d, tf_3d, sf_3d = extract_chronological_order(cell_name, run_num)
# condition_order = get_condition_order(ori_3d, tf_3d, sf_3d)

### Functions: get red/green time sequence

Note: only for 48 conditions, in chronological order but not continuous.

The sequences got by the following function get_time_sequence are in chronological order **but not continuous** (**so they may not be helpful or used in future algorithms**), because something between them are omitted, like there are gray screen visual stimuli trials are not in the 48 columns but inserted between the 48 trials in the experiment.

In [None]:
def get_time_sequence(conca_fluo_data, condition_order):
    '''
    This function uses conca_fluo_data and condition_order to get
    the time sequence results for red and green data.
    If red, return time_sequence with shape (3, 10, 1, 2976), where
    3 is how many runs, 10 is how many repeats each run, 1 is how many
    components (for red, only soma, it is 1), 2976=62*48 (62 points
    are 2s, 48 are the number of conditions).
    If green, return time_sequence with shape (3, 10, 281, 2976),
    where 281 is how many components, which depends on the cell.
    '''

    how_many_run = conca_fluo_data.shape[0]
    how_many_repeat = conca_fluo_data[0,0].shape[2]
    how_many_component = conca_fluo_data[0,0].shape[1] # will be 1 if red

    time_sequence = np.empty((how_many_run, how_many_repeat, how_many_component, 62*48))

    for i in range(how_many_run):
        for j in range(how_many_repeat):
            for k in range(how_many_component):
                for index, z in enumerate(condition_order[i,j,:]):
                    time_sequence[i, j, k, 62*index:62*(index+1)] = conca_fluo_data[i, int(z)][31:, k, j]

    return time_sequence


def plot_time_sequence_each_repeat(time_sequence,
            cell_name = 'CL090_230515',
            run_num = ['4', '5', '6'],
            color = 'red',
            datatype = 'F',
            run_index = 1,
            repeat = 1,
            component = 1):
    '''
    This function plots the curve of a certain repeat.
    Return the data of that repeat.

    run_index, repeat, and component all start from 1, not 0.
    run_index = 1, ..., len(run_num)
    repeat = 1, .., repeat number
    component = 1, .., component number
    For red, component can only be 1.
    '''

    color_name = 'Red' if color == 'red' else 'Green'
    run = int(run_num[run_index-1])
    data = time_sequence[run_index-1, repeat-1, component-1, :]
    # we should plot all the data of 1 run, that is 10 repeats, cuz they are a sequence.

    time_steps = 62 * 48 # time_sequence.shape[3]
    time_interval = 192  # seconds, 192 s = 4 s/trial * 48 trails, trial is condition

    time_values = [t * time_interval / time_steps for t in range(time_steps)]

    plt.figure(figsize=(18, 6))  # Adjust the figure size as needed

    plt.plot(time_values, data, color=color, linewidth=1)

    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel('Time (s)', fontsize=16)
    plt.ylabel(f'{datatype}', fontsize=16)
    if color == 'red':
        plt.title(f'{cell_name} {color_name} {datatype} Run {run} Repeat {repeat} in Chronological Order', fontsize=18)
    if color == 'green':
        plt.title(f'{cell_name} {color_name} {datatype} Run {run} Repeat {repeat} Component {component} in Chronological Order', fontsize=18)

    # Add shaded rectangles for stimuli
    for i in range(0, time_interval, 4):
        plt.axvspan(i, i + 2, facecolor='gray', alpha=0.2)

    if color == 'red':
        filename = f'{cell_name}_{color_name}_{datatype}_Run{run}_Repeat{repeat:02d}_TimeSequence.pdf'
    if color == 'green':
        filename = f'{cell_name}_{color_name}_{datatype}_{component:03d}_Run{run}_Repeat{repeat:02d}_TimeSequence.pdf'

    plt.savefig(filename)
    plt.show()

    return data


def plot_time_sequence_each_run(time_sequence,
            cell_name = 'CL090_230515',
            run_num = ['4', '5', '6'],
            color = 'red',
            datatype = 'F',
            run_index = 1,
            component = 1):
    '''
    This function plots the curve of a certain run.
    Why plot all the data of 1 run, that is 10 repeats?
    Because they are a sequence, i.e., they are done in 
    chronological order but not continuously.
    Return the data of that run.

    run_index and component both start from 1, not 0.
    run_index = 1, ..., len(run_num)
    component = 1, .., component number
    For red, component can only be 1.
    '''

    color_name = 'Red' if color == 'red' else 'Green'
    run = int(run_num[run_index-1])
    data = time_sequence[run_index-1, :, component-1, :]
    data = data.reshape(time_sequence.shape[1] * time_sequence.shape[3])

    time_steps = 62 * 48 * 10 # 62 * 48 is time_sequence.shape[3], 10 is time_sequence.shape[1]
    time_interval = 192 * 10  # seconds, 192 s = 4 s/trial * 48 trails, trial is condition

    time_values = [t * time_interval / time_steps for t in range(time_steps)]

    plt.figure(figsize=(100, 6))  # Adjust the figure size as needed

    plt.plot(time_values, data, color=color, linewidth=1)

    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.xlabel('Time (s)', fontsize=16)
    plt.ylabel(f'{datatype}', fontsize=16)
    if color == 'red':
        plt.title(f'{cell_name} {color_name} {datatype} Run {run} in Chronological Order', fontsize=18)
    if color == 'green':
        plt.title(f'{cell_name} {color_name} {datatype} Run {run} Component {component} in Chronological Order', fontsize=18)

    # Add shaded rectangles for stimuli
    for i in range(0, time_interval, 4):
        plt.axvspan(i, i + 2, facecolor='gray', alpha=0.2)

    filename = f'{cell_name}_{color_name}_{datatype}_Run{run}_TimeSequence.pdf'
    if color == 'red':
        filename = f'{cell_name}_{color_name}_{datatype}_Run{run}_TimeSequence.pdf'
    if color == 'green':
        filename = f'{cell_name}_{color_name}_{datatype}_{component:03d}_Run{run}_TimeSequence.pdf'
    plt.savefig(filename)
    plt.show()

    return data

### Eg of Function Use

In [None]:
# cell_name = 'CL090_230515'
# run_num = ['4', '5', '6']
# color = 'green'
# datatype = 'F'

# conca_fluo_data = read_conca_fluo_data(cell_name, run_num, color, datatype)

# ori_3d, tf_3d, sf_3d = extract_chronological_order(cell_name, run_num)
# condition_order = get_condition_order(ori_3d, tf_3d, sf_3d)

# time_sequence = get_time_sequence(conca_fluo_data, condition_order)

# # run_index = 1
# # repeat = 1
# # component = 2

# # plot_time_sequence_each_repeat(time_sequence,
# #             cell_name,
# #             run_num,
# #             color,
# #             datatype,
# #             run_index = run_index,
# #             repeat = repeat,
# #             component = component);

# # plot_time_sequence_each_run(time_sequence,
# #             cell_name,
# #             run_num,
# #             color,
# #             datatype,
# #             run_index = run_index,
# #             component = component);

# print(conca_fluo_data.shape)
# print(conca_fluo_data[0,0].shape)
# print(condition_order.shape)
# print(time_sequence.shape)

## Get valid components

### Functions: get valid components

In [None]:
def get_valid_components(cell_name, max_pixels = 100, soma_index = 0,
                         root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"):
    '''
    this function gets the correct components from the all component (too big ones are invalid, criteria is max_pexels).
    corresponding size list and distance list (distance to soma) are obtained.
    soma is get from red bouton marks using soma_index (there should be only 1 but there are actually some invalid ones),
    so we need soma_index.
    '''

    # Load red BoutonMasks
    path_ = os.path.join(root_path, cell_name, cell_name + 'red_BoutonMasks.mat')
    try:
        data_ = h5py.File(path_, 'r') # Open the MATLAB v7.3 file
        red_bouton_masks = copy.deepcopy(np.array(data_['BoutonMasks']))
        if len(red_bouton_masks.shape) == 3:
            red_bouton_masks = np.transpose(red_bouton_masks, (2, 1, 0))
        elif len(red_bouton_masks.shape) == 2:
            red_bouton_masks = np.transpose(red_bouton_masks)
        else:
            print("Unexpected red_bouton_masks shape!!!")
        data_.close()
    except Exception as e:
        data_ = scipy.io.loadmat(path_)
        red_bouton_masks = np.array(data_['BoutonMasks'])

    if len(red_bouton_masks.shape) == 3:
        A = red_bouton_masks[:, :, soma_index]
    elif len(red_bouton_masks.shape) == 2:
        A = red_bouton_masks
    else:
        print("Unexpected red_bouton_masks shape!!!")

    row_indices, col_indices = np.where(A == 1)
    soma_average_row = np.mean(row_indices)
    soma_average_col = np.mean(col_indices)

    # Load green BoutonMasks
    path_ = os.path.join(root_path, cell_name, cell_name + 'green_BoutonMasks.mat')
    try:
        data_ = h5py.File(path_, 'r') # Open the MATLAB v7.3 file
        green_bouton_masks = copy.deepcopy(np.array(data_['BoutonMasks']))
        green_bouton_masks = np.transpose(green_bouton_masks, (2, 1, 0))
        data_.close()
    except Exception as e:
        data_ = scipy.io.loadmat(path_)
        green_bouton_masks = np.array(data_['BoutonMasks'])

    n_com = green_bouton_masks.shape[2]
    # print(green_bouton_masks.shape)
    valid_com_index_list = []
    valid_size_list = []
    valid_dis_list = []

    for i in range(n_com):
        A = green_bouton_masks[:, :, i]
        row_indices, col_indices = np.where(A == 1)
        average_row = np.mean(row_indices)
        average_col = np.mean(col_indices)
        if len(row_indices) < max_pixels:
            valid_com_index_list.append(i)

            distance = np.sqrt((average_row - soma_average_row)**2 + (average_col - soma_average_col)**2)
            valid_dis_list.append(distance)

            number_of_ones = np.sum(A == 1)
            valid_size_list.append(number_of_ones)

    return valid_com_index_list, valid_dis_list, valid_size_list


### Eg of Function Use

In [None]:
# valid_com_index_list, valid_dis_list, valid_size_list = get_valid_components('CL075_230228', 100, 5)
# print(valid_com_index_list)
# print(valid_dis_list)
# print(valid_size_list)

## Extract singal traces

### Functions: read signal traces

In [None]:
def read_signal_traces(cell_name = 'CL090_230515',
           run_num = ['4', '5', '6'],
           color = 'red',
           root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"):
    '''
    This fuction is used to signal traces data of which the length is the 2p lengh
    (number of 2p imgaing frames, which divided by 2p imaging frequency, 15.63 Hz, then
    the duration of 2p imaging).
    color = 'red' or 'green'
    Acquired data is the raw fluorescence data, i.e., F value.
    If len(run_num) is 3, for color 'green', the output is <class 'numpy.ndarray'> with
    shape (3, N, M); for color 'red', the output is <class 'numpy.ndarray'> with shape
    (3, 1, M). N is the number of components and M is 2p length (number of 2p imaging
    frames).
    '''

    path_ = os.path.join(root_path, cell_name, cell_name + color + '_SignalTracesrun' + run_num[0] + '.mat')
    mat_data = scipy.io.loadmat(path_)
    singal_traces = np.empty((len(run_num), mat_data['SignalTraces']['BoutonTraces'][0,0].shape[1],
                  mat_data['SignalTraces']['BoutonTraces'][0,0].shape[0]))
    # print(singal_traces.shape)

    for i in range(len(run_num)):
        path_ = os.path.join(root_path, cell_name, cell_name + color + '_SignalTracesrun' + run_num[i] + '.mat')
        mat_data = scipy.io.loadmat(path_)

        # print(mat_data['SignalTraces'].shape)
        # print(mat_data['SignalTraces'][0,0].shape)
        # print(mat_data['SignalTraces'][0,0]['BoutonTraces'].shape)
        # print(mat_data['SignalTraces'][0,0]['BoutonTraces'][0,0].shape)
        # print(mat_data['SignalTraces']['BoutonTraces'].shape)
        # print(mat_data['SignalTraces']['BoutonTraces'][0,0].shape)
        # # mat_data['SignalTraces'][0,0]['BoutonTraces'][0,0] and mat_data['SignalTraces']['BoutonTraces'][0,0] are same.

        singal_traces[i, :, :] = mat_data['SignalTraces']['BoutonTraces'][0,0].T

    return singal_traces

### Eg of Function Use

In [None]:
# cell_name = 'CL090_230515'
# run_num = ['4', '5', '6']
# color = 'green'

# signal_traces = read_signal_traces(cell_name, run_num, color)

# print(signal_traces.shape)
# # print(signal_traces[0,0,:])

### Functions: check whether structure data is in signal traces and get locations

In [None]:
def get_condition_piece_locations(conca_fluo_data, signal_traces):
    '''
    This function gets the locations of conca_fluo_data (from structure data) in
    signal_traces.
    conca_fluo_data is the data of all conditions; the data is extracted from
    signal_traces.
    return locations with shape (a,b,c,d) (e.g., (3, 281, 10, 48)) and
    its element's shape is (2,3).
    a is how_many_run
    b is how_many_component # will be 1 if red
    c is how_many_repeat
    d is how_many_condition
    (2,3) are the data piece's start and end index in signal_traces.
    here we assume the data piece is 93 points, corresponding to 6s.
    '''

    how_many_run = conca_fluo_data.shape[0]
    how_many_repeat = conca_fluo_data[0,0].shape[2]
    how_many_component = conca_fluo_data[0,0].shape[1] # will be 1 if red
    how_many_condition = conca_fluo_data.shape[1]

    locations = np.empty((how_many_run, how_many_component, how_many_repeat, how_many_condition), dtype=object)

    # get the position of elements from conca_fluo_data in signal_traces
    for run_index in range(how_many_run):
        for component_index in range(how_many_component):
            for repeat_index in range(how_many_repeat):
                for condition_index in range(how_many_condition):
                    for i, ele in enumerate(conca_fluo_data[run_index, condition_index][:, component_index, repeat_index]):
                        if i == 0:
                            # print(np.isin(ele, signal_traces))
                            index_start = np.where(signal_traces[run_index, component_index, :] == ele)[0]
                            # print(np.where(signal_traces[run_index, component_index, :] == ele))
                            # if np.where(signal_traces[run_index, component_index, :] == ele)[0].shape[0] > 1:
                            #     print(ele)
                        if i == 46: # if only i == 0 and i == 92, there are still multiple outputs due to conincidence
                            index_middle = np.where(signal_traces[run_index, component_index, :] == ele)[0]
                        if i == 92:
                            # print(np.isin(ele, signal_traces))
                            index_end = np.where(signal_traces[run_index, component_index, :] == ele)[0]
                    isfound = 0
                    for index1_ in index_start:
                        for index2_ in index_middle:
                            for index3_ in index_end:
                                if index3_ - index1_ == 92 and index2_ - index1_ == 46:
                                    isfound = isfound + 1
                                    locations[run_index, component_index, repeat_index, condition_index] = np.array([[run_index, component_index, index1_],
                                                                              [run_index, component_index, index3_]])
                                    # message = (
                                    #       f"condition {condition_index} -- start position: "
                                    #       f"({run_index}, {component_index}, {index1_}); "
                                    #       f"end position: ({run_index}, {component_index}, {index2_})"
                                    #       )
                                    # print(message)
                    if isfound != 1:
                        print(f"isfound is {isfound}")
                        print(f"run {run_index} component {component_index} repeat {repeat_index} condition {condition_index} Not Found or Multiple Found")
            # print(f"--- --- run {run_index} component {component_index} finished --- ---")

    return locations

## below is the original code for finding location, which has been wrapped in the above function
# run_index = 2
# repeat = 9
# component = 279
# condition_index = 0

# empty_array = np.empty((3, 4))

# # get the position of elements from conca_fluo_data in signal_traces
# for condition_index in range(48):
#     for i, ele in enumerate(conca_fluo_data[run_index, condition_index][:, component, repeat]):
#         if i == 0:
#             # print(np.isin(ele, signal_traces))
#             index1 = np.where(signal_traces[run_index, component, :] == ele)[0]
#             # print(np.where(signal_traces[run_index, component, :] == ele))
#             # if np.where(signal_traces[run_index, component, :] == ele)[0].shape[0] > 1:
#             #     print(ele)
#         if i == 92:
#             # print(np.isin(ele, signal_traces))
#             index2 = np.where(signal_traces[run_index, component, :] == ele)[0]
#     isfound = False
#     for index1_ in index1:
#         for index2_ in index2:
#             if index2_ - index1_ == 92:
#                 isfound = True
#                 print(f"condition {condition_index} -- start position: ({run_index},{component},{index1_}); end position: ({run_index},{component},{index2_})")
#     if not isfound:
#         print("Not Found")
#     print("======")

### Eg of Function Use

In [None]:
# cell_name = 'CL090_230515'
# run_num = ['4', '5', '6']
# color = 'red'
# datatype = 'F'

# conca_fluo_data = read_conca_fluo_data(cell_name, run_num, color, datatype)
# ori_3d, tf_3d, sf_3d = extract_chronological_order(cell_name, run_num)
# condition_order = get_condition_order(ori_3d, tf_3d, sf_3d)
# time_sequence = get_time_sequence(conca_fluo_data, condition_order)

# signal_traces = read_signal_traces(cell_name, run_num, color)

# print(conca_fluo_data.shape)
# print(conca_fluo_data[0,0].shape)
# print(condition_order.shape)
# print(time_sequence.shape)
# print(signal_traces.shape)
# print("-- --- --")

# locations = get_condition_piece_locations(conca_fluo_data, signal_traces)
# print(locations.shape)
# print(locations[0,0,0,2])
# print(locations[0,0,0,2].shape)

### Functions: recover conca_fluo_data (structure data) from signal traces and locations

In [None]:
def recover_strucure_from_traces_and_locs(signal_traces, locations):
    '''
    This function recovers conca_fluo_data (structure data) from signal traces and locations.

    This function is an inverse of function get_condition_piece_locations.
    '''
    how_many_run, how_many_component, how_many_repeat, how_many_condition = locations.shape

    start_point = locations[0, 0, 0, 0][0, 2]
    end_point = locations[0, 0, 0, 0][1, 2]
    fluo_length_per_condition = end_point - start_point + 1
    conca_fluo_data = np.empty((how_many_run, how_many_condition), dtype=object)
    for i in range(how_many_run):
        for j in range(how_many_condition):
            conca_fluo_data[i, j] = np.zeros((fluo_length_per_condition, how_many_component, how_many_repeat))

    for run_index in range(how_many_run):
        for component_index in range(how_many_component):
            for repeat_index in range(how_many_repeat):
                for condition_index in range(how_many_condition):
                    start_point = locations[run_index, component_index, repeat_index, condition_index][0, 2]
                    end_point = locations[run_index, component_index, repeat_index, condition_index][1, 2]
                    conca_fluo_data[run_index, condition_index][:, component_index, repeat_index] = signal_traces[run_index, component_index, start_point:end_point+1]

    return conca_fluo_data

### === Input cell_name here ===

In [None]:
cell_name = 'CL096_231018'
run_num = ['3', '4']

# cell_name = 'CL090_230515'
# run_num = ['4', '5', '6']

# cell_name = 'CL075_230228'
# run_num = ['1', '2', '3']

# cell_name = 'CL079_230324'
# run_num = ['1', '2', '3']

datatype = 'F'

color = 'green'

green_conca_fluo_data = read_conca_fluo_data(cell_name, run_num, color, datatype)
green_signal_traces = read_signal_traces(cell_name, run_num, color)
green_locations = get_condition_piece_locations(green_conca_fluo_data, green_signal_traces)
print(green_conca_fluo_data.shape)
print(green_conca_fluo_data[0,0].shape)
print(green_signal_traces.shape)
print(green_locations.shape)
print(green_locations[0,0,0,0].shape)
print("------ ------")

color = 'red'

red_conca_fluo_data = read_conca_fluo_data(cell_name, run_num, color, datatype)
red_signal_traces = read_signal_traces(cell_name, run_num, color)
red_locations = get_condition_piece_locations(red_conca_fluo_data, red_signal_traces)
print(red_conca_fluo_data.shape)
print(red_conca_fluo_data[0,0].shape)
print(red_signal_traces.shape)
print(red_locations.shape)
print(red_locations[0,0,0,0].shape)

### Eg of Function Use

Check recover_strucure_from_traces_and_locs to see whether we can recover conca_fluo_data 

In [None]:
# # check recover_strucure_from_traces_and_locs to see whether we can recover conca_fluo_data
# green_conca_fluo_data_ = recover_strucure_from_traces_and_locs(green_signal_traces, green_locations)
# red_conca_fluo_data_ = recover_strucure_from_traces_and_locs(red_signal_traces, red_locations)
# for conca_fluo_data, conca_fluo_data_ in [(green_conca_fluo_data, green_conca_fluo_data_), (red_conca_fluo_data, red_conca_fluo_data_)]:
#     are_equal_red = True
#     for i in range(conca_fluo_data.shape[0]):
#         for j in range(conca_fluo_data.shape[1]):
#             if not np.array_equal(conca_fluo_data[i, j], conca_fluo_data_[i, j]):
#                 are_equal_red = False
#                 break
#     print(are_equal_red)
# # the printed outputs are both "True" -- already verified!!!

### Functions: calculate the F value std of all components

In [None]:
def calculate_std_of_dimensions(array, dimensions):
  """Calculates the standard deviation of the specified dimensions of an array.
  Args:
    array: A numpy array.
    dimensions: A tuple of integers representing the dimensions of the array to calculate the 
    standard deviation of.
  Returns:
    A numpy array containing the standard deviation of the specified dimensions of the array.
  """

  # Create a new numpy array to store the standard deviations.
  std_array = np.zeros(len(dimensions))

  # Iterate over the range of dimensions and calculate the standard deviation of each dimension.
  for i in range(len(dimensions)):
    std_array[i] = np.std(array[:, dimensions[i], :])

  # Return the numpy array containing the standard deviations.
  return std_array




### Eg of Function Use

In [None]:
# Calculate the standard deviation of (:,i,:) i = 0 to 127.
F_std_component = calculate_std_of_dimensions(green_signal_traces, range(green_signal_traces.shape[1]))
print(len(F_std_component))

### Calculate and plot mean each run, std each run, mean each repeat, std each repeat

In [None]:
### mean each run

print("mean of each run of green:")
green_mean_run = []
for run_ in range(len(run_num)):
    mean_ = np.mean(green_signal_traces[run_,:,:])
    green_mean_run.append(mean_)
    print(f"run index:{run_} -- mean: {mean_}")

print("mean of each run of red:")
red_mean_run = []
for run_ in range(len(run_num)):
    mean_ = np.mean(red_signal_traces[run_,:,:])
    red_mean_run.append(mean_)
    print(f"run index:{run_} -- mean: {mean_}")

# plot figure
x = range(len(run_num))
x_custom = [run_ for run_ in run_num]
y1 = green_mean_run
y2 = red_mean_run

fig, ax1 = plt.subplots(figsize=(6.4, 4.8)) # [6.4, 4.8] is deault size, same as fig, ax1 = plt.subplots()

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Run', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
ax1.set_xticklabels(x_custom)
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.6*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.6*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, fontsize=14, facecolor='none')

plt.title("Green and Red F Value Means in Different Runs (Raw)", fontsize=18)

plt.show()


### std each run

print("std of each run of green:")
green_std_run = []
for run_ in range(len(run_num)):
    std_ = np.std(green_signal_traces[run_,:,:])
    green_std_run.append(std_)
    print(f"run index:{run_} -- std: {std_}")

print("std of each run of red:")
red_std_run = []
for run_ in range(len(run_num)):
    std_ = np.std(red_signal_traces[run_,:,:])
    red_std_run.append(std_)
    print(f"run index:{run_} -- std: {std_}")

# plot figure
x = range(len(run_num))
x_custom = [run_ for run_ in run_num]
y1 = green_std_run
y2 = red_std_run

fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Run', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
ax1.set_xticklabels(x_custom)
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.6*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.6*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, fontsize=14, facecolor='none')

plt.title("Green and Red F Value Std in Different Runs (Raw)", fontsize=18)

plt.show()


### mean each repeat

# print("mean of each repeat of green:")
repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces.shape[2]/repeat_num)
green_mean_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(green_signal_traces[run_,:,i*repeat_len:(i+1)*repeat_len])
        green_mean_repeat.append(mean_)
        # print(f"run index:{run_} repeat index:{i} -- mean: {mean_}")

# print("mean of each repeat of red:")
red_mean_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(red_signal_traces[run_,:,i*repeat_len:(i+1)*repeat_len])
        red_mean_repeat.append(mean_)
        # print(f"run index:{run_} repeat index:{i} -- mean: {mean_}")

# plot figure
x = range(1, repeat_num * len(run_num) + 1)
y1 = green_mean_repeat
y2 = red_mean_repeat

fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Repeat', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
if len(x) == 30 and len(run_num) == 3: # ste a better xticks for the common case
    ax1.set_xticks([1,4,7,10,11,14,17,20,21,24,27,30])
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.8*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.8*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# add separating dash lines and add run names
for i, run_ in enumerate(run_num):
    if i < len(run_num) - 1:
        ax1.axvline(x=repeat_num*(i+1)+0.5, color='gray', linestyle='--', linewidth=3)
    ax1.text(repeat_num*(i+0.5)+0.5, average_1+0.7*span, 'Run ' + run_, ha='center', va='center', fontsize=16)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, loc = 'lower left', fontsize=14, facecolor='none')

plt.title("Green and Red F Value Means in Different Repeats (Raw)", fontsize=18)

plt.show()


### std each repeat

# print("std of each repeat of green:")
repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces.shape[2]/repeat_num)
green_std_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        std_ = np.std(green_signal_traces[run_,:,i*repeat_len:(i+1)*repeat_len])
        green_std_repeat.append(std_)
        # print(f"run index:{run_} repeat index:{i} -- std: {std_}")

# print("std of each repeat of red:")
red_std_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        std_ = np.std(red_signal_traces[run_,:,i*repeat_len:(i+1)*repeat_len])
        red_std_repeat.append(std_)
        # print(f"run index:{run_} repeat index:{i} -- std: {std_}")

# plot figure
x = range(1, repeat_num * len(run_num) + 1)
y1 = green_std_repeat
y2 = red_std_repeat

fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Repeat', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
if len(x) == 30 and len(run_num) == 3: # ste a better xticks for the common case
    ax1.set_xticks([1,4,7,10,11,14,17,20,21,24,27,30])
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.8*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.8*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
    sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
else:
    sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
    sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
else:
    sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# add separating dash lines and add run names
for i, run_ in enumerate(run_num):
    if i < len(run_num) - 1:
        ax1.axvline(x=repeat_num*(i+1)+0.5, color='gray', linestyle='--', linewidth=3)
    ax1.text(repeat_num*(i+0.5)+0.5, average_1+0.7*span, 'Run ' + run_, ha='center', va='center', fontsize=16)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, loc = 'lower left', fontsize=14, facecolor='none')

plt.title("Green and Red F Value Std in Different Repeats (Raw)", fontsize=18)

plt.show()


## for later use in decay restoration
green_mean_repeat_list_raw = copy.deepcopy(green_mean_repeat)
red_mean_repeat_list_raw = copy.deepcopy(red_mean_repeat)

green_mean_run_list_raw = copy.deepcopy(green_mean_run)
red_mean_run_list_raw = copy.deepcopy(red_mean_run)

green_std_repeat_list_raw = copy.deepcopy(green_std_repeat)
red_std_repeat_list_raw = copy.deepcopy(red_std_repeat)

green_std_run_list_raw = copy.deepcopy(green_std_run)
red_std_run_list_raw = copy.deepcopy(red_std_run)

# std is not used cuz after mean restoration we need to recalculate std

### Plot a repeat of signal traces

In [None]:
print(green_signal_traces.shape)
print(red_signal_traces.shape)

## plot a signal trace in a repeat
# data_to_plot = green_signal_traces[0, 0, :int(32500/10)]
data_to_plot = red_signal_traces[0, 0, 7*int(32500/10):8*int(32500/10)]

plt.figure(figsize=(6.4*2, 4.8))

plt.plot(data_to_plot, color='salmon', linewidth=0.8, label='Red Fluorescence')
plt.xlabel('Frame', fontsize=16)
plt.ylabel('F Value', fontsize=16)
plt.tick_params(axis='y', labelsize=14)
plt.tick_params(axis='x', labelsize=12)

# Set title and legend
plt.title('Red F Trace in One Repeat', fontsize=18)
plt.legend(fontsize=14)

# Show the plot
plt.tight_layout()
plt.show()

## Signal trace restoration (mean and std restoration)

### Restore mean and plot mean each run, std each run, mean each repeat, std each repeat

#### Restore

In [None]:
green_mean_repeat = np.array(green_mean_repeat_list_raw)
red_mean_repeat = np.array(red_mean_repeat_list_raw)

green_mean_run = np.array(green_mean_run_list_raw)
red_mean_run = np.array(red_mean_run_list_raw)

# print(green_mean_repeat.shape)
# print(red_mean_repeat.shape)
# print(green_mean_run.shape)
# print(red_mean_run.shape)

x_mean_repeat_restored_list = []
x_mean_repeat_restored_coefficients_list = []
for x_mean_repeat, x_mean_run in [(green_mean_repeat, green_mean_run), (red_mean_repeat, red_mean_run)]:
    # Define the exponential function for regression of decay
    def exponential_func(x, lambda_):
        return np.exp(lambda_ * x)

    how_many_run = 2
    repeat_num_per_run = int(x_mean_repeat.shape[0]/how_many_run)

    x_mean_repeat_restored = copy.deepcopy(x_mean_repeat)
    x_mean_repeat_restored_coefficients = copy.deepcopy(x_mean_repeat)
    for i in range(how_many_run):
        y = copy.deepcopy(x_mean_repeat[i*repeat_num_per_run:(i+1)*repeat_num_per_run])
        multiplier = 1 / y[0]
        y = y * multiplier
        x = np.arange(len(y))
        popt, pcov = curve_fit(exponential_func, x, y)
        lambda_ = popt[0]
        print(f"exponent in the exp func is {lambda_}")
        y_fit = exponential_func(x, lambda_)

        recover_factor = y / y_fit
        print("recover_factor.shape:", recover_factor.shape)
        # print(recover_factor)

        x_mean_repeat_restored[i*repeat_num_per_run:(i+1)*repeat_num_per_run] *= recover_factor * x_mean_run[0] / x_mean_run[i]
        x_mean_repeat_restored_coefficients[i*repeat_num_per_run:(i+1)*repeat_num_per_run] = recover_factor * x_mean_run[0] / x_mean_run[i]

    x_mean_repeat_restored_list.append(x_mean_repeat_restored)
    x_mean_repeat_restored_coefficients_list.append(x_mean_repeat_restored_coefficients)

green_mean_repeat_restored, red_mean_repeat_restored = x_mean_repeat_restored_list
green_mean_repeat_restored_coefficients, red_mean_repeat_restored_coefficients = x_mean_repeat_restored_coefficients_list

## generate green_signal_traces_mean_restored and red_signal_traces_mean_restored
green_signal_traces_mean_restored = copy.deepcopy(green_signal_traces)
red_signal_traces_mean_restored = copy.deepcopy(red_signal_traces)

repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces.shape[2]/repeat_num)

for run_ in range(len(run_num)):
    for i in range(repeat_num):
        green_signal_traces_mean_restored[run_,:,i*repeat_len:(i+1)*repeat_len] *= green_mean_repeat_restored_coefficients[run_*repeat_num+i]

for run_ in range(len(run_num)):
    for i in range(repeat_num):
        red_signal_traces_mean_restored[run_,:,i*repeat_len:(i+1)*repeat_len] *= red_mean_repeat_restored_coefficients[run_*repeat_num+i]


#### Plot

In [None]:
### mean each run

print("mean of each run of green:")
green_mean_run = []
for run_ in range(len(run_num)):
    mean_ = np.mean(green_signal_traces_mean_restored[run_,:,:])
    green_mean_run.append(mean_)
    print(f"run index:{run_} -- mean: {mean_}")

print("mean of each run of red:")
red_mean_run = []
for run_ in range(len(run_num)):
    mean_ = np.mean(red_signal_traces_mean_restored[run_,:,:])
    red_mean_run.append(mean_)
    print(f"run index:{run_} -- mean: {mean_}")

# plot figure
x = range(len(run_num))
x_custom = [run_ for run_ in run_num]
y1 = green_mean_run
y2 = red_mean_run

fig, ax1 = plt.subplots(figsize=(6.4, 4.8)) # [6.4, 4.8] is deault size, same as fig, ax1 = plt.subplots()

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Run', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
ax1.set_xticklabels(x_custom)
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.6*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.6*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
print(min_y1_, max_y1_, unit_unified, '------------------')
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, fontsize=14, facecolor='none')

plt.title("Green and Red F Value Means in Different Runs\n(After Mean Restoration)", fontsize=18)

plt.show()


### std each run

print("std of each run of green:")
green_std_run = []
for run_ in range(len(run_num)):
    std_ = np.std(green_signal_traces_mean_restored[run_,:,:])
    green_std_run.append(std_)
    print(f"run index:{run_} -- std: {std_}")

print("std of each run of red:")
red_std_run = []
for run_ in range(len(run_num)):
    std_ = np.std(red_signal_traces_mean_restored[run_,:,:])
    red_std_run.append(std_)
    print(f"run index:{run_} -- std: {std_}")

# plot figure
x = range(len(run_num))
x_custom = [run_ for run_ in run_num]
y1 = green_std_run
y2 = red_std_run

fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Run', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
ax1.set_xticklabels(x_custom)
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.6*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.6*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, fontsize=14, facecolor='none')

plt.title("Green and Red F Value Std in Different Runs\n(After Mean Restoration)", fontsize=18)

plt.show()


### mean each repeat

# print("mean of each repeat of green:")
repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces_mean_restored.shape[2]/repeat_num)
green_mean_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(green_signal_traces_mean_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        green_mean_repeat.append(mean_)
        # print(f"run index:{run_} repeat index:{i} -- mean: {mean_}")

# print("mean of each repeat of red:")
red_mean_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(red_signal_traces_mean_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        red_mean_repeat.append(mean_)
        # print(f"run index:{run_} repeat index:{i} -- mean: {mean_}")

# plot figure
x = range(1, repeat_num * len(run_num) + 1)
y1 = green_mean_repeat
y2 = red_mean_repeat

fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Repeat', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
if len(x) == 30 and len(run_num) == 3: # ste a better xticks for the common case
    ax1.set_xticks([1,4,7,10,11,14,17,20,21,24,27,30])
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.8*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.8*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# add separating dash lines and add run names
for i, run_ in enumerate(run_num):
    if i < len(run_num) - 1:
        ax1.axvline(x=repeat_num*(i+1)+0.5, color='gray', linestyle='--', linewidth=3)
    ax1.text(repeat_num*(i+0.5)+0.5, average_1+0.7*span, 'Run ' + run_, ha='center', va='center', fontsize=16)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, loc = 'lower left', fontsize=14, facecolor='none')

plt.title("Green and Red F Value Means in Different Repeats (After Mean Restoration)", fontsize=18)

plt.show()


### std each repeat

# print("std of each repeat of green:")
repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces_mean_restored.shape[2]/repeat_num)
green_std_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        std_ = np.std(green_signal_traces_mean_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        green_std_repeat.append(std_)
        # print(f"run index:{run_} repeat index:{i} -- std: {std_}")

# print("std of each repeat of red:")
red_std_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        std_ = np.std(red_signal_traces_mean_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        red_std_repeat.append(std_)
        # print(f"run index:{run_} repeat index:{i} -- std: {std_}")

# plot figure
x = range(1, repeat_num * len(run_num) + 1)
y1 = green_std_repeat
y2 = red_std_repeat

fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Repeat', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
if len(x) == 30 and len(run_num) == 3: # ste a better xticks for the common case
    ax1.set_xticks([1,4,7,10,11,14,17,20,21,24,27,30])
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.8*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.8*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# add separating dash lines and add run names
for i, run_ in enumerate(run_num):
    if i < len(run_num) - 1:
        ax1.axvline(x=repeat_num*(i+1)+0.5, color='gray', linestyle='--', linewidth=3)
    ax1.text(repeat_num*(i+0.5)+0.5, average_1+0.7*span, 'Run ' + run_, ha='center', va='center', fontsize=16)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, loc = 'lower left', fontsize=14, facecolor='none')

plt.title("Green and Red F Value Std in Different Repeats (After Mean Restoration)", fontsize=18)

plt.show()

## store them to variables with new names
## and for later use in decay restoration
green_mean_repeat_list_after_mean_restored = copy.deepcopy(green_mean_repeat) # same as green_mean_repeat_restored
red_mean_repeat_list_after_mean_restored = copy.deepcopy(red_mean_repeat) # same as red_mean_repeat_restored

green_mean_run_list_after_mean_restored = copy.deepcopy(green_mean_run)
red_mean_run_list_after_mean_restored = copy.deepcopy(red_mean_run)

green_std_repeat_list_after_mean_restored = copy.deepcopy(green_std_repeat)
red_std_repeat_list_after_mean_restored = copy.deepcopy(red_std_repeat)

green_std_run_list_after_mean_restored = copy.deepcopy(green_std_run)
red_std_run_list_after_mean_restored = copy.deepcopy(red_std_run)

### Restore std and plot mean each run, std each run, mean each repeat, std each repeat

#### Restore

In [None]:
## convert to numpy form
green_mean_repeat = np.array(green_mean_repeat_list_after_mean_restored)
red_mean_repeat = np.array(red_mean_repeat_list_after_mean_restored)

green_mean_run = np.array(green_mean_run_list_after_mean_restored)
red_mean_run = np.array(red_mean_run_list_after_mean_restored)

green_std_repeat = np.array(green_std_repeat_list_after_mean_restored)
red_std_repeat = np.array(red_std_repeat_list_after_mean_restored)

green_std_run = np.array(green_std_run_list_after_mean_restored)
red_std_run = np.array(red_std_run_list_after_mean_restored)

# print(green_std_repeat, red_std_repeat, green_std_run, red_std_run)
# print(green_mean_repeat, red_mean_repeat, green_mean_run, red_mean_run)

In [None]:
x_std_repeat_restored_list = []
x_std_repeat_restored_coefficients_list = []
for x_std_repeat, x_std_run in [(green_std_repeat, green_std_run), (red_std_repeat, red_std_run)]:
    # Define the exponential function for regression of decay
    def exponential_func(x, lambda_):
        return np.exp(lambda_ * x)

    how_many_run = 2
    repeat_num_per_run = int(x_std_repeat.shape[0]/how_many_run)

    x_std_repeat_restored = copy.deepcopy(x_std_repeat)
    x_std_repeat_restored_coefficients = copy.deepcopy(x_std_repeat)
    for i in range(how_many_run):
        y = copy.deepcopy(x_std_repeat[i*repeat_num_per_run:(i+1)*repeat_num_per_run])
        multiplier = 1 / y[0]
        y = y * multiplier
        x = np.arange(len(y))
        popt, pcov = curve_fit(exponential_func, x, y)
        lambda_ = popt[0]
        print(f"exponent in the exp func is {lambda_}")
        y_fit = exponential_func(x, lambda_)

        recover_factor = y / y_fit
        print("recover_factor.shape:", recover_factor.shape)
        # print(recover_factor)

        x_std_repeat_restored[i*repeat_num_per_run:(i+1)*repeat_num_per_run] *= recover_factor * x_std_run[0] / x_std_run[i]
        x_std_repeat_restored_coefficients[i*repeat_num_per_run:(i+1)*repeat_num_per_run] = recover_factor * x_std_run[0] / x_std_run[i]

    x_std_repeat_restored_list.append(x_std_repeat_restored)
    x_std_repeat_restored_coefficients_list.append(x_std_repeat_restored_coefficients)

green_std_repeat_restored, red_std_repeat_restored = x_std_repeat_restored_list
green_std_repeat_restored_coefficients, red_std_repeat_restored_coefficients = x_std_repeat_restored_coefficients_list

## generate red_signal_traces_mean_restored_std_restored and red_signal_traces_mean_restored_std_restored
green_signal_traces_mean_restored_std_restored = copy.deepcopy(green_signal_traces_mean_restored)
red_signal_traces_mean_restored_std_restored = copy.deepcopy(red_signal_traces_mean_restored)

repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces.shape[2]/repeat_num)

for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(green_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        temp_ = green_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len] - mean_
        temp_ *= green_std_repeat_restored_coefficients[run_*repeat_num+i]
        green_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len] = mean_ + temp_

for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(red_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        temp_ = red_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len] - mean_
        temp_ *= red_std_repeat_restored_coefficients[run_*repeat_num+i]
        red_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len] = mean_ + temp_

#### Plot

In [None]:
### mean each run

print("mean of each run of green:")
green_mean_run = []
for run_ in range(len(run_num)):
    mean_ = np.mean(green_signal_traces_mean_restored_std_restored[run_,:,:])
    green_mean_run.append(mean_)
    print(f"run index:{run_} -- mean: {mean_}")

print("mean of each run of red:")
red_mean_run = []
for run_ in range(len(run_num)):
    mean_ = np.mean(red_signal_traces_mean_restored_std_restored[run_,:,:])
    red_mean_run.append(mean_)
    print(f"run index:{run_} -- mean: {mean_}")

# plot figure
x = range(len(run_num))
x_custom = [run_ for run_ in run_num]
y1 = green_mean_run
y2 = red_mean_run

fig, ax1 = plt.subplots(figsize=(6.4, 4.8)) # [6.4, 4.8] is deault size, same as fig, ax1 = plt.subplots()

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Run', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
ax1.set_xticklabels(x_custom)
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.6*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.6*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, fontsize=14, facecolor='none')

plt.title("Green and Red F Value Means in Different Runs\n(After Mean and Std Restoration)", fontsize=18)

plt.show()


### std each run

print("std of each run of green:")
green_std_run = []
for run_ in range(len(run_num)):
    std_ = np.std(green_signal_traces_mean_restored_std_restored[run_,:,:])
    green_std_run.append(std_)
    print(f"run index:{run_} -- std: {std_}")

print("std of each run of red:")
red_std_run = []
for run_ in range(len(run_num)):
    std_ = np.std(red_signal_traces_mean_restored_std_restored[run_,:,:])
    red_std_run.append(std_)
    print(f"run index:{run_} -- std: {std_}")

# plot figure
x = range(len(run_num))
x_custom = [run_ for run_ in run_num]
y1 = green_std_run
y2 = red_std_run

fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Run', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
ax1.set_xticklabels(x_custom)
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.6*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.6*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, fontsize=14, facecolor='none')

plt.title("Green and Red F Value Std in Different Runs\n(After Mean and Std Restoration)", fontsize=18)

plt.show()


### mean each repeat

# print("mean of each repeat of green:")
repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces_mean_restored_std_restored.shape[2]/repeat_num)
green_mean_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(green_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        green_mean_repeat.append(mean_)
        # print(f"run index:{run_} repeat index:{i} -- mean: {mean_}")

# print("mean of each repeat of red:")
red_mean_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        mean_ = np.mean(red_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        red_mean_repeat.append(mean_)
        # print(f"run index:{run_} repeat index:{i} -- mean: {mean_}")

# plot figure
x = range(1, repeat_num * len(run_num) + 1)
y1 = green_mean_repeat
y2 = red_mean_repeat

fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Repeat', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
if len(x) == 30 and len(run_num) == 3: # ste a better xticks for the common case
    ax1.set_xticks([1,4,7,10,11,14,17,20,21,24,27,30])
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.8*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.8*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# add separating dash lines and add run names
for i, run_ in enumerate(run_num):
    if i < len(run_num) - 1:
        ax1.axvline(x=repeat_num*(i+1)+0.5, color='gray', linestyle='--', linewidth=3)
    ax1.text(repeat_num*(i+0.5)+0.5, average_1+0.7*span, 'Run ' + run_, ha='center', va='center', fontsize=16)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, loc = 'lower left', fontsize=14, facecolor='none')

plt.title("Green and Red F Value Means in Different Repeats (After Mean and Std Restoration)", fontsize=18)

plt.show()


### std each repeat

# print("std of each repeat of green:")
repeat_num = 10 # 10 repeats per run
repeat_len = int(green_signal_traces_mean_restored_std_restored.shape[2]/repeat_num)
green_std_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        std_ = np.std(green_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        green_std_repeat.append(std_)
        # print(f"run index:{run_} repeat index:{i} -- std: {std_}")

# print("std of each repeat of red:")
red_std_repeat = []
for run_ in range(len(run_num)):
    for i in range(repeat_num):
        std_ = np.std(red_signal_traces_mean_restored_std_restored[run_,:,i*repeat_len:(i+1)*repeat_len])
        red_std_repeat.append(std_)
        # print(f"run index:{run_} repeat index:{i} -- std: {std_}")

# plot figure
x = range(1, repeat_num * len(run_num) + 1)
y1 = green_std_repeat
y2 = red_std_repeat

fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

green_line, = ax1.plot(x, y1, 's-', color='limegreen', label='Green Fluorescence', markersize=14, linewidth=3)
ax1.set_xlabel('Repeat', fontsize=16)
ax1.set_ylabel('Green F', color='limegreen', fontsize=16)
ax1.tick_params(axis='y', labelcolor='limegreen', labelsize=14)

# set custome x axis labels
ax1.set_xticks(x)
if len(x) == 30 and len(run_num) == 3: # ste a better xticks for the common case
    ax1.set_xticks([1,4,7,10,11,14,17,20,21,24,27,30])
ax1.tick_params(axis='x', labelsize=14)

ax2 = ax1.twinx()

red_line, = ax2.plot(x, y2, 's-', color='salmon', label='Red Fluorescence', markersize=14, linewidth=3)
ax2.set_ylabel('Red F', color='salmon', fontsize=16)
ax2.tick_params(axis='y', labelcolor='salmon', labelsize=14)

# set two y axes with same value for unit length
span = max((max(y1) - min(y1), max(y2) - min(y2)))
average_1 = (max(y1) + min(y1)) / 2
average_2 = (max(y2) + min(y2)) / 2
min_y1, max_y1 = average_1 - 0.6*span, average_1 + 0.8*span
min_y2, max_y2 = average_2 - 0.6*span, average_2 + 0.8*span
ax1.set_ylim(min_y1, max_y1)
ax2.set_ylim(min_y2, max_y2)
# make ticks sparse
min_y1_, max_y1_ = np.ceil(min_y1 / 100) * 100, np.floor(max_y1 / 100) * 100
min_y2_, max_y2_ = np.ceil(min_y2 / 100) * 100, np.floor(max_y2 / 100) * 100
if max_y1_ <= min_y1_:
    min_y1_, max_y1_ = min_y1_ - 100, max_y1_ + 100
if max_y2_ <= min_y2_:
    min_y2_, max_y2_ = min_y2_ - 100, max_y2_ + 100
# here use ceil for min and floor for ceil is to gurantee that the range used in
# set_yticks is within the range of set_ylim, otherwise a new range of the axis
# will consider both set_yticks and set_ylim and give the union of them.
unit_1 = np.ceil((max_y1_ - min_y1_) / 3 / 100) * 100 # 3 means only 3+1 ticks
unit_2 = np.ceil((max_y2_ - min_y2_) / 3 / 100) * 100
unit_unified = max(unit_1, unit_2)
if min_y1_ != max_y1_:
    if np.arange(min_y1_, max_y1_, step=unit_unified)[-1] > max(y1):
        sparse_y_ticks_1 = np.arange(min_y1_, max_y1_, step=unit_unified)
    else:
        sparse_y_ticks_1 = np.append(np.arange(min_y1_, max_y1_, step=unit_unified), max_y1_)
else:
    sparse_y_ticks_1 = np.array([int(min(y1)), int(max(y1))])
if min_y2_ != max_y2_:
    if np.arange(min_y2_, max_y2_, step=unit_unified)[-1] > max(y2):
        sparse_y_ticks_2 = np.arange(min_y2_, max_y2_, step=unit_unified)
    else:
        sparse_y_ticks_2 = np.append(np.arange(min_y2_, max_y2_, step=unit_unified), max_y2_)
else:
    sparse_y_ticks_2 = np.array([int(min(y2)), int(max(y2))])
ax1.set_yticks(sparse_y_ticks_1)
ax2.set_yticks(sparse_y_ticks_2)

# add separating dash lines and add run names
for i, run_ in enumerate(run_num):
    if i < len(run_num) - 1:
        ax1.axvline(x=repeat_num*(i+1)+0.5, color='gray', linestyle='--', linewidth=3)
    ax1.text(repeat_num*(i+0.5)+0.5, average_1+0.7*span, 'Run ' + run_, ha='center', va='center', fontsize=16)

# put two legends together
lines = [green_line, red_line]
labels = [line.get_label() for line in lines]
ax2.legend(lines, labels, loc = 'lower left', fontsize=14, facecolor='none')

plt.title("Green and Red F Value Std in Different Repeats (After Mean and Std Restoration)", fontsize=18)

plt.show()

## store them to variables with new names
## and for later use in decay restoration
green_mean_repeat_list_after_mean_restored_std_restored = copy.deepcopy(green_mean_repeat) # same as green_mean_repeat_restored
red_mean_repeat_list_after_mean_restored_std_restored = copy.deepcopy(red_mean_repeat) # same as red_mean_repeat_restored

green_mean_run_list_after_mean_restored_std_restored = copy.deepcopy(green_mean_run)
red_mean_run_list_after_mean_restored_std_restored = copy.deepcopy(red_mean_run)

green_std_repeat_list_after_mean_restored_std_restored = copy.deepcopy(green_std_repeat)
red_std_repeat_list_after_mean_restored_std_restored = copy.deepcopy(red_std_repeat)

green_std_run_list_after_mean_restored_std_restored = copy.deepcopy(green_std_run)
red_std_run_list_after_mean_restored_std_restored = copy.deepcopy(red_std_run)

### Plot tunning curves using retored data

This is the same process as in "Read and plot fluorescence structure data/Eg of function use/Read red and green data then plot", but uses restored data. Need to use function recover_strucure_from_traces_and_locs() to recover conca_fluo_data (structure data) from signal traces and locations.

In [None]:
# cell_name_list = ['CL075_230228']
# run_num_list = [['1', '2', '3']]

# cell_name_list = ['CL090_230515']
# run_num_list = [['4', '5', '6']]

# for cell_name_, run_num_ in zip(cell_name_list, run_num_list):
#     for color_ in ['red', 'green']:
#         for datatype_ in ['F']:
#             if color_ == 'red':
#                 conca_fluo_data_ = recover_strucure_from_traces_and_locs(red_signal_traces_mean_restored_std_restored, red_locations)
#                 plot_all_trials(conca_fluo_data_, cell_name_, run_num_, color_, datatype_)
#             if color_ == 'green':
#                 conca_fluo_data_ = recover_strucure_from_traces_and_locs(green_signal_traces_mean_restored_std_restored, green_locations)
#                 for component_ in range(1, conca_fluo_data_[0,0].shape[1]+1):
#                     plot_all_trials(conca_fluo_data_, cell_name_, run_num_, color_, datatype_, component_)

In [None]:
# # batch download the plotted figures
# # uncomment the code below to download figures if needed

# import glob

# folder_path = '.'
# # file_prefix = 'All_Conditions_All_Rounds_All_Repeats_'
# file_prefix = 'CL'

# # Use glob to find all files with the given prefix in the folder
# matching_files = glob.glob(f"{folder_path}/{file_prefix}*")
# matching_files_new = []
# for file_path in matching_files:
#     # Check if the file ends with ".pdf"
#     if file_path.lower().endswith('.pdf'):
#         matching_files_new.append(file_path)
# matching_files = matching_files_new
# # print(matching_files)
# # # Print the matching file names
# # for file_path in matching_files:
# #     print(file_path)

# import zipfile

# zip_filename = 'files.zip'
# with zipfile.ZipFile(zip_filename, 'w') as zipf:
#     # Add files to the zip file
#     for file_path in matching_files:
#         zipf.write(file_path)

# from google.colab import files
# files.download(zip_filename)

### Save important data

In [None]:
# np.save(cell_name+'red_signal_traces.npy', red_signal_traces)
# np.save(cell_name+'red_conca_fluo_data.npy', red_conca_fluo_data)
# np.save(cell_name+'red_locations.npy', red_locations)

# np.save(cell_name+'green_signal_traces.npy', green_signal_traces)
# np.save(cell_name+'green_conca_fluo_data.npy', green_conca_fluo_data)
# np.save(cell_name+'green_locations.npy', green_locations)

# np.save(cell_name+'green_signal_traces_mean_restored.npy', green_signal_traces_mean_restored)
# np.save(cell_name+'red_signal_traces_mean_restored.npy', red_signal_traces_mean_restored)

# np.save(cell_name+'green_signal_traces_mean_restored_std_restored.npy', green_signal_traces_mean_restored_std_restored)
# np.save(cell_name+'red_signal_traces_mean_restored_std_restored.npy', red_signal_traces_mean_restored_std_restored)


In [None]:
# # Load red signal traces data
# red_signal_traces = np.load(cell_name + 'red_signal_traces.npy')
# print("Red Signal Traces Shape:", red_signal_traces.shape)
# print("Red Signal Traces Type:", type(red_signal_traces))

# # Load red concatenated fluorescence data
# red_conca_fluo_data = np.load(cell_name + 'red_conca_fluo_data.npy', allow_pickle=True)
# print("Red Concatenated Fluorescence Data Shape:", red_conca_fluo_data.shape)
# print("Red Concatenated Fluorescence Data Type:", type(red_conca_fluo_data))
# print("Red Concatenated Fluorescence Data Elements Shape:", red_conca_fluo_data[0,0].shape)
# print("Red Concatenated Fluorescence Data Elements Type:", type(red_conca_fluo_data[0, 0]))

# # Load red locations data
# red_locations = np.load(cell_name + 'red_locations.npy', allow_pickle=True)
# print("Red Locations Shape:", red_locations.shape)
# print("Red Locations Type:", type(red_locations))
# print("Red Locations Elements Shape:", red_locations[0,0].shape)
# print("Red Locations Elements Type:", type(red_locations[0, 0]))

# # Load green signal traces data
# green_signal_traces = np.load(cell_name + 'green_signal_traces.npy')
# print("Green Signal Traces Shape:", green_signal_traces.shape)
# print("Green Signal Traces Type:", type(green_signal_traces))

# # Load green concatenated fluorescence data
# green_conca_fluo_data = np.load(cell_name + 'green_conca_fluo_data.npy', allow_pickle=True)
# print("Green Concatenated Fluorescence Data Shape:", green_conca_fluo_data.shape)
# print("Green Concatenated Fluorescence Data Type:", type(green_conca_fluo_data))
# print("Green Concatenated Fluorescence Elements Data Shape:", green_conca_fluo_data[0,0].shape)
# print("Green Concatenated Fluorescence Elements Data Type:", type(green_conca_fluo_data[0, 0]))

# # Load green locations data
# green_locations = np.load(cell_name + 'green_locations.npy', allow_pickle=True)
# print("Green Locations Shape:", green_locations.shape)
# print("Green Locations Type:", type(green_locations))
# print("Green Locations Elements Shape:", green_locations[0,0].shape)
# print("Green Locations Elements Type:", type(green_locations[0, 0]))

# # Load green signal traces mean restored data
# green_signal_traces_mean_restored = np.load(cell_name + 'green_signal_traces_mean_restored.npy')
# print("Green Signal Traces Mean Restored Shape:", green_signal_traces_mean_restored.shape)
# print("Green Signal Traces Mean Restored Type:", type(green_signal_traces_mean_restored))

# # Load red signal traces mean restored data
# red_signal_traces_mean_restored = np.load(cell_name + 'red_signal_traces_mean_restored.npy')
# print("Red Signal Traces Mean Restored Shape:", red_signal_traces_mean_restored.shape)
# print("Red Signal Traces Mean Restored Type:", type(red_signal_traces_mean_restored))

# # Load green signal traces mean restored std restored data
# green_signal_traces_mean_restored_std_restored = np.load(cell_name + 'green_signal_traces_mean_restored_std_restored.npy')
# print("Green Signal Traces Mean Restored Std Restored Shape:", green_signal_traces_mean_restored_std_restored.shape)
# print("Green Signal Traces Mean Restored Std Restored Type:", type(green_signal_traces_mean_restored_std_restored))

# # Load red signal traces mean restored std restored data
# red_signal_traces_mean_restored_std_restored = np.load(cell_name + 'red_signal_traces_mean_restored_std_restored.npy')
# print("Red Signal Traces Mean Restored Std Restored Shape:", red_signal_traces_mean_restored_std_restored.shape)
# print("Red Signal Traces Mean Restored Std Restored Type:", type(red_signal_traces_mean_restored_std_restored))


### Download files if running on Google Colab

In [None]:
# # batch download the plotted figures
# # uncomment the code below to download figures if needed

# import glob

# folder_path = '.'
# # file_prefix = 'All_Conditions_All_Rounds_All_Repeats_'
# file_prefix = 'CL'

# # Use glob to find all files with the given prefix in the folder
# matching_files = glob.glob(f"{folder_path}/{file_prefix}*")
# # print(matching_files)
# # # Print the matching file names
# # for file_path in matching_files:
# #     print(file_path)

# import zipfile

# zip_filename = 'files.zip'
# with zipfile.ZipFile(zip_filename, 'w') as zipf:
#     # Add files to the zip file
#     for file_path in matching_files:
#         zipf.write(file_path)

# from google.colab import files
# files.download(zip_filename)

## Generate data and label, then train and eval

In this chapter, each code cell has a title so that you can navigate it easily through the table of contents. Because some code cells are long, otherwise it is hard to locate them.

### === Set whether including sigmoid ===

In [None]:
with_sigmoid = True # whether include a sigmoid unit in our model (non-hierarchical model)

### === Set whether setting weights non-negative ===

In [None]:
weights_nonnegative = True # whether only allow non-negative weights for linear layers in training
strong_penalty = False # whether using strong penalty to restrict the weights to nonnegative values
# if weights_nonnegative = False, then the value of strong_penalty doesn't matter.

### Get valid components (=== different cell may have different paras ===)

In [None]:
# # different cell has different parameters, this is for cell CL075_230228
# valid_com_index_list, valid_dis_list, valid_size_list = get_valid_components(cell_name, 100, 5)
# data_set = copy.deepcopy(green_signal_traces)
# label_set = copy.deepcopy(red_signal_traces)
# print(data_set[:,valid_com_index_list,:].shape) # remain the valid component
# print(valid_com_index_list)
# print(valid_dis_list)


# different cell has different parameters, this is for cell CL090_230515 / CL096_231018
valid_com_index_list, valid_dis_list, valid_size_list = get_valid_components(cell_name, 100, 0)
data_set = copy.deepcopy(green_signal_traces)
label_set = copy.deepcopy(red_signal_traces)
print(data_set[:,valid_com_index_list,:].shape) # remain the valid component
print(valid_com_index_list)

# # different cell has different parameters, this is for cell CL079_230324
# valid_com_index_list, valid_dis_list, valid_size_list = get_valid_components(cell_name, 100, 0)
# data_set = copy.deepcopy(green_signal_traces)
# label_set = copy.deepcopy(red_signal_traces)
# print(data_set[:,valid_com_index_list,:].shape) # remain the valid component
# print(valid_com_index_list)

### Show red and green signal traces

In [None]:
signal_trace_type_list = ["Raw Signal Traces",
                          "Mean Restored Signal Traces",
                          "Mean and Std Restored Signal Traces"]


## plot red
for signal_trace_type in signal_trace_type_list:

    if signal_trace_type == "Raw Signal Traces":
        data_set = copy.deepcopy(green_signal_traces)
        label_set = copy.deepcopy(red_signal_traces)
    elif signal_trace_type == "Mean Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored)
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored_std_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored_std_restored)

    x = range(1, label_set.shape[2] + 1)

    plt.figure(figsize=(20, 6))
    # plt.plot(list(label_set[0,0,:])+list(label_set[1,0,:])+list(label_set[2,0,:]), color='salmon', linestyle='-', linewidth=1)
    plt.plot(list(label_set[0,0,:])+list(label_set[1,0,:]), color='salmon', linestyle='-', linewidth=1)
    # to use which one in above two lines depends on how many runs, 2 or 3.
    v1 = label_set.shape[2] + 0.5
    v2 = label_set.shape[2] * 2 + 0.5
    plt.axvline(x=v1, color='gray', linestyle='--', linewidth=2)
    plt.axvline(x=v2, color='gray', linestyle='--', linewidth=2)
    plt.xlabel('Frames', fontsize = 20)
    plt.ylabel('F Value', fontsize = 20)
    plt.tick_params(labelsize=18)
    plt.title(f'Whole red signal trace ({signal_trace_type})', fontsize = 24)
    plt.grid(True)
    plt.show()


## plot green
component_index = 20
for signal_trace_type in signal_trace_type_list:

    if signal_trace_type == "Raw Signal Traces":
        data_set = copy.deepcopy(green_signal_traces)
        label_set = copy.deepcopy(red_signal_traces)
    elif signal_trace_type == "Mean Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored)
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored_std_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored_std_restored)

    data_set = data_set[:, valid_com_index_list, :] # use valid components

    x = range(1, data_set.shape[2] + 1)

    plt.figure(figsize=(20, 6))
    # plt.plot(list(data_set[0,component_index,:])+list(data_set[1,component_index,:])+list(data_set[2,component_index,:]), color='limegreen', linestyle='-', linewidth=1)
    plt.plot(list(data_set[0,component_index,:])+list(data_set[1,component_index,:]), color='limegreen', linestyle='-', linewidth=1) 
    # to use which one in above two lines depends on how many runs, 2 or 3.
    v1 = data_set.shape[2] + 0.5
    v2 = data_set.shape[2] * 2 + 0.5
    plt.axvline(x=v1, color='gray', linestyle='--', linewidth=2)
    plt.axvline(x=v2, color='gray', linestyle='--', linewidth=2)
    plt.xlabel('Frames', fontsize = 20)
    plt.ylabel('F Value', fontsize = 20)
    plt.tick_params(labelsize=18)
    plt.title(f'Whole green signal trace of Component {component_index} ({signal_trace_type})', fontsize = 24)
    plt.grid(True)
    plt.show()

### Train, plot, save models (it takes some time) (=== set near on/offset ===)

In [None]:
trial_times = 15  # try how many times to get the best one of them as the final result

# end_near_on/offset means the data pieces selected for training is ending near on/offset 
# of the visual stimuli. Test data is still random without restriction.
# Note: at most one of end_near_offset and end_near_onset can be True!
end_near_offset = False
end_near_onset = False


signal_trace_type_list = ["Raw Signal Traces",
                          "Mean Restored Signal Traces",
                          "Mean and Std Restored Signal Traces"]

# signal_trace_type_list = ["Mean Restored Signal Traces",
#                           "Mean and Std Restored Signal Traces"]

# signal_trace_type_list = ["Raw Signal Traces"]
# signal_trace_type_list = ["Mean Restored Signal Traces"]
# signal_trace_type_list = ["Mean and Std Restored Signal Traces"]

list_of_train_loss_lists = []
list_of_test_loss_lists = []

for signal_trace_type in signal_trace_type_list:

    if signal_trace_type == "Raw Signal Traces":
        data_set = copy.deepcopy(green_signal_traces)
        label_set = copy.deepcopy(red_signal_traces)
    elif signal_trace_type == "Mean Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored)
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored_std_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored_std_restored)

    data_set = data_set[:,valid_com_index_list,:] # use valid components

    print(f"data_set.shape: {data_set.shape}")
    print(f"label_set.shape: {label_set.shape}")

    # # Normalize data_set to [-1,1]
    # data_set_min = np.min(data_set)
    # data_set_max = np.max(data_set)
    # data_set = ((data_set - data_set_min) / (data_set_max - data_set_min) - 0.5) * 2

    # # Normalize label_set to [-1,1]
    # label_set_min = np.min(label_set)
    # label_set_max = np.max(label_set)
    # label_set = ((label_set - label_set_min) / (label_set_max - label_set_min) - 0.5) * 2

    # Normalize data_set to [0,1]
    data_set_min = np.min(data_set)
    data_set_max = np.max(data_set)
    data_set = (data_set - data_set_min) / (data_set_max - data_set_min)

    # Normalize label_set to [0,1]
    label_set_min = np.min(label_set)
    label_set_max = np.max(label_set)
    label_set = (label_set - label_set_min) / (label_set_max - label_set_min)

    # Define dataset class
    class FluoDataset(Dataset):
        def __init__(self, data_set, label_set, z_indices, x_indices):
            self.data_set = data_set
            self.label_set = label_set
            self.z_indices = z_indices
            self.x_indices = x_indices

        def __len__(self):
            return len(self.z_indices)

        def __getitem__(self, idx):
            z = self.z_indices[idx]
            x = self.x_indices[idx]

            input_data = self.data_set[z, :, x-62:x]
            target_label = np.mean(self.label_set[z, :, x-31:x])

            return torch.tensor(input_data, dtype=torch.float32), torch.tensor(target_label, dtype=torch.float32)
            # or torch.from_numpy(input_data).float(), torch.from_numpy(target_label).float()

    # Set a random seed for reproducibility
    # np.random.seed(16)  # can use any integer as the seed value

    # Split the data into train and test sets
    length = 6400*2
    # choices = range(3)
    choices = range(2) 
    train_z_indices = np.random.choice(choices, size=length)
    # choices = range(62, 30000)
    choices = list(range(62, 10000)) + list(range(12000, 17000)) + list(range(18244, 32500)) # 90% for train
    if end_near_offset or end_near_onset:
        choices_ = []
        if end_near_offset:
            x_ = 0
        if end_near_onset:
            x_ = 1
        count_ = 0
        while(True):
            count_ = count_ + 1
            shape = red_locations.shape
            random_indices = [np.random.randint(dim_size) for dim_size in shape]
            selected_element = red_locations[random_indices[0], random_indices[1], random_indices[2], random_indices[3]]
            end_index_ = selected_element[x_, 2]
            end_index_ = random.randint(end_index_-5, end_index_+5)
            if end_index_ in choices and end_index_ not in choices_:
                choices_.append(end_index_)
            if count_ >= length*3 or len(choices_) >= length:
                break
        choices = choices_
    train_x_indices = np.random.choice(choices, size=length)


    length = 640
    # choices = range(3)
    choices = range(2) 
    test_z_indices = np.random.choice(choices, size=length)
    # choices = range(62, 30000)
    # choices = range(30000, 32500)
    choices = list(range(10000, 12000)) + list(range(17000, 18244)) # 10% for test
    # 12000-10000+18244-17000=3244, 3244/(32500-62)=10%, 32500-62 is the all data pieces
    test_x_indices = np.random.choice(choices, size=length)

    length = 640 # this is a subset of train data for test on train
    # choices = range(3)
    choices = range(2) 
    small_train_z_indices = np.random.choice(choices, size=length)
    choices = list(range(9000, 10000)) + list(range(12000, 17000)) + list(range(18244, 19000))
    small_train_x_indices = np.random.choice(choices, size=length)

    train_dataset = FluoDataset(data_set, label_set, train_z_indices, train_x_indices)
    test_dataset = FluoDataset(data_set, label_set, test_z_indices, test_x_indices)
    test_on_train_dataset = FluoDataset(data_set, label_set, small_train_z_indices, small_train_x_indices)

    # Create data loaders
    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    test_on_train_loader = DataLoader(test_on_train_dataset, batch_size=batch_size, shuffle=False)
    # print(len(train_loader))
    # print(len(test_loader))


    # Define the model
    class FluoModel(nn.Module):
        def __init__(self, component_num):
            super(FluoModel, self).__init__()
            self.fc_shared = nn.Linear(62, 1)  # Shared fully connected layer
            self.fc_reduce = nn.Linear(component_num, 1)  # Fully connected layer to reduce component_num (e.g., 281) channels to 1
            self.sigmoid = nn.Sigmoid()
            self.fc_end = nn.Linear(1, 1)

        def forward(self, x):
            x = x.view(x.size(0), -1, 62)  # Reshape to (batch_size, component_num, 62)

            # Apply the shared fully connected layer along the last dimension (62)
            shared_output = self.fc_shared(x).squeeze(2)
            # self.fc_shared(x) shape is (batch_size, component_num, 1), then squeeze the last dimension
            # In PyTorch, when you apply a fully connected layer (or any other linear
            # transformation) to a 3D tensor, by default, the operation is performed
            # along the last dimension of the tensor.

            # Reduce component_num (e.g., 281) channels to 1 using a separate fully connected layer
            reduced_output = self.fc_reduce(shared_output)
            if with_sigmoid:
                pre_output = self.sigmoid(reduced_output) # if with sigmoid
                output = self.fc_end(pre_output)
            else:
                output = reduced_output # if without sigmoid

            return output

    # Using a custom loss term during training that penalizes negative weights
    if weights_nonnegative:
        class NonNegLoss(nn.Module):
            def __init__(self):
                super(NonNegLoss, self).__init__()

            def forward(self, tensor):
                return torch.sum(torch.relu(-tensor))
        non_neg_loss = NonNegLoss()

    train_loss_list = []
    test_loss_list = []
    for i_ in range(trial_times): # try multiple times to get the best one
        model = FluoModel(data_set.shape[1])

        # Define loss function and optimizer
        criterion = nn.MSELoss()  # Mean Squared Error Loss for regression
        optimizer = optim.Adam(model.parameters(), lr=0.003)

        # optimizer = optim.Adam(model.parameters(), lr=0.004)
        # scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

        # Training loop
        num_epochs = 90
        print_interval = 10

        train_loss_list_ = []
        test_loss_list_ = []
        for epoch in range(num_epochs):
            model.train()
            for inputs, targets in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                if weights_nonnegative:
                    if strong_penalty:
                        pnlty_coeff = 1 / 500
                    else:
                        pnlty_coeff = 1 / 1000
                    loss = criterion(outputs.squeeze(), targets) + \
                    non_neg_loss(model.fc_shared.weight) * pnlty_coeff + \
                    non_neg_loss(model.fc_reduce.weight) * pnlty_coeff
                    # print(non_neg_loss(model.fc_reduce.weight))
                    # print(model.fc_reduce.weight)
                    # print(torch.relu(-model.fc_reduce.weight))
                else:
                    loss = criterion(outputs.squeeze(), targets)
                loss.backward()
                optimizer.step()
            # scheduler.step()

            # if epoch < 5 or (epoch + 1) % print_interval == 0:
            #     print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {loss.item():.6f}")

            # # Enforce non-negativity constraint on weights (not including biases) at the end of each epoch
            # # Enforcing strict non-negativity constraints on the weights after BP makes the training process 
            # # more challenging and very difficult to converge. So, basically, this method is ruled out and not
            # # used. I add a loss to penalize the negative weights instead.
            # if weights_nonnegative:
            #     with torch.no_grad():
            #         model.fc_reduce.weight.data.clamp_(min=0)
            #         model.fc_shared.weight.data.clamp_(min=0)

            # Testing on train and test sets
            model.eval()

            train_loss = 0.0
            with torch.no_grad():
                for inputs, targets in test_on_train_loader:
                    outputs = model(inputs)
                    train_loss += criterion(outputs.squeeze(), targets).item()

            average_train_loss = train_loss / len(test_on_train_dataset)
            train_loss_list_.append(average_train_loss)

            test_loss = 0.0
            with torch.no_grad():
                for inputs, targets in test_loader:
                    outputs = model(inputs)
                    test_loss += criterion(outputs.squeeze(), targets).item()

            average_test_loss = test_loss / len(test_loader)
            if i_ != 0: # at least one complete training -- let the 1st training be complete
                if epoch == 5 and average_test_loss >= 0.9 * np.mean(np.array(test_loss_list_)[1:]):
                    break; # kill trials with a low probability of convergence (terminate trials that are unlikely to converge)
            test_loss_list_.append(average_test_loss)

            if epoch < 5 or (epoch + 1) % print_interval == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {average_train_loss:.6f} | Test Loss: {average_test_loss:.6f}")

        if i_ == 0:
            train_loss_list, test_loss_list = train_loss_list_, test_loss_list_
            model_final = model
        elif test_loss_list_[-1] < test_loss_list[-1] and len(train_loss_list_) == num_epochs:
            train_loss_list, test_loss_list = train_loss_list_, test_loss_list_
            model_final = model

    # plot results
    plt.figure(figsize=(6.4, 4.8))
    plt.plot(train_loss_list, color = 'skyblue', label='Train Loss', linewidth=2)
    plt.plot(test_loss_list, color = 'salmon', label='Test Loss', linewidth=2)
    plt.xlabel('Epoch Index', fontsize=16)
    plt.ylabel('MSE Loss of Normalized Data', fontsize=16)
    plt.tick_params(labelsize=14)
    plt.title(f"Train and Test Loss Curves\n({signal_trace_type})", fontsize=18)
    plt.legend(fontsize=14, facecolor='none')
    plt.show()

    list_of_train_loss_lists.append(train_loss_list)
    list_of_test_loss_lists.append(test_loss_list)

    if signal_trace_type == "Raw Signal Traces":
        model_path = f"./{cell_name}_model_with_raw.pth"
    elif signal_trace_type == "Mean Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored.pth"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored_std_restored.pth"

    with open(model_path, "wb") as f:
        torch.save(model_final.state_dict(), f)

    model = FluoModel(data_set.shape[1])
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model in evaluation mode

    outputs_list = []
    labels_list = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = model(inputs)
            outputs_list.append(outputs.numpy())
            labels_list.append(targets.numpy())

    outputs_array = np.concatenate(outputs_list)
    labels_array = np.concatenate(labels_list)

    # Sort labels and corresponding outputs
    sorted_indices = np.argsort(labels_array)
    sorted_labels = labels_array[sorted_indices]
    sorted_outputs = outputs_array[sorted_indices]

    # plot results
    plt.figure(figsize=(6.4, 4.8))
    plt.plot(sorted_labels, color = 'skyblue', label='True Labels', linewidth=2)
    plt.plot(sorted_outputs, color = 'salmon', label='Model Outputs', linewidth=2)
    plt.xlabel('Sample Index', fontsize=16)
    plt.ylabel('Normalized Value', fontsize=16)
    plt.tick_params(labelsize=14)
    plt.title(f"Comparison between True Labels and Model Outputs\n({signal_trace_type})", fontsize=18)
    plt.legend(fontsize=14, facecolor='none')
    plt.show()


    # Calculate the correlation coefficient
    correlation_coefficient = np.corrcoef(sorted_labels, np.transpose(sorted_outputs))[0, 1]

    # plot results
    plt.figure(figsize=(6.4, 4.8))
    plt.scatter(sorted_labels, sorted_outputs, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    plt.xlabel('Normalized Value (Labels)', fontsize=16)
    plt.ylabel('Normalized Value (Outputs)', fontsize=16)
    plt.tick_params(labelsize=14)
    plt.title(f"Scatter Plot between True Labels and Model Outputs\n({signal_trace_type})", fontsize=18)
    plt.legend(fontsize=14, facecolor='none')
    plt.show()

### Plot comparison of losses

In [None]:
# compare train losses and compare test loss between different types of signal traces

colors_list = [('deepskyblue', 'orangered'), ('royalblue', 'salmon'), ('cornflowerblue', 'violet')]
linestyle_list = ['dashed', 'dotted', 'solid']

plt.figure(figsize=(6.4, 4.8))
for (signal_trace_type, train_loss_list, test_loss_list, colors, linestyle) in zip(
    signal_trace_type_list, list_of_train_loss_lists, list_of_test_loss_lists,
    colors_list, linestyle_list
):
    if signal_trace_type == "Raw Signal Traces":
        type_ = "(Raw)"
    elif signal_trace_type == "Mean Restored Signal Traces":
        type_ = "(Mean Restored)"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        type_ = "(Mean and Std Restored)"
    plt.plot(train_loss_list, color = colors[0], linestyle = linestyle, label=f'Train Loss {type_}', linewidth=2)
    # plt.plot(train_loss_list[-10:], color = colors[0], linestyle = linestyle, label=f'Train Loss {type_}', linewidth=2)
plt.xlabel('Epoch Index', fontsize=16)
plt.ylabel('MSE Loss of Normalized Data', fontsize=16)
plt.tick_params(labelsize=14)
plt.title(f"Train Loss Comparison Between Different Datasets", fontsize=18)
plt.legend(fontsize=14, facecolor='none')
plt.show()

plt.figure(figsize=(6.4, 4.8))
for (signal_trace_type, train_loss_list, test_loss_list, colors, linestyle) in zip(
    signal_trace_type_list, list_of_train_loss_lists, list_of_test_loss_lists,
    colors_list, linestyle_list
):
    if signal_trace_type == "Raw Signal Traces":
        type_ = "(Raw)"
    elif signal_trace_type == "Mean Restored Signal Traces":
        type_ = "(Mean Restored)"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        type_ = "(Mean and Std Restored)"
    plt.plot(test_loss_list, color = colors[1], linestyle = linestyle, label=f'Test Loss {type_}', linewidth=2)
    # plt.plot(test_loss_list[-10:], color = colors[1], linestyle = linestyle, label=f'Test Loss {type_}', linewidth=2)
plt.xlabel('Epoch Index', fontsize=16)
plt.ylabel('MSE Loss of Normalized Data', fontsize=16)
plt.tick_params(labelsize=14)
plt.title(f"Test Loss Comparison Between Different Datasets", fontsize=18)
plt.legend(fontsize=14, facecolor='none')
plt.show()


# plot last epochs

plt.figure(figsize=(6.4, 4.8))
for (signal_trace_type, train_loss_list, test_loss_list, colors, linestyle) in zip(
    signal_trace_type_list, list_of_train_loss_lists, list_of_test_loss_lists,
    colors_list, linestyle_list
):
    if signal_trace_type == "Raw Signal Traces":
        type_ = "(Raw)"
    elif signal_trace_type == "Mean Restored Signal Traces":
        type_ = "(Mean Restored)"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        type_ = "(Mean and Std Restored)"
    plt.plot(np.arange(len(train_loss_list)+1-50, len(train_loss_list)+1), train_loss_list[-50:], color = colors[0], linestyle = linestyle, label=f'Train Loss {type_}', linewidth=2)
    # plt.plot(train_loss_list[-10:], color = colors[0], linestyle = linestyle, label=f'Train Loss {type_}', linewidth=2)
plt.xlabel('Epoch Index', fontsize=16)
plt.ylabel('MSE Loss of Normalized Data', fontsize=16)
plt.tick_params(labelsize=14)
plt.title(f"Train Loss Comparison Between Different Datasets", fontsize=18)
plt.legend(fontsize=14, facecolor='none')
plt.show()

plt.figure(figsize=(6.4, 4.8))
for (signal_trace_type, train_loss_list, test_loss_list, colors, linestyle) in zip(
    signal_trace_type_list, list_of_train_loss_lists, list_of_test_loss_lists,
    colors_list, linestyle_list
):
    if signal_trace_type == "Raw Signal Traces":
        type_ = "(Raw)"
    elif signal_trace_type == "Mean Restored Signal Traces":
        type_ = "(Mean Restored)"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        type_ = "(Mean and Std Restored)"
    plt.plot(np.arange(len(test_loss_list)+1-50, len(test_loss_list)+1), test_loss_list[-50:], color = colors[1], linestyle = linestyle, label=f'Test Loss {type_}', linewidth=2)
    # plt.plot(test_loss_list[-10:], color = colors[1], linestyle = linestyle, label=f'Test Loss {type_}', linewidth=2)
plt.xlabel('Epoch Index', fontsize=16)
plt.ylabel('MSE Loss of Normalized Data', fontsize=16)
plt.tick_params(labelsize=14)
plt.title(f"Test Loss Comparison Between Different Datasets", fontsize=18)
plt.legend(fontsize=14, facecolor='none')
plt.show()

### Load model weights and replot (=== set weight redistribution test here ===)

The next chunck plots "Comparison between True Labels and Model Outputs" and "Scatter Plot between True Labels and Model Outputs" (and some other figures) 

This part is done without training a model but with directly loading the model weights.

In [None]:
## whether redistribute weights to see robustness
## if redistribute = False, then strong_or_weak_redistribute's value doesn't matter.
redistribute = False
strong_or_weak_redistribute = False # True is strong, False is weak


signal_trace_type_list = ["Raw Signal Traces",
                          "Mean Restored Signal Traces",
                          "Mean and Std Restored Signal Traces"]

# signal_trace_type_list = ["Mean Restored Signal Traces",
#                           "Mean and Std Restored Signal Traces"]

# signal_trace_type_list = ["Raw Signal Traces"]
# signal_trace_type_list = ["Mean Restored Signal Traces"]
# signal_trace_type_list = ["Mean and Std Restored Signal Traces"]

for signal_trace_type in signal_trace_type_list:

    if signal_trace_type == "Raw Signal Traces":
        data_set = copy.deepcopy(green_signal_traces)
        label_set = copy.deepcopy(red_signal_traces)
    elif signal_trace_type == "Mean Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored)
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        data_set = copy.deepcopy(green_signal_traces_mean_restored_std_restored)
        label_set = copy.deepcopy(red_signal_traces_mean_restored_std_restored)

    data_set = data_set[:,valid_com_index_list,:] # use valid components

    print(f"data_set.shape: {data_set.shape}")
    print(f"label_set.shape: {label_set.shape}")

    # # Normalize data_set to [-1,1]
    # data_set_min = np.min(data_set)
    # data_set_max = np.max(data_set)
    # data_set = ((data_set - data_set_min) / (data_set_max - data_set_min) - 0.5) * 2

    # # Normalize label_set to [-1,1]
    # label_set_min = np.min(label_set)
    # label_set_max = np.max(label_set)
    # label_set = ((label_set - label_set_min) / (label_set_max - label_set_min) - 0.5) * 2

    # Normalize data_set to [0,1]
    data_set_min = np.min(data_set)
    data_set_max = np.max(data_set)
    data_set = (data_set - data_set_min) / (data_set_max - data_set_min)

    # Normalize label_set to [0,1]
    label_set_min = np.min(label_set)
    label_set_max = np.max(label_set)
    label_set = (label_set - label_set_min) / (label_set_max - label_set_min)

    # Define dataset class
    class FluoDataset(Dataset):
        def __init__(self, data_set, label_set, z_indices, x_indices):
            self.data_set = data_set
            self.label_set = label_set
            self.z_indices = z_indices
            self.x_indices = x_indices

        def __len__(self):
            return len(self.z_indices)

        def __getitem__(self, idx):
            z = self.z_indices[idx]
            x = self.x_indices[idx]

            input_data = self.data_set[z, :, x-62:x]
            target_label = np.mean(self.label_set[z, :, x-31:x])

            return torch.tensor(input_data, dtype=torch.float32), torch.tensor(target_label, dtype=torch.float32)
            # or torch.from_numpy(input_data).float(), torch.from_numpy(target_label).float()

    # Set a random seed for reproducibility
    # np.random.seed(16)  # can use any integer as the seed value

    # Split the data into train and test sets
    length = 6400*2
    choices = range(3)
    choices = range(2) # delete last run
    train_z_indices = np.random.choice(choices, size=length)
    # choices = range(62, 30000)
    choices = list(range(62, 10000)) + list(range(12000, 17000)) + list(range(18244, 32500)) # 90% for train
    # 10000-62+17000-12000+32500-18244=29194, 29194/(32500-62)=90%, 32500-62 is the all data pieces
    train_x_indices = np.random.choice(choices, size=length)

    length = 640
    choices = range(3)
    choices = range(2) # delete last run
    test_z_indices = np.random.choice(choices, size=length)
    # choices = range(62, 30000)
    # choices = range(30000, 32500)
    choices = list(range(10000, 12000)) + list(range(17000, 18244)) # 10% for test
    # 12000-10000+18244-17000=3244, 3244/(32500-62)=10%, 32500-62 is the all data pieces
    test_x_indices = np.random.choice(choices, size=length)

    length = 640 # this is a subset of train data for test on train
    choices = range(3)
    small_train_z_indices = np.random.choice(choices, size=length)
    choices = list(range(9000, 10000)) + list(range(12000, 17000)) + list(range(18244, 19000))
    small_train_x_indices = np.random.choice(choices, size=length)

    train_dataset = FluoDataset(data_set, label_set, train_z_indices, train_x_indices)
    test_dataset = FluoDataset(data_set, label_set, test_z_indices, test_x_indices)
    test_on_train_dataset = FluoDataset(data_set, label_set, small_train_z_indices, small_train_x_indices)

    # Create data loaders
    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    test_on_train_loader = DataLoader(test_on_train_dataset, batch_size=batch_size, shuffle=False)
    # print(len(train_loader))
    # print(len(test_loader))


    # Define the model
    class FluoModel(nn.Module):
        def __init__(self, component_num):
            super(FluoModel, self).__init__()
            self.fc_shared = nn.Linear(62, 1)  # Shared fully connected layer
            self.fc_reduce = nn.Linear(component_num, 1)  # Fully connected layer to reduce component_num (e.g., 281) channels to 1
            self.sigmoid = nn.Sigmoid()
            self.fc_end = nn.Linear(1, 1)

        def forward(self, x):
            x = x.view(x.size(0), -1, 62)  # Reshape to (batch_size, component_num, 62)

            # Apply the shared fully connected layer along the last dimension (62)
            shared_output = self.fc_shared(x).squeeze(2)
            # self.fc_shared(x) shape is (batch_size, component_num, 1), then squeeze the last dimension
            # In PyTorch, when you apply a fully connected layer (or any other linear
            # transformation) to a 3D tensor, by default, the operation is performed
            # along the last dimension of the tensor.

            # Reduce component_num (e.g., 281) channels to 1 using a separate fully connected layer
            reduced_output = self.fc_reduce(shared_output)
            if with_sigmoid:
                pre_output = self.sigmoid(reduced_output) # if with sigmoid
                output = self.fc_end(pre_output)
            else:
                output = reduced_output # if without sigmoid

            return output
        
        def get_layer_output(self, x, layer_name):
            if layer_name == "shared":
                shared_output = self.fc_shared(x).squeeze(2)
                return shared_output
            elif layer_name == "reduce":
                shared_output = self.fc_shared(x).squeeze(2)
                reduced_output = self.fc_reduce(shared_output)
                return reduced_output
            elif layer_name == "sigmoid":
                shared_output = self.fc_shared(x).squeeze(2)
                reduced_output = self.fc_reduce(shared_output)
                pre_output = self.sigmoid(reduced_output)
                return pre_output
            elif layer_name == "final":
                x = x.view(x.size(0), -1, 62)
                shared_output = self.fc_shared(x).squeeze(2)
                reduced_output = self.fc_reduce(shared_output)
                if with_sigmoid:
                    pre_output = self.sigmoid(reduced_output) # if with sigmoid
                    output = self.fc_end(pre_output)
                else:
                    output = reduced_output # if without sigmoid
                return output
            else:
                raise ValueError("Invalid layer name")

    if signal_trace_type == "Raw Signal Traces":
        model_path = f"./{cell_name}_model_with_raw.pth"
    elif signal_trace_type == "Mean Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored.pth"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored_std_restored.pth"

    model = FluoModel(data_set.shape[1])
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model in evaluation mode

    if redistribute:
        # change weight distribution of componnets in a group
        root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"
        path_ = os.path.join(root_path, cell_name, cell_name + 'green_Axon.mat')
        mat_data = scipy.io.loadmat(path_)
        axons = mat_data['Axons']
        # Squeeze the outer array
        axons = np.squeeze(axons, axis=0)
        for i in range(len(axons)):
            # Squeeze the inner array and convert the data type to 'int'
            axons[i] = np.squeeze(axons[i].astype(int), axis=0)
        valid_com_index_from_one_list = np.array(valid_com_index_list) + 1
        # Convert axons to index form thereby consistent with the index of remained components
        index_form_axons = copy.deepcopy(axons)
        for i, axon in enumerate(axons):
            for j, bouton in enumerate(axon):
                index_form_axons[i][j] = np.where(valid_com_index_from_one_list == bouton)[0]
        selected_coms = index_form_axons[0]
        if strong_or_weak_redistribute:
            num_ = 30
        else:
            num_ = 10
        for i in range(num_):
            index_1 = random.choice(selected_coms)
            index_2 = random.choice(selected_coms)
            temp_w = model.fc_reduce.weight.data[0,index_1]
            model.fc_reduce.weight.data[0,index_1] = model.fc_reduce.weight.data[0,index_2]
            model.fc_reduce.weight.data[0,index_2] = temp_w
    
    inputs_list = []
    outputs_list = []
    labels_list = []
    outputs_of_shared_list = []
    outputs_of_reduce_list = []
    outputs_of_sigmoid_list = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = model(inputs)
            outputs_of_shared = model.get_layer_output(inputs, "shared")
            outputs_of_shared_list.append(outputs_of_shared.numpy())
            outputs_of_reduce = model.get_layer_output(inputs, "reduce")
            outputs_of_reduce_list.append(outputs_of_reduce.numpy())
            if with_sigmoid:
                outputs_of_sigmoid = model.get_layer_output(inputs, "sigmoid")
                outputs_of_sigmoid_list.append(outputs_of_sigmoid.numpy())
            inputs_list.append(inputs.numpy())
            outputs_list.append(outputs.numpy())
            labels_list.append(targets.numpy())
    print("----")
    print(f"inputs shape: {inputs_list[0].shape}")
    print(f"outputs_of_shared shape: {outputs_of_shared_list[0].shape}")
    print(f"outputs_of_reduce shape: {outputs_of_reduce_list[0].shape}")
    if with_sigmoid:
        print(f"outputs_of_sigmoid shape: {outputs_of_sigmoid_list[0].shape}")
    print(f"outputs shape: {outputs_list[0].shape}")
    print(f"labels shape: {labels_list[0].shape}")
    print("----")

    
    # plot some input samples with indicating their reduce layer outputs
    n_sample = 5
    inputs = inputs_list[0][:n_sample]
    upper_b_inputs = np.ceil(np.max(inputs) * 10) / 10
    lower_b_inputs = np.floor(np.min(inputs) * 10) / 10
    outputs_of_shared = outputs_of_shared_list[0][:n_sample]
    upper_b_shared = np.ceil(np.max(outputs_of_shared) * 10) / 10
    lower_b_shared = np.floor(np.min(outputs_of_shared) * 10) / 10
    outputs_of_reduce = outputs_of_reduce_list[0][:n_sample]
    if with_sigmoid:
        outputs_of_sigmoid = outputs_of_sigmoid_list[0][:n_sample]
    outputs = outputs_list[0][:n_sample]
    labels = labels_list[0][:n_sample]
    print(inputs.shape, outputs_of_shared.shape, outputs_of_reduce.shape)
    for i in range(n_sample):
        input_, share_out_, reduce_out_ = inputs[i], outputs_of_shared[i], outputs_of_reduce[i]
        if with_sigmoid:
            sigmoid_out_ = outputs_of_sigmoid[i]
        output_, label_ = outputs[i], labels[i]

        fig, ax = plt.subplots(figsize=(6.4*1.5, 4.8*1.5))
        cax = ax.imshow(input_, cmap='cool', vmin=lower_b_inputs, vmax=upper_b_inputs)
        colorbar = fig.colorbar(cax)
        colorbar.set_label('Colorbar Label', fontsize=14)
        ax.tick_params(axis='x', labelsize=14)
        ax.tick_params(axis='y', labelsize=14)
        plt.xlabel('Time frames', fontsize=16)
        plt.ylabel('Components', fontsize=16)
        ax.set_title(f"Input Value Map\n(output: {output_[0]:.5f}, label: {label_:.5f})\n({signal_trace_type})", fontsize=18)
        plt.show()

        share_out_ = share_out_[:, np.newaxis]
        #  Horizontally concatenate the array to make it bold in figures
        share_out_ = np.hstack([share_out_, share_out_, share_out_, share_out_, share_out_])
        fig, ax = plt.subplots(figsize=(6.4*1.5, 4.8*1.5))
        cax = ax.imshow(share_out_, cmap='cool', vmin=lower_b_shared, vmax=upper_b_shared)
        colorbar = fig.colorbar(cax)
        colorbar.set_label('Colorbar Label', fontsize=14)
        ax.tick_params(axis='y', labelsize=14)
        plt.ylabel('Components', fontsize=16)
        # Hide x-axis ticks and labels
        ax.set_xticks([])
        ax.set_xticklabels([])
        ax.set_title(f"Shared Layer Output Value Map\n(output: {output_[0]:.5f}, label: {label_:.5f})\n({signal_trace_type})", fontsize=18)
        plt.show()

        print(f"Output of the reduce layer is {reduce_out_[0]:.5f}")
        if with_sigmoid:
            print(f"Output of the sigmoid is {sigmoid_out_[0]:.5f}")
        print((f"Output of the final layer is {output_[0]:.5f}"))
        print((f"Label is {label_:.5f}"))


    outputs_array = np.concatenate(outputs_list)
    labels_array = np.concatenate(labels_list)
    outputs_of_reduce_array = np.concatenate(outputs_of_reduce_list)

    if with_sigmoid:
        # plot the values of reduce layer outputs (before feeding into sigmoid)
        plt.figure(figsize=(6.4, 4.8))
        plt.scatter(range(outputs_of_reduce_array.shape[0]), outputs_of_reduce_array, color = 'salmon')
        plt.xlabel('Sample Index', fontsize=16)
        plt.ylabel('Value', fontsize=16)
        plt.title(f"Outputs of reduce layer (before feeding into sigmoid)\n({signal_trace_type})", fontsize=18)
        plt.show()
        def sigmoid(x):
            return 1 / (1 + np.exp(-x))
        x_ = np.linspace(-5, 5, 100)
        y_ = sigmoid(x_)
        plt.figure(figsize=(6.4, 4.8))
        plt.plot(x_, y_, color = 'skyblue', label='Sigmoid Function', linewidth=2)
        plt.scatter(outputs_of_reduce_array, np.zeros_like(outputs_of_reduce_array), label='Inputs to Sigmoid', color = 'salmon')
        plt.xlabel('Input Value', fontsize=16)
        plt.ylabel('Output Value', fontsize=16)
        plt.title(f"Outputs of reduce layer on Sigmoid\n({signal_trace_type})", fontsize=18)
        plt.legend(fontsize=14, facecolor='none')
        plt.show()
    else:
        # plot the values of reduce layer outputs (final outputs)
        plt.figure(figsize=(6.4, 4.8))
        plt.scatter(range(outputs_of_reduce_array.shape[0]), outputs_of_reduce_array, color = 'salmon')
        plt.xlabel('Sample Index', fontsize=16)
        plt.ylabel('Value', fontsize=16)
        plt.title(f"Outputs of reduce layer (final outputs)\n({signal_trace_type})", fontsize=18)
        plt.show()

    # Sort labels and corresponding outputs
    sorted_indices = np.argsort(labels_array)
    sorted_labels = labels_array[sorted_indices]
    sorted_outputs = outputs_array[sorted_indices]

    # plot results
    plt.figure(figsize=(6.4, 4.8))
    plt.plot(sorted_labels, color = 'skyblue', label='True Labels', linewidth=2)
    plt.plot(sorted_outputs, color = 'salmon', label='Model Outputs', linewidth=2)
    plt.xlabel('Sample Index', fontsize=16)
    plt.ylabel('Normalized Value', fontsize=16)
    plt.tick_params(labelsize=14)
    plt.title(f"Comparison between True Labels and Model Outputs\n({signal_trace_type})", fontsize=18)
    plt.legend(fontsize=14, facecolor='none')
    plt.show()


    # Calculate the correlation coefficient
    correlation_coefficient = np.corrcoef(sorted_labels, np.transpose(sorted_outputs))[0, 1]

    # plot results
    plt.figure(figsize=(6.4, 4.8))
    plt.scatter(sorted_labels, sorted_outputs, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    plt.xlabel('Normalized Value (Labels)', fontsize=16)
    plt.ylabel('Normalized Value (Outputs)', fontsize=16)
    plt.tick_params(labelsize=14)
    plt.title(f"Scatter Plot between True Labels and Model Outputs\n({signal_trace_type})", fontsize=18)
    plt.legend(fontsize=14, facecolor='none')
    plt.show()


## below is wrong original code

# signal_trace_type_list = ["Raw Signal Traces",
#                           "Mean Restored Signal Traces",
#                           "Mean and Std Restored Signal Traces"]

# # signal_trace_type_list = ["Mean Restored Signal Traces",
# #                           "Mean and Std Restored Signal Traces"]

# # signal_trace_type_list = ["Raw Signal Traces"]
# # signal_trace_type_list = ["Mean Restored Signal Traces"]
# # signal_trace_type_list = ["Mean and Std Restored Signal Traces"]

# for signal_trace_type in signal_trace_type_list:
#     if signal_trace_type == "Raw Signal Traces":
#         model_path = f"./{cell_name}_model_with_raw.pth"
#     elif signal_trace_type == "Mean Restored Signal Traces":
#         model_path = f"./{cell_name}_model_with_mean_restored.pth"
#     elif signal_trace_type == "Mean and Std Restored Signal Traces":
#         model_path = f"./{cell_name}_model_with_mean_restored_std_restored.pth"

#     model = FluoModel(data_set.shape[1])
#     model.load_state_dict(torch.load(model_path))
#     model.eval()  # Set the model in evaluation mode

#     outputs_list = []
#     labels_list = []

#     with torch.no_grad():
#         for inputs, targets in test_loader:
#             outputs = model(inputs)
#             outputs_list.append(outputs.numpy())
#             labels_list.append(targets.numpy())

#     outputs_array = np.concatenate(outputs_list)
#     labels_array = np.concatenate(labels_list)

#     # Sort labels and corresponding outputs
#     sorted_indices = np.argsort(labels_array)
#     sorted_labels = labels_array[sorted_indices]
#     sorted_outputs = outputs_array[sorted_indices]

#     # Calculate the correlation coefficient
#     correlation_coefficient = np.corrcoef(sorted_labels, np.transpose(sorted_outputs))[0, 1]

#     # plot results
#     plt.figure(figsize=(6.4, 4.8))
#     plt.scatter(sorted_labels, sorted_outputs, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
#     plt.xlabel('Normalized Value (Labels)', fontsize=16)
#     plt.ylabel('Normalized Value (Outputs)', fontsize=16)
#     plt.tick_params(labelsize=14)
#     plt.title(f"Scatter Plot between True Labels and Model Outputs\n({signal_trace_type})", fontsize=18)
#     plt.legend(fontsize=14, facecolor='none')
#     plt.show()

### Load model weights and see weights

#### Plots with respect to absolute weights

In [None]:
signal_trace_type_list = ["Raw Signal Traces", "Mean Restored Signal Traces", "Mean and Std Restored Signal Traces"]
# signal_trace_type_list = ["Mean Restored Signal Traces", "Mean and Std Restored Signal Traces"]
# signal_trace_type_list = ["Mean and Std Restored Signal Traces"]

root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"
path_ = os.path.join(root_path, cell_name, cell_name + 'green_Axon.mat')
mat_data = scipy.io.loadmat(path_)
axons = mat_data['Axons']
# Squeeze the outer array
axons = np.squeeze(axons, axis=0)
for i in range(len(axons)):
    # Squeeze the inner array and convert the data type to 'int'
    axons[i] = np.squeeze(axons[i].astype(int), axis=0)

valid_com_index_from_one_list = np.array(valid_com_index_list) + 1

# Convert axons to index form thereby consistent with the index of remained components
index_form_axons = copy.deepcopy(axons)
for i, axon in enumerate(axons):
    for j, bouton in enumerate(axon):
        index_form_axons[i][j] = np.where(valid_com_index_from_one_list == bouton)[0]
flat_index_form_axons = np.concatenate(index_form_axons)


for signal_trace_type in signal_trace_type_list:

    model = FluoModel(data_set.shape[1])

    if signal_trace_type == "Raw Signal Traces":
        model_path = f"./{cell_name}_model_with_raw.pth"
    elif signal_trace_type == "Mean Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored.pth"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored_std_restored.pth"

    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)

    # access the weights for each layer
    fc_shared_weights = model.fc_shared.weight
    fc_reduce_weights = model.fc_reduce.weight
    fc_end_weights = model.fc_end.weight

    # access the biases if needed
    fc_shared_biases = model.fc_shared.bias
    fc_reduce_biases = model.fc_reduce.bias
    fc_end_biases = model.fc_end.bias

    print(f"{signal_trace_type}")

    # Weights
    print("\nfc_shared weights:")
    print(fc_shared_weights)

    print("\nfc_reduce weights:")
    print(fc_reduce_weights)

    print("\nfc_end weights (if without sigmoid, the end layer is not used, then its weight and bias are random):")
    print(fc_end_weights)

    # Biases
    print("\nfc_shared biases:")
    print(fc_shared_biases)

    print("\nfc_reduce biases:")
    print(fc_reduce_biases)

    print("\nfc_end biases (if without sigmoid, the end layer is not used, then its weight and bias are random):")
    print(fc_end_biases)

    # Access the weights
    fc_shared_weights = fc_shared_weights.detach().numpy().flatten()
    fc_reduce_weights = fc_reduce_weights.detach().numpy().flatten()



    ##### Create subplots (weights abs)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Plot fc_shared_weights
    ax1.bar(range(1,len(fc_shared_weights)+1), np.abs(fc_shared_weights), color='skyblue')
    ax1.set_title(f'fc_shared_weights, corresponding to frames\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Frame', fontsize=16)
    ax1.set_ylabel('Absolute Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    plt.tight_layout()
    plt.show()

    half_len = int(len(fc_shared_weights)/2)
    print(f"The mean the first {half_len} weight absolute values of fc_shared_weights is {np.mean(np.abs(fc_shared_weights)[:half_len])}")
    print(f"The mean the last {half_len} weight absolute values of fc_shared_weights is {np.mean(np.abs(fc_shared_weights)[half_len:])}")


    # Create subplots (1 y axis)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

    # Plot fc_reduce_weights
    ax1.bar(range(1,len(fc_reduce_weights)+1), np.abs(fc_reduce_weights), color='salmon')
    ax1.set_title(f'fc_reduce_weights, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Absolute Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (2 y axes, weights abs and distance)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))
    x = np.arange(1, len(fc_reduce_weights) + 1)

    # Plot fc_reduce_weights
    ax1.bar(x, np.abs(fc_reduce_weights), color='salmon', label = 'Weights')
    ax1.set_title(f'fc_reduce_weights and Distance, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Absolute Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve
    ax2.plot(x, valid_dis_list, color='skyblue', label='Distance to Soma')
    ax2.set_ylabel('Distance to Soma', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for the second curve
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (correlation between weights abs and distance)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(np.abs(fc_reduce_weights))
    sorted_valid_dis_list = np.array(valid_dis_list)[sorted_indices]
    sorted_abs_fc_reduce_weights = np.abs(fc_reduce_weights)[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_abs_fc_reduce_weights, np.transpose(sorted_valid_dis_list))[0, 1]

    # plot results
    ax1.scatter(sorted_abs_fc_reduce_weights, sorted_valid_dis_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('Distance to Soma (Pixels)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and Distance to Soma\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (correlation between weights abs and 1/distance)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(np.abs(fc_reduce_weights))
    sorted_valid_dis_list = np.array(valid_dis_list)[sorted_indices]
    sorted_abs_fc_reduce_weights = np.abs(fc_reduce_weights)[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_abs_fc_reduce_weights, np.transpose(1/sorted_valid_dis_list))[0, 1]

    # plot results
    ax1.scatter(sorted_abs_fc_reduce_weights, 1/sorted_valid_dis_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('1/Distance to Soma (Pixels)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and 1/Distance to Soma\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (2 y axes, weights abs and distance, grouped)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

    # Reorder fc_reduce_weights and valid_dis_list based on group_indices
    reordered_fc_reduce_weights = [fc_reduce_weights[idx] for idx in flat_index_form_axons]
    reordered_valid_dis_list = [valid_dis_list[idx] for idx in flat_index_form_axons]

    # Create x-labels for groups
    x_labels = [f"Group {i+1}" for i in range(len(index_form_axons))]

    # Bar plot for fc_reduce_weights
    x = np.arange(1, len(flat_index_form_axons) + 1)
    index__ = 0
    for axon in index_form_axons:
        index__ = index__ + len(axon)
        if index__ < len(flat_index_form_axons):
            x[index__:] = x[index__:] + 5 # generate gap space on x axis between groups

    break_indices = np.where(np.diff(x) != 1)[0]
    x_labels_center = [(x[start] + x[end]) / 2 for start, end in zip(np.insert(break_indices + 1, 0, 0), break_indices)]
    x_labels_center.append((x[break_indices[-1]+1] + x[-1])/2)

    ax1.bar(x, np.abs(reordered_fc_reduce_weights), color='salmon', label='Weights')
    ax1.set_title(f'fc_reduce_weights and Distance, corresponding to components, grouped\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Groups', fontsize=16)
    ax1.set_ylabel('Absolute Weight Value', fontsize=16)
    ax1.set_xticks(x_labels_center)
    ax1.set_xticklabels(x_labels, rotation=45, fontsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve (valid_dis_list)
    ax2.plot(x, reordered_valid_dis_list, color='skyblue', label='Distance to Soma')
    ax2.set_ylabel('Distance to Soma', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for both curves
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (2 y axes, weights abs and size)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))
    x = np.arange(1, len(fc_reduce_weights) + 1)

    # Plot fc_reduce_weights
    ax1.bar(x, np.abs(fc_reduce_weights), color='salmon', label = 'Weights')
    ax1.set_title(f'fc_reduce_weights and Size, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Absolute Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve
    ax2.plot(x, valid_size_list, color='skyblue', label='Size/Pixels')
    ax2.set_ylabel('Size/Pixels', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for the second curve
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (correlation between weights abs and size)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(np.abs(fc_reduce_weights))
    sorted_valid_size_list = np.array(valid_size_list)[sorted_indices]
    sorted_abs_fc_reduce_weights = np.abs(fc_reduce_weights)[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_abs_fc_reduce_weights, np.transpose(sorted_valid_size_list))[0, 1]

    # plot results
    ax1.scatter(sorted_abs_fc_reduce_weights, sorted_valid_size_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('Size (Pixel Number)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and Size\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (correlation between weights abs and 1/size)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(np.abs(fc_reduce_weights))
    sorted_valid_size_list = np.array(valid_size_list)[sorted_indices]
    sorted_abs_fc_reduce_weights = np.abs(fc_reduce_weights)[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_abs_fc_reduce_weights, np.transpose(1/sorted_valid_size_list))[0, 1]

    # plot results
    ax1.scatter(sorted_abs_fc_reduce_weights, 1/sorted_valid_size_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('1/Size (Pixel Number)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and 1/Size\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (2 y axes, weight abs and size, grouped)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

    # Reorder fc_reduce_weights and valid_dis_list based on group_indices
    reordered_fc_reduce_weights = [fc_reduce_weights[idx] for idx in flat_index_form_axons]
    reordered_valid_size_list = [valid_size_list[idx] for idx in flat_index_form_axons]

    # Create x-labels for groups
    x_labels = [f"Group {i+1}" for i in range(len(index_form_axons))]

    # Bar plot for fc_reduce_weights
    x = np.arange(1, len(flat_index_form_axons) + 1)
    index__ = 0
    for axon in index_form_axons:
        index__ = index__ + len(axon)
        if index__ < len(flat_index_form_axons):
            x[index__:] = x[index__:] + 5

    break_indices = np.where(np.diff(x) != 1)[0]
    x_labels_center = [(x[start] + x[end]) / 2 for start, end in zip(np.insert(break_indices + 1, 0, 0), break_indices)]
    x_labels_center.append((x[break_indices[-1]+1] + x[-1])/2)

    ax1.bar(x, np.abs(reordered_fc_reduce_weights), color='salmon', label='Weights')
    ax1.set_title(f'fc_reduce_weights and Size, corresponding to components, grouped\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Groups', fontsize=16)
    ax1.set_ylabel('Absolute Weight Value', fontsize=16)
    ax1.set_xticks(x_labels_center)
    ax1.set_xticklabels(x_labels, rotation=45, fontsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve (valid_dis_list)
    ax2.plot(x, reordered_valid_size_list, color='skyblue', label='Size/Pixels')
    ax2.set_ylabel('Size/Pixels', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for both curves
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()

    print("----- ----- -----")



#### Plots with respect to original weights

In [None]:
signal_trace_type_list = ["Raw Signal Traces", "Mean Restored Signal Traces", "Mean and Std Restored Signal Traces"]
# signal_trace_type_list = ["Mean Restored Signal Traces", "Mean and Std Restored Signal Traces"]
# signal_trace_type_list = ["Mean and Std Restored Signal Traces"]

root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"
path_ = os.path.join(root_path, cell_name, cell_name + 'green_Axon.mat')
mat_data = scipy.io.loadmat(path_)
axons = mat_data['Axons']
# Squeeze the outer array
axons = np.squeeze(axons, axis=0)
for i in range(len(axons)):
    # Squeeze the inner array and convert the data type to 'int'
    axons[i] = np.squeeze(axons[i].astype(int), axis=0)

valid_com_index_from_one_list = np.array(valid_com_index_list) + 1

# Convert axons to index form thereby consistent with the index of remained components
index_form_axons = copy.deepcopy(axons)
for i, axon in enumerate(axons):
    for j, bouton in enumerate(axon):
        index_form_axons[i][j] = np.where(valid_com_index_from_one_list == bouton)[0]
flat_index_form_axons = np.concatenate(index_form_axons)


for signal_trace_type in signal_trace_type_list:

    model = FluoModel(data_set.shape[1])

    if signal_trace_type == "Raw Signal Traces":
        model_path = f"./{cell_name}_model_with_raw.pth"
    elif signal_trace_type == "Mean Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored.pth"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored_std_restored.pth"

    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)

    # access the weights for each layer
    fc_shared_weights = model.fc_shared.weight
    fc_reduce_weights = model.fc_reduce.weight
    fc_end_weights = model.fc_end.weight

    # access the biases if needed
    fc_shared_biases = model.fc_shared.bias
    fc_reduce_biases = model.fc_reduce.bias
    fc_end_biases = model.fc_end.bias

    print(f"{signal_trace_type}")

    # Weights
    print("\nfc_shared weights:")
    print(fc_shared_weights)

    print("\nfc_reduce weights:")
    print(fc_reduce_weights)

    print("\nfc_end weights (if without sigmoid, the end layer is not used, then its weight and bias are random):")
    print(fc_end_weights)

    # Biases
    print("\nfc_shared biases:")
    print(fc_shared_biases)

    print("\nfc_reduce biases:")
    print(fc_reduce_biases)

    print("\nfc_end biases (if without sigmoid, the end layer is not used, then its weight and bias are random):")
    print(fc_end_biases)

    # Access the weights
    fc_shared_weights = fc_shared_weights.detach().numpy().flatten()
    fc_reduce_weights = fc_reduce_weights.detach().numpy().flatten()



    ##### Create subplots (weights)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Plot fc_shared_weights
    ax1.bar(range(1,len(fc_shared_weights)+1), fc_shared_weights, color='skyblue')
    ax1.set_title(f'fc_shared_weights, corresponding to frames\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Frame', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    plt.tight_layout()
    plt.show()

    half_len = int(len(fc_shared_weights)/2)
    print(f"The mean the first {half_len} weight values of fc_shared_weights is {np.mean(fc_shared_weights[:half_len])}")
    print(f"The mean the last {half_len} weight values of fc_shared_weights is {np.mean(fc_shared_weights[half_len:])}")


    # Create subplots (1 y axis)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

    # Plot fc_reduce_weights
    ax1.bar(range(1,len(fc_reduce_weights)+1), fc_reduce_weights, color='salmon')
    ax1.set_title(f'fc_reduce_weights, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    plt.tight_layout()
    plt.show()


    ##### Create subplots (2 y axes, weights and significant components)
    ##### if having the data of significant components, here is for cell CL090_230515
    sig_components = [
    2, 5, 7, 15, 17, 19, 25, 28, 30, 36, 37, 38, 44, 49, 52, 59, 60, 62, 63, 67, 69, 71, 74, 76, 78, 80, 82, 86, 87, 88, 94, 99, 100, 103,
    109, 112, 113, 115, 116, 117, 122, 133, 134, 136, 137, 140, 142, 147, 150, 151, 152, 155, 156, 160, 161, 166, 169, 171, 172, 176, 177,
    181, 182, 186, 189, 190, 192, 194, 195, 197, 198, 199, 200, 201, 202, 203, 205, 209, 210, 213, 214, 216, 222, 223, 224, 225, 226, 229, 234,
    236, 237, 238, 240, 242, 244, 250, 252, 254, 255, 256, 257, 258, 260, 261, 262, 268, 270, 271, 275, 276, 277, 278, 281
    ]
    sig_components = np.array(sig_components)
    sig_components_index = sig_components - 1
    sig_components_index = list(sig_components_index)

    intersection_indices = [valid_com_index_list.index(item) for item in valid_com_index_list if item in sig_components_index]
    whether_sig_array = np.zeros_like(fc_reduce_weights)
    whether_sig_array[intersection_indices] = 1

    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))
    x = np.arange(1, len(fc_reduce_weights) + 1)

    # Plot fc_reduce_weights
    ax1.bar(x, fc_reduce_weights, color='salmon', label = 'Weights')
    ax1.set_title(f'fc_reduce_weights and whether significant, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve
    ax2.plot(x, whether_sig_array, color='skyblue', label='Whether significant')
    ax2.set_ylabel('Significant or not', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for the second curve
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()




    ##### Create subplots (2 y axes, weights and distance)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))
    x = np.arange(1, len(fc_reduce_weights) + 1)

    # Plot fc_reduce_weights
    ax1.bar(x, fc_reduce_weights, color='salmon', label = 'Weights')
    ax1.set_title(f'fc_reduce_weights and Distance, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve
    ax2.plot(x, valid_dis_list, color='skyblue', label='Distance to Soma')
    ax2.set_ylabel('Distance to Soma', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for the second curve
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (correlation between weights and F value std)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(fc_reduce_weights)
    sorted_valid_F_std_component = np.array(F_std_component[valid_com_index_list])[sorted_indices]
    sorted_fc_reduce_weights = fc_reduce_weights[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_fc_reduce_weights, np.transpose(sorted_valid_F_std_component))[0, 1]

    # plot results
    ax1.scatter(sorted_fc_reduce_weights, sorted_valid_F_std_component, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('F Std of Components', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and F Std of Components\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (correlation between weights and distance)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(fc_reduce_weights)
    sorted_valid_dis_list = np.array(valid_dis_list)[sorted_indices]
    sorted_fc_reduce_weights = fc_reduce_weights[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_fc_reduce_weights, np.transpose(sorted_valid_dis_list))[0, 1]

    # plot results
    ax1.scatter(sorted_fc_reduce_weights, sorted_valid_dis_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('Distance to Soma (Pixels)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and Distance to Soma\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (correlation between weights and 1/distance)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(fc_reduce_weights)
    sorted_valid_dis_list = np.array(valid_dis_list)[sorted_indices]
    sorted_fc_reduce_weights = fc_reduce_weights[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_fc_reduce_weights, np.transpose(1/sorted_valid_dis_list))[0, 1]

    # plot results
    ax1.scatter(sorted_fc_reduce_weights, 1/sorted_valid_dis_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('1/Distance to Soma (Pixels)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and 1/Distance to Soma\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (2 y axes, weights and distance, grouped)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

    # Reorder fc_reduce_weights and valid_dis_list based on group_indices
    reordered_fc_reduce_weights = [fc_reduce_weights[idx] for idx in flat_index_form_axons]
    reordered_valid_dis_list = [valid_dis_list[idx] for idx in flat_index_form_axons]

    # Create x-labels for groups
    x_labels = [f"Group {i+1}" for i in range(len(index_form_axons))]

    # Bar plot for fc_reduce_weights
    x = np.arange(1, len(flat_index_form_axons) + 1)
    index__ = 0
    for axon in index_form_axons:
        index__ = index__ + len(axon)
        if index__ < len(flat_index_form_axons):
            x[index__:] = x[index__:] + 5 # generate gap space on x axis between groups

    break_indices = np.where(np.diff(x) != 1)[0]
    x_labels_center = [(x[start] + x[end]) / 2 for start, end in zip(np.insert(break_indices + 1, 0, 0), break_indices)]
    x_labels_center.append((x[break_indices[-1]+1] + x[-1])/2)

    ax1.bar(x, reordered_fc_reduce_weights, color='salmon', label='Weights')
    ax1.set_title(f'fc_reduce_weights and Distance, corresponding to components, grouped\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Groups', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.set_xticks(x_labels_center)
    ax1.set_xticklabels(x_labels, rotation=45, fontsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve (valid_dis_list)
    ax2.plot(x, reordered_valid_dis_list, color='skyblue', label='Distance to Soma')
    ax2.set_ylabel('Distance to Soma', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for both curves
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (2 y axes, weights and size)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))
    x = np.arange(1, len(fc_reduce_weights) + 1)

    # Plot fc_reduce_weights
    ax1.bar(x, fc_reduce_weights, color='salmon', label = 'Weights')
    ax1.set_title(f'fc_reduce_weights and Size, corresponding to components\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Weight Index/Component', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve
    ax2.plot(x, valid_size_list, color='skyblue', label='Size/Pixels')
    ax2.set_ylabel('Size/Pixels', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for the second curve
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()



    ##### Create subplots (correlation between weights and size)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(fc_reduce_weights)
    sorted_valid_size_list = np.array(valid_size_list)[sorted_indices]
    sorted_fc_reduce_weights = fc_reduce_weights[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_fc_reduce_weights, np.transpose(sorted_valid_size_list))[0, 1]

    # plot results
    ax1.scatter(sorted_fc_reduce_weights, sorted_valid_size_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('Size (Pixel Number)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and Size\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (correlation between weights and 1/size)
    fig, ax1 = plt.subplots(figsize=(6.4, 4.8))

    # Calculate the correlation coefficient
    sorted_indices = np.argsort(fc_reduce_weights)
    sorted_valid_size_list = np.array(valid_size_list)[sorted_indices]
    sorted_fc_reduce_weights = fc_reduce_weights[sorted_indices]
    correlation_coefficient = np.corrcoef(sorted_fc_reduce_weights, np.transpose(1/sorted_valid_size_list))[0, 1]

    # plot results
    ax1.scatter(sorted_fc_reduce_weights, 1/sorted_valid_size_list, color = 'salmon', label=f'Correlation Coefficient = {correlation_coefficient:.5f}')
    ax1.set_xlabel('Weights', fontsize=16)
    ax1.set_ylabel('1/Size (Pixel Number)', fontsize=16)
    ax1.tick_params(labelsize=14)
    ax1.set_title(f"Scatter Plot between Weights and 1/Size\n({signal_trace_type})", fontsize=18)
    ax1.legend(fontsize=14, facecolor='none')
    plt.show()



    ##### Create subplots (2 y axes, weight and size, grouped)
    fig, ax1 = plt.subplots(figsize=(6.4*2, 4.8))

    # Reorder fc_reduce_weights and valid_dis_list based on group_indices
    reordered_fc_reduce_weights = [fc_reduce_weights[idx] for idx in flat_index_form_axons]
    reordered_valid_size_list = [valid_size_list[idx] for idx in flat_index_form_axons]

    # Create x-labels for groups
    x_labels = [f"Group {i+1}" for i in range(len(index_form_axons))]

    # Bar plot for fc_reduce_weights
    x = np.arange(1, len(flat_index_form_axons) + 1)
    index__ = 0
    for axon in index_form_axons:
        index__ = index__ + len(axon)
        if index__ < len(flat_index_form_axons):
            x[index__:] = x[index__:] + 5

    break_indices = np.where(np.diff(x) != 1)[0]
    x_labels_center = [(x[start] + x[end]) / 2 for start, end in zip(np.insert(break_indices + 1, 0, 0), break_indices)]
    x_labels_center.append((x[break_indices[-1]+1] + x[-1])/2)

    ax1.bar(x, reordered_fc_reduce_weights, color='salmon', label='Weights')
    ax1.set_title(f'fc_reduce_weights and Size, corresponding to components, grouped\n({signal_trace_type})', fontsize=18)
    ax1.set_xlabel('Groups', fontsize=16)
    ax1.set_ylabel('Weight Value', fontsize=16)
    ax1.set_xticks(x_labels_center)
    ax1.set_xticklabels(x_labels, rotation=45, fontsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # Create a second y-axis
    ax2 = ax1.twinx()

    # Plot the second curve (valid_dis_list)
    ax2.plot(x, reordered_valid_size_list, color='skyblue', label='Size/Pixels')
    ax2.set_ylabel('Size/Pixels', fontsize=16)
    ax2.tick_params(axis='y', labelsize=14)

    # Adding legend for both curves
    ax1.legend(loc='upper left', fontsize=14)
    ax2.legend(loc='upper right', fontsize=14)

    plt.tight_layout()
    plt.show()

    print("----- ----- -----")


#### Plot weights with component locations (colorbar map)

In [None]:
signal_trace_type_list = ["Raw Signal Traces", "Mean Restored Signal Traces", "Mean and Std Restored Signal Traces"]

root_path = "/home/chenhuimiao/dLGN_Neuron_Modeling/Fluorescence_Data/FluoData4Fitting_Average"
path_ = os.path.join(root_path, cell_name, cell_name + 'green_BoutonMasks.mat')
data_ = h5py.File(path_, 'r') # Open the MATLAB v7.3 file
green_bouton_masks = copy.deepcopy(np.array(data_['BoutonMasks']))
green_bouton_masks = np.transpose(green_bouton_masks, (2, 1, 0))
data_.close()

for signal_trace_type in signal_trace_type_list:

    model = FluoModel(data_set.shape[1])

    if signal_trace_type == "Raw Signal Traces":
        model_path = f"./{cell_name}_model_with_raw.pth"
    elif signal_trace_type == "Mean Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored.pth"
    elif signal_trace_type == "Mean and Std Restored Signal Traces":
        model_path = f"./{cell_name}_model_with_mean_restored_std_restored.pth"

    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)

    fc_reduce_weights = model.fc_reduce.weight

    print(f"{signal_trace_type}")

    print("\nfc_reduce weights:")
    print(fc_reduce_weights)

    fc_reduce_weights = fc_reduce_weights.detach().numpy().flatten()

    weights_map = copy.deepcopy(green_bouton_masks)

    for k, index_ in enumerate(valid_com_index_list):
        for i in range(weights_map.shape[0]):
            for j in range(weights_map.shape[1]):
                if weights_map[i, j, index_] == 1:
                    weights_map[i, j, index_] = fc_reduce_weights[k]

    weights_map = np.sum(weights_map[:,:,valid_com_index_list], axis=2)
    fig, ax = plt.subplots(figsize=(6.4*1.5, 4.8*1.5))
    cax = ax.imshow(weights_map, cmap='cool')
    colorbar = fig.colorbar(cax)
    colorbar.set_label('Colorbar Label', fontsize=14)
    ax.tick_params(axis='x', labelsize=14)
    ax.tick_params(axis='y', labelsize=14)
    ax.set_title(f"Weights Map\n({signal_trace_type})", fontsize=18)
    plt.show()

