# Generate Mean Normalized Waveform Numpy
### Generates a .npy that stores the mean waveform for each cluster in each recordings as well as its firing rate, cluster_id, and recording

# Inputs & Data
`BASE_DIR` is the single parent directory for all recordings, the Cyborg external hard drive in this example

`SAVE_PATH` is the name and location of the ouput .npy file

# Processing
This script looks at the `cluster_info.tsv` file in each phy folder to determine the "good" units, and then iterates through each folder using the phylib package to get all the waveforms for each good unit on only the best channel, and then averages the waveform of each spike into 1 average waveform and then normalizes it (so that the max is either 1 or -1) and stores that within the .npy that gets created. 

This scripts take a while to run (~1hr on my laptop for all of RCE Cohort 2+3)

In [1]:
import os
import pandas as pd
import numpy as np
import concurrent.futures
from phylib.io.model import load_model
from sklearn.preprocessing import normalize
from matplotlib import pyplot as plt

# Constants
BASE_DIR = r'D:\pc_lab\RCE\finished_proc\phy_curation'
SAVE_PATH = r'C:\Users\short\Documents\GitHub\waveform_id\meanWave_clust_240715_1.npy'  # Path to save the waveform data

def process_folder(folder):
    cluster_info_path = os.path.join(folder, 'phy', 'cluster_info.tsv')
    params_path = os.path.join(folder, 'phy', 'params.py')
    recording_name = os.path.basename(folder)
    
    # Read the TSV file
    try:
        cluster_info = pd.read_csv(cluster_info_path, sep='\t')
    except Exception as e:
        print(f"Error reading {cluster_info_path}: {e}")
        return []

    try:
        good_clusters = cluster_info[
            (cluster_info['group'] == 'good') & 
            (cluster_info['fr'] > 0.5)
        ][['cluster_id', 'fr']]
        
        # Load the TemplateModel
        model = load_model(params_path)
        
        # Initialize a list to store the mean waveforms, cluster IDs, firing rates, and recording names
        mean_waveforms = []
        cluster_ids = []
        firing_rates = []
        recording_names = []
        
        for _, row in good_clusters.iterrows():
            cluster_id = row['cluster_id']
            firing_rate = row['fr']
            
            # Get cluster spike waveforms
            waveforms = model.get_cluster_spike_waveforms(cluster_id)[:, :, 0]
            # Calculate the mean waveform and normalize it
            mean_waveform = waveforms.mean(axis=0)
            norm_mean_waveform = normalize(mean_waveform.reshape(1, -1), norm='max').squeeze()
            mean_waveforms.append(norm_mean_waveform)
            cluster_ids.append(cluster_id)
            firing_rates.append(firing_rate)
            recording_names.append(recording_name)
        
        return list(zip(mean_waveforms, cluster_ids, firing_rates, recording_names))
    
    except Exception as e:
        print(f"Error processing data in {folder}: {e}")
        return []

def main():
    # Retrieve all recording folders
    folders = [os.path.join(BASE_DIR, f) for f in os.listdir(BASE_DIR) if os.path.isdir(os.path.join(BASE_DIR, f))]
    
    # Process each folder
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        for result in executor.map(process_folder, folders):
            if result:
                results.extend(result)

    # Combine all waveforms, cluster IDs, firing rates, and recording names into a single list
    if not results:
        print("No waveforms found.")
        return

    normWFs_array = np.array([r[0] for r in results])
    cluster_ids = np.array([r[1] for r in results])
    firing_rates = np.array([r[2] for r in results])
    recording_names = np.array([r[3] for r in results])

    # Save the results to a file
    np.save(SAVE_PATH, {'waveforms': normWFs_array, 'cluster_ids': cluster_ids, 'firing_rates': firing_rates, 'recording_names': recording_names})

if __name__ == "__main__":
    main()

## To load the data afterwards (in order to perform UMAP):

`SAVED_PATH = r"C:\Users\short\Documents\GitHub\waveform_id\meanWave_clust_240715_1.npy"  # Path to save the waveform data`

`data = np.load(SAVED_PATH, allow_pickle=True).item()`

This step got the actual normalized mean waveforms, but before you perform UMAP, you'll need to flip positive waveforms to negative. This is because we believe the direction of the waveform is more representative of electrode proximity to cell body, and not representative of cell type. The instructions for this will be done in the next Notebook: `2_THE_WaveMAP_Notebook.ipynb`