# Gotta go fast!!!!

# Init Code

In [1]:
import Waven.WaveletGenerator as wg
import Waven.Analysis_Utils as au
import Waven.LoadPinkNoise as lpn
import numpy as np
import gc
import os
import torch
# Non-GUI analysis workflow
import matplotlib.pyplot as plt
from pathlib import Path

results = None

In [2]:
def preload_wavelets(param_defaults, gabor_param):
    """
    Preload and cache wavelet decomposition - only needs to run once for the stimulus
    Returns the wavelets that can be reused for all scans
    """
    movpath = param_defaults["Movie Path"]
    lib_path = param_defaults["Library Path"]
    visual_coverage = eval(param_defaults["Visual Coverage"])
    analysis_coverage = eval(param_defaults["Analysis Coverage"])
    nx0 = int(param_defaults["NX0"])
    ny0 = int(param_defaults["NY0"])
    nx = int(param_defaults["NX"])
    ny = int(param_defaults["NY"])
    sigmas = np.array(eval(param_defaults["Sigmas"]))
    ns = len(sigmas)
    n_theta = int(gabor_param["N_thetas"])
    
    try:
        device = param_defaults["Device"]
    except KeyError:
        device = "cuda:0"
    
    torch.cuda.set_device(device)
    parent_dir = os.path.dirname(movpath)
    
    print("="*60)
    print("PRELOADING WAVELETS (one-time operation)")
    print("="*60)
    
    # First try: load pre-computed downsampled wavelets
    try:
        wavelets_downsampled = np.load(os.path.join(parent_dir, 'dwt_downsampled_videodata.npy'))
        w_r_downsampled = wavelets_downsampled[0]
        w_i_downsampled = wavelets_downsampled[1]
        w_c_downsampled = wavelets_downsampled[2]
        del wavelets_downsampled
        gc.collect()
        print("✓ Loaded cached downsampled wavelets")
        return w_r_downsampled, w_i_downsampled, w_c_downsampled
    except Exception as e:
        print(f"Cached wavelets not found: {e}")
    
    # Second try: load coarse wavelets
    try:
        print("Attempting to load coarse wavelets...")
        w_r_downsampled, w_i_downsampled, w_c_downsampled = lpn.coarseWavelet(
            parent_dir, False, nx0, ny0, nx, ny, n_theta, ns)
        print("✓ Loaded coarse wavelets")
        return w_r_downsampled, w_i_downsampled, w_c_downsampled
    except Exception as e:
        print(f"Coarse wavelets not found: {e}")
    
    # Third try: Check if downsampled video exists and generate wavelets
    downsampled_video_path = movpath[:-4] + '_downsampled.npy'
    if os.path.exists(downsampled_video_path):
        print(f"✓ Found downsampled video at {downsampled_video_path}")
        print("Generating wavelet decomposition...")
        videodata = np.load(downsampled_video_path)
        print(f"  Video shape: {videodata.shape}")
        
        wg.waveletDecomposition_batched(videodata, [0, 1], sigmas, parent_dir, 
                                       library_path=lib_path, device=device, batch_size=32)
        
        w_r_downsampled, w_i_downsampled, w_c_downsampled = lpn.coarseWavelet(
            parent_dir, False, nx0, ny0, nx, ny, n_theta, ns)
        print("✓ Completed wavelet decomposition")
        return w_r_downsampled, w_i_downsampled, w_c_downsampled
    
    # Fourth try: Full pipeline - downsample video then decompose
    print("Running full pipeline: downsample video + wavelet decomposition...")
    
    if visual_coverage != analysis_coverage:
        visual_coverage_arr = np.array(visual_coverage)
        analysis_coverage_arr = np.array(analysis_coverage)
        ratio_x = 1 - ((visual_coverage_arr[0] - visual_coverage_arr[1]) - 
                      (analysis_coverage_arr[0] - analysis_coverage_arr[1])) / \
                      (visual_coverage_arr[0] - visual_coverage_arr[1])
        ratio_y = 1 - ((visual_coverage_arr[2] - visual_coverage_arr[3]) - 
                      (analysis_coverage_arr[2] - analysis_coverage_arr[3])) / \
                      (visual_coverage_arr[2] - visual_coverage_arr[3])
    else:
        ratio_x = 1
        ratio_y = 1
    
    print(f"  Downsampling video: {movpath}")
    wg.downsample_video_binary(movpath, visual_coverage, analysis_coverage, 
                               shape=(ny, nx), chunk_size=1000, ratios=(ratio_x, ratio_y))
    
    videodata = np.load(movpath[:-4] + '_downsampled.npy')
    print(f"  Downsampled video shape: {videodata.shape}")
    
    print("  Running wavelet decomposition...")
    wg.waveletDecomposition_batched(videodata, [0, 1], sigmas, parent_dir, 
                                   library_path=lib_path, device=device, batch_size=32)
    
    w_r_downsampled, w_i_downsampled, w_c_downsampled = lpn.coarseWavelet(
        parent_dir, False, nx0, ny0, nx, ny, n_theta, ns)
    
    print("✓ Completed full wavelet pipeline")
    print("="*60)
    
    return w_r_downsampled, w_i_downsampled, w_c_downsampled


In [3]:
def run_analysis_fast(scan_dir, w_c_preloaded, param_defaults, gabor_param, plot=True):
    """
    Fast analysis using preloaded wavelets - only processes neural data
    
    Parameters:
        scan_dir: Path to scan directory
        w_c_preloaded: Preloaded wavelet tensor (complex wavelets)
        param_defaults: Parameter dictionary
        gabor_param: Gabor parameter dictionary
        plot: Whether to display plots (set False for batch processing)
    
    Returns:
        results: Dictionary with analysis results
    """
    import time
    start_time = time.time()
    
    # Extract parameters
    visual_coverage = eval(param_defaults["Visual Coverage"])
    analysis_coverage = eval(param_defaults["Analysis Coverage"])
    n_planes = int(param_defaults["Number of Planes"])
    block_end = int(param_defaults["Block End"])
    nx = int(param_defaults["NX"])
    ny = int(param_defaults["NY"])
    sigmas = np.array(eval(param_defaults["Sigmas"]))
    ns = len(sigmas)
    resolution = float(param_defaults["Resolution"])
    spks_path = param_defaults["Spks Path"]
    nb_frames = int(param_defaults["Number of Frames"])
    n_trial2keep = int(param_defaults["Number of Trials to Keep"])
    n_theta = int(gabor_param["N_thetas"])
    
    screen_ratio = abs(visual_coverage[0] - visual_coverage[1]) / nx
    xM, xm, yM, ym = analysis_coverage
    deg_per_pix = abs(xM - xm) / nx
    sigmas_deg = np.trunc(2 * deg_per_pix * sigmas * 100) / 100
    
    print(f"\n{'='*60}")
    print(f"ANALYZING: {os.path.basename(scan_dir)}")
    print(f"{'='*60}")
    
    # Build paths
    pathsuite2p = scan_dir + '/suite2p'
    
    # Load spike data
    if spks_path == 'None':
        print('⏳ Loading and aligning neural data...')
        spks, spks_n, neuron_pos = lpn.loadSPKMesoscope(scan_dir, pathsuite2p, block_end, 
                                                        n_planes, nb_frames, threshold=1.25, 
                                                        last=True, method='frame2ttl')
        neuron_pos = lpn.correctNeuronPos(neuron_pos, resolution)
    else:
        print(f'⏳ Loading spks from {spks_path}...')
        spks = np.load(spks_path)
        parent_dir = os.path.dirname(spks_path)
        neuron_pos = np.load(os.path.join(parent_dir, 'pos.npy'))
    
    print(f"  ✓ Neurons: {spks.shape[0]}, Frames: {spks.shape[1]}")
    
    # Compute quality metrics
    print('⏳ Computing neuron quality metrics...')
    n_neurons = spks.shape[0]
    
    if n_trial2keep > 1:
        respcorr = au.repetability_trial3(spks, neuron_pos, plotting=False)
    else:
        respcorr = np.ones(n_neurons)
    
    skewness = au.compute_skewness_neurons(spks, plotting=False)
    skewness = np.array(skewness)
    
    if n_trial2keep > 1:
        filter_mask = np.logical_and(respcorr >= 0.2, skewness <= 20)
    else:
        filter_mask = skewness <= 20
    
    print(f"  ✓ Quality filter: {np.sum(filter_mask)}/{n_neurons} neurons passed")
    
    # Compute receptive fields (this is the main computation)
    print('⏳ Computing receptive fields...')
    n_frames_to_use = min(w_c_preloaded.shape[0], spks.shape[1])
    
    rfs_gabor = au.PearsonCorrelationPinkNoise_batched(
        stim=w_c_preloaded[:n_frames_to_use].reshape(n_frames_to_use, -1),
        resp=spks[:, :n_frames_to_use],
        neuron_pos=neuron_pos,
        nx=nx, ny=ny, n_theta=n_theta, ns=ns,
        visual_coverage=analysis_coverage,
        screen_ratio=screen_ratio,
        sigmas=sigmas_deg
    )
    
    print(f"  ✓ Receptive fields computed")
    
    # Save results
    print('⏳ Saving results...')
    save_dir = scan_dir + "/zebra/"
    os.makedirs(save_dir, exist_ok=True)
    
    np.save(os.path.join(save_dir, 'correlation_matrix.npy'), rfs_gabor[0])
    np.save(os.path.join(save_dir, 'maxes_indices.npy'), rfs_gabor[1])
    np.save(os.path.join(save_dir, 'maxes_corrected.npy'), rfs_gabor[2])
    
    results = {
        'spks': spks,
        'neuron_pos': neuron_pos,
        'rfs_gabor': rfs_gabor,
        'filter_mask': filter_mask,
        'respcorr': respcorr,
        'skewness': skewness
    }
    
    np.save(os.path.join(save_dir, 'analysis_results.npy'), results)
    print(f"  ✓ Results saved to {save_dir}")
    
    # Plot if requested
    if plot:
        print('⏳ Generating plots...')
        fig2, ax2 = plt.subplots(2, 2, figsize=(14, 12))
        maxes1 = rfs_gabor[2]
        plt.rcParams['axes.facecolor'] = 'none'
        
        m = ax2[0, 0].scatter(neuron_pos[:, 0], neuron_pos[:, 1], s=10, c=maxes1[0], 
                             cmap='jet', alpha=filter_mask)
        fig2.colorbar(m, ax=ax2[0, 0])
        ax2[0, 0].set_title('Azimuth Preference (deg)')
        ax2[0, 0].set_xlabel('X (um)')
        ax2[0, 0].set_ylabel('Y (um)')
        
        m = ax2[0, 1].scatter(neuron_pos[:, 0], neuron_pos[:, 1], s=10, c=maxes1[1], 
                             cmap='jet_r', alpha=filter_mask)
        fig2.colorbar(m, ax=ax2[0, 1])
        ax2[0, 1].set_title('Elevation Preference (deg)')
        ax2[0, 1].set_xlabel('X (um)')
        ax2[0, 1].set_ylabel('Y (um)')
        
        m = ax2[1, 0].scatter(neuron_pos[:, 0], neuron_pos[:, 1], s=10, c=maxes1[2], 
                             cmap='hsv', alpha=filter_mask)
        fig2.colorbar(m, ax=ax2[1, 0])
        ax2[1, 0].set_title('Orientation Preference (deg)')
        ax2[1, 0].set_xlabel('X (um)')
        ax2[1, 0].set_ylabel('Y (um)')
        
        m = ax2[1, 1].scatter(neuron_pos[:, 0], neuron_pos[:, 1], s=10, c=maxes1[3], 
                             cmap='coolwarm', alpha=filter_mask)
        fig2.colorbar(m, ax=ax2[1, 1])
        ax2[1, 1].set_title('Preferred Size (deg)')
        ax2[1, 1].set_xlabel('X (um)')
        ax2[1, 1].set_ylabel('Y (um)')
        
        plt.suptitle(f"{os.path.basename(scan_dir)}", fontsize=14, y=0.995)
        plt.tight_layout()
        plt.show()
    
    elapsed = time.time() - start_time
    print(f"{'='*60}")
    print(f"✓ COMPLETE in {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
    print(f"{'='*60}\n")
    
    return results

# Run before:

In [4]:
# List of default parameters for the Gabor Library
gabor_param={
    "N_thetas":"6",
    "Sigmas": "[2, 3, 4, 5, 6, 8]",
    "Frequencies": "[0.015, 0.04, 0.07, 0.1]",
    "Phases": "[0, 90]",
    "NX": "100",
    "NY": "75",
    "Save Path":"/datajoint-data/data/leonk/analysis/zebra/gabor_library.npy"
}

In [5]:
# List of parameters

param_defaults = {
    "Dirs": "/datajoint-data/data/leonk/",
    "Number of Planes": "1",
    "Block End": "0",
    "screen_x":"800",
    "screen_y":"600",
    "NX0": "100", # downsampled x positions from Gabor library
    "NY0": "75", # downsampled y positions from Gabor library
    "NX": "100", # target downsampled x positions for wavelet computation
    "NY": "75", # target downsampled y positions for wavelet computation
    "Resolution":"1.2",
    "Sigmas": "[2, 3, 4, 5, 6, 8]",
    "Frequencies": "[0.015, 0.04, 0.07, 0.1]",
    "Visual Coverage":"[-42, 42, 35, -15]",
    "Analysis Coverage": "[-42, 42, 35, -15]",
    "Number of Frames": "54000",  # stimulus frames in each trial
    "Number of Trials to Keep": "1",
    "Movie Path": "/datajoint-data/data/leonk/analysis/zebra/fullscreen_zebra.mp4",
    "Library Path": "/datajoint-data/data/leonk/analysis/zebra/gabor_library.npy",
    "Spks Path": "None",
    "Device": "cuda:1"
}


Pre-load the Wavelets

In [6]:
# Run this once to preload wavelets
print("Preloading wavelets for stimulus...")
w_r_global, w_i_global, w_c_global = preload_wavelets(param_defaults, gabor_param)
print(f"\n✓ Wavelets cached in memory!")
print(f"  Shape: {w_c_global.shape}")
print(f"  Memory: ~{w_c_global.nbytes / 1e9:.2f} GB")
print(f"  Ready for fast analysis on all scans!\n")

Preloading wavelets for stimulus...
PRELOADING WAVELETS (one-time operation)
PRELOADING WAVELETS (one-time operation)
✓ Loaded cached downsampled wavelets

✓ Wavelets cached in memory!
  Shape: (54000, 100, 75, 6, 6, 1)
  Memory: ~58.32 GB
  Ready for fast analysis on all scans!

✓ Loaded cached downsampled wavelets

✓ Wavelets cached in memory!
  Shape: (54000, 100, 75, 6, 6, 1)
  Memory: ~58.32 GB
  Ready for fast analysis on all scans!



# Run when scan ready:

Define Scans

In [7]:
scans = [
        'LE_ROS-2210_2025-11-28_scan9FXEU7TJ_sess9FXEU7TJ'
        ]

In [8]:
# FAST BATCH PROCESSING - Run analysis on all scans using preloaded wavelets
import time

total_start = time.time()
all_results = {}

print("\n" + "="*60)
print(f"BATCH PROCESSING {len(scans)} SCANS")
print("="*60)

for i, scan in enumerate(scans, 1):
    scan_dir = "/datajoint-data/data/leonk/" + scan 
    
    print(f"\n[{i}/{len(scans)}] Processing: {scan}")
    
    # Run fast analysis with preloaded wavelets
    results = run_analysis_fast(
        scan_dir=scan_dir,
        w_c_preloaded=w_c_global,  # Use the preloaded wavelets!
        param_defaults=param_defaults,
        gabor_param=gabor_param,
        plot=True  # Set to False for even faster batch processing
    )
    
    all_results[scan] = results

total_elapsed = time.time() - total_start
avg_time = total_elapsed / len(scans)

print("\n" + "="*60)
print("BATCH PROCESSING COMPLETE!")
print("="*60)
print(f"Total time: {total_elapsed:.1f}s ({total_elapsed/60:.1f} min)")
print(f"Average per scan: {avg_time:.1f}s ({avg_time/60:.1f} min)")
print(f"Scans processed: {len(all_results)}")
print("="*60)


BATCH PROCESSING 1 SCANS

[1/1] Processing: LE_ROS-2210_2025-11-28_scan9FXEU7TJ_sess9FXEU7TJ

ANALYZING: LE_ROS-2210_2025-11-28_scan9FXEU7TJ_sess9FXEU7TJ
⏳ Loading and aligning neural data...
last session
single plane
single plane
shape spks :  (564, 27300)
neuron_pos spks :  (564, 2)
Found 1 complete trials (54001 stimulus frames, 27300 imaging frames)
Processing trial 1
  Trial 1: 26963 imaging frames, 26963 timepoints
Detected incomplete final trial
Aligning trial 1: (564, 27300), max index: 27008
single plane
shape spks :  (564, 27300)
neuron_pos spks :  (564, 2)
Found 1 complete trials (54001 stimulus frames, 27300 imaging frames)
Processing trial 1
  Trial 1: 26963 imaging frames, 26963 timepoints
Detected incomplete final trial
Aligning trial 1: (564, 27300), max index: 27008
  26963 frames, 26963 timepoints, spks shape: (564, 26963)
  26963 frames, 26963 timepoints, spks shape: (564, 26963)
Alignment complete: 1 trials processed
data aligned
Alignment complete: 1 trials proces