In [1]:
#　解析に使用する定数の定義
import numpy as np
import matplotlib.pyplot as plt
import h5py
from datetime import datetime
from scipy.signal import firwin, lfilter
from scipy.signal import argrelmin
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from hdbscan import HDBSCAN
from scipy import signal
import os
import itertools
import functools
import glob
import numpy as np


In [2]:
# raw dataからtrigger情報を取得
#import numpy as np

def TRGfromDAT(dat_path, sampling_rate=20000, CHANNEL_NUM = 61, CHANNEL_TRG = 0):
    dat = np.fromfile(dat_path,dtype='h').reshape(-1,CHANNEL_NUM)
    trg_column = dat[:,CHANNEL_TRG]
    not_saturate = np.where(trg_column != trg_column.max())[0] #search unsaturated range 
    temp_trg = np.diff(not_saturate)
    temp_trg_index = np.where(temp_trg != 1)[0] + 1 #search uncontinous index from unsaturated range 
    trg_index = np.r_[0,temp_trg_index]
    trg_array = not_saturate[trg_index] / sampling_rate
    return trg_array

# raw dataから1ch分の情報を取得
#import numpy as np

def ReadRawFile(path, channel, CHANNEL_NUM = 61):
    raw = np.fromfile(path,dtype='h').reshape(-1,CHANNEL_NUM)
    # print('raw.shape', raw.shape)
    return raw[:,CH_ARRAY==channel][:,0]


def BandPassFilter(wave_raw, bottom=300, top=3000, sampling_rate=20000):
    nyq = sampling_rate / 2
    cutoff = np.array([bottom, top]) / nyq
    numtaps = 255
    bpf = firwin(numtaps, cutoff, pass_zero=False)
    return lfilter(bpf, 1, wave_raw)[int((numtaps-1)/2):]

def SpikeDetection(wave_filtered, sd_thr=4, order=15, spike=-1):
    peaks = argrelmin(-1*spike*wave_filtered, order=order)[0]
    #選別用の閾値の計算
    ##median
    median = np.median(wave_filtered)
    threshold = median - sd_thr * (np.median(abs(wave_filtered - median)) / 0.6745)
    # print('Threshold: ', threshold)
    #スパイクの選別
    spike_index = peaks[wave_filtered[peaks] < threshold]
    # print('Peak number: ', spike_index.size)
    return spike_index


def MakeWaveShape(temp_wave_array):
    return np.arange(temp_wave_array[0], temp_wave_array[1])

def GetWaveShape(spike_index, wave_filtered, area_before_peak_ms=1, area_after_peak_ms=2, sampling_rate=20000, ms=1000):
    area_before_peak_index = int(area_before_peak_ms * sampling_rate / ms)
    area_after_peak_index = int(area_after_peak_ms * sampling_rate / ms)
    temp_wave_array = np.c_[spike_index-area_before_peak_index, spike_index+area_after_peak_index]
    wave_array = np.array(list(map(MakeWaveShape, temp_wave_array)))

    unuse_peak_index_1 = np.where(wave_array[:,-1] > wave_filtered.size-40)[0]
    unuse_peak_index_2 = np.where(wave_array[:,0] < 0)[0]
    wave_array = np.delete(wave_array, np.r_[unuse_peak_index_1,unuse_peak_index_2], axis=0)
    spike_index = np.delete(spike_index, np.r_[unuse_peak_index_1,unuse_peak_index_2], axis=0)

    spike_shape = wave_filtered[wave_array]
    # print('Wave_Shape.shape :', spike_shape.shape)
    return spike_shape, spike_index

def CutWaveShape(spike_shape, area=13):
#     roi = np.arange(spike_shape.shape[1]/3 - area, spike_shape.shape[1]/3 + 2*area + 1).astype(np.int)
    roi = np.arange(
    spike_shape.shape[1] / 3 - area, 
    spike_shape.shape[1] / 3 + 2 * area + 1
).astype(int)
    return spike_shape[:,roi]

def DimensionalityReductionWithDiff1(features, n_comp):
    features_diff = np.diff(features, n=1)
    pca = PCA(n_components=n_comp)
    X_pca = pca.fit_transform(features_diff)
    return X_pca, pca.explained_variance_ratio_

def DimensionalityReductionWithDiffs(features, n_comp):
    features_diff = np.c_[np.diff(features, n=1), np.diff(features, n=2)]
    pca = PCA(n_components=features.shape[1])
    X_pca = pca.fit_transform(features_diff)
    return X_pca[:,:int(n_comp)], pca.explained_variance_ratio_[:int(n_comp)]


# def ClusteringWithHDBSCAN(spike_feature, clu_size=2500, min_sam=250,cor_num=4, lea_siz=100): 
#     try:
#         clusters = HDBSCAN(min_cluster_size=clu_size, min_samples=min_sam, leaf_size=lea_siz,
#                            cluster_selection_method='leaf',core_dist_n_jobs=cor_num).fit_predict(spike_feature)
#         return clusters
#     except ValueError:
#         print('There was ValueError!! So now using eom!!!')
#         hdbscan = HDBSCAN(min_cluster_size=10, min_samples=100, core_dist_n_jobs=cor_num,allow_single_cluster=True)
#         hdbscan.fit(spike_feature)
#         clusters = hdbscan.labels_
#         clusters[hdbscan.probabilities_ < 0.3] = -1
#         return clusters
#     except:
#         print('Any Error Were Occured!!!')
#         return 0

def ClusteringWithHDBSCAN(spike_feature, clu_size=2000, min_sam=250,cor_num=1, lea_siz=100): 
        try:
            clusters = HDBSCAN(min_cluster_size=clu_size, min_samples=min_sam, leaf_size=lea_siz,
                               cluster_selection_method='leaf',core_dist_n_jobs=cor_num).fit_predict(spike_feature)
#             print("pjifea")
            if(np.unique(clusters).shape[0] == 1):
                raise ValueError
#             print("あ")
            return clusters
        except ValueError:
            try:
                clusters = HDBSCAN(min_cluster_size=320, min_samples=10, leaf_size=lea_siz,
                                        cluster_selection_method='leaf',core_dist_n_jobs=cor_num, allow_single_cluster=True).fit_predict(spike_feature)
                if(np.unique(clusters).shape[0] == 1):
                    raise ValueError
#                 print("い")
                return clusters
            except ValueError:
                print('There was ValueError!! So now using eom!!!')
                hdbscan = HDBSCAN(min_cluster_size=10, min_samples=100, core_dist_n_jobs=cor_num,allow_single_cluster=True)
                hdbscan.fit(spike_feature)
                clusters = hdbscan.labels_
                clusters[hdbscan.probabilities_ < 0.3] = -1

                return clusters
            except:
                print('Any Error Were Occured!!!')
                return clusters

        except:
            print('Any Error Were Occured!!!')

            return 0

from scipy import signal
import os


def CalcACR(spike_time):
    #変数の定義
    window_auto = 1000
    binWidth_auto = 1 #[ms]
    #格納先の作成
    bin_num = int(((window_auto * 2)/binWidth_auto) + 1)
    hist_auto = np.zeros(bin_num)
    isi_size = spike_time.size
    ##ex) spike_time_all = [1,3,10,100,2100]
    for mid_search in spike_time:
        ##1週目
        #mid_serch = 1
        # left_end = 1 - 1000 = -999 (ms)
        left_end = mid_search - window_auto
        #right_end = 1 + 1000 = 1001 (ms)
        right_end = mid_search + window_auto
        # -999(ms) ~  1001(ms)の間に存在するスパイク発火のインデックスをindex_serchに格納
        #該当するスパイク　 1, 3, 10, 100 　＊ spike_time_all内のスパイク発火時間
        #index_search = 0,1,2,3
        index_search = np.where((spike_time >= left_end)&(spike_time <= right_end))[0]
        #spike_time_all[index_search] = 1,3,10
        #temp_spike_time - 1 = 0, 2, 9
        temp_spike_time = spike_time[index_search] - mid_search
        #ヒストグラム書く
        hist_auto += np.histogram(temp_spike_time, bins=bin_num, range=(-window_auto,window_auto))[0]

    return hist_auto  #Norm_Hist_Auto
##############################################

def CalcPOW(acr,ex_file_path=''):
    sampling_rate = 1000
    freq, P = signal.periodogram(x=acr,fs=sampling_rate)
    roi_x_pow = np.array([0,80])
    
    ##オシレーションインデックス算出
    #Search_Area = np.array([5,15]) #5~10Hz
    #PWR_ROI_Index = np.where((freq >= Search_Area[0]) & (PWR[:,0] < Search_Area[1]))[0]
    #PWR_ROI = (np.max(PWR[PWR_ROI_Index,1]) - np.mean(PWR[PWR_ROI_Index,1]))/np.std(PWR[:,1])
    #return np.max(PWR[PWR_ROI_Index,1]), PWR[np.argmax(PWR[PWR_ROI_Index,1])+PWR_ROI_Index[0],0], np.mean(PWR[:,1]), np.std(PWR[:,1]), PWR_ROI
    
    # figure_pow = plt.figure()
    # pow_pointer = figure_pow.add_subplot(1,1,1)
    # pow_pointer.plot(freq,P,color='black')
    # pow_pointer.set_xlim(roi_x_pow)
    # pow_pointer.set_xlabel('Frequency[Hz]')
    # pow_pointer.set_ylabel('Power/frequency')
    # # cell_id = os.path.split(ex_file_path)[1]
    # pow_pointer.set_title(cell_id)
    # #plt.ylim(0, 0.0007)

    # plt.show()
    # plt.clf()
    # plt.close('all')
    
    return np.c_[freq,P]

def JudgeAcr(xAxis, acr):
    #クラスターが自己相関の基準を突破できるか判定
    AllIndex = np.where((xAxis >= -200) & (xAxis <= 200))[0] #-200 ms ~ 200 msのインデックスを取得
    SearchIndex =[998, 999, 1001, 1002] #-2,-1, 1,2 msのインデックスを取得

    AllACR = acr[AllIndex]
    SearchACR = acr[SearchIndex]

    # 基準値算出
    FireIndex = np.sum(SearchACR)/np.sum(AllACR)* 100 
    return FireIndex

def CalcCCR(spike_time1, spike_time2):
    #変数の定義
    #1000 ms
    window_auto = 1000
    #瓶サイズ
    binWidth_auto = 1 #[ms]
    #格納先の作成
    #瓶の数 
    #2000 / 1 + 1 = 2001
    bin_num = int(((window_auto * 2)/binWidth_auto) + 1)
    #2001個の0をhist_autoに格納
    hist_auto = np.zeros(bin_num)
    #全てのスパイク時間

############################################################################################
##変更点1： 
##①spike_time_all ⇨　cell1_spike_time_all，　
##②さらにcell2_spike_time_all = isi[:,2]を追加　
##
############################################################################################
    #スパイク発火の数
    isi_size = spike_time1.size
    ##ex) spike_time_all = [1,3,10,100,2100]
    for mid_search in spike_time1:
        ##1週目
        #mid_serch = 1
        # left_end = 1 - 1000 = -999 (ms)
        left_end = mid_search - window_auto
        #right_end = 1 + 1000 = 1001 (ms)
        right_end = mid_search + window_auto
        # -999(ms) ~  1001(ms)の間に存在するスパイク発火のインデックスをindex_serchに格納
        #該当するスパイク　 1, 3, 10, 100 　＊ spike_time_all内のスパイク発火時間
############################################################################################
##変更点2：         i
##①ndex_search = np.where((spike_time_all >= left_end)&(spike_time_all <= right_end))[0]
##このspike_time_all ⇨　cell1_spike_time_all，に変更
##② temp_spike_time = spike_time_all[index_search] - mid_search
##このspike_time_all ⇨　cell1_spike_time_all，に変更
############################################################################################
        #index_search = 0,1,2,3
        index_search = np.where((spike_time2 >= left_end)&(spike_time2 <= right_end))[0]
        #spike_time_all[index_search] = 1,3,10
        #temp_spike_time - 1 = 0, 2, 9
        temp_spike_time = spike_time2[index_search] - mid_search
        #ヒストグラム書く
        hist_auto += np.histogram(temp_spike_time, bins=bin_num, range=(-window_auto,window_auto))[0]
    return hist_auto #Norm_Hist_Auto
##############################################

def JudgeCcr(isi, PairList):
    NewCluNosAfterAcr = isi[:, 3]
    for pair in list(itertools.permutations((NewCluNosAfterAcr), 2)):
        if pair in PairList:
            continue
        CluNo1, CluNo2 = pair
        print(CluNo1, CluNo2)
        TempClusters = isi[:, 3]
        Tempindex1 = np.where(TempClusters == CluNo1)[0]
        Tempindex2 = np.where(TempClusters == CluNo2)[0]

        TempSpikeTime1 = isi[Tempindex1, 1]
        TempSpikeTime2 = isi[Tempindex2, 1]

        ccr = CalcCCR(TempSpikeTime1, TempSpikeTime2)

        #オートコレロをもとにした基準値を算出
        FireIndex = JudgeAcr(xAxis, ccr)
        print(FireIndex)

        if FireIndex < 1:
            NewCluNo = np.max(CluNosAfterAcr) + 1
            NewCluIndex = np.r_[Tempindex1, Tempindex2]
            np.put(TempClusters, NewCluIndex, NewCluNo)
            break
    return isi, PairList

def GetTemplates(waves):
    template = np.mean(waves, axis=0)
    template_sd = np.std(waves, axis=0)
    return np.array([template, template_sd])

def GetWaves(clu, result, wave_shape):
    temp_index = np.where(result==clu)[0]
    return wave_shape[temp_index,:]

# クラスタリング結果を元に波形からテンプレートを作成
def MakeTemplates(clu, result, wave_shape):
    return GetTemplates(GetWaves(clu, result, wave_shape))

def CheckTemplate(template, wave):
    temp_late, temp_late_sd = template
    temp_lower = wave > (temp_late-temp_late_sd)
    temp_upper = wave < (temp_late+temp_late_sd)
    temp_index = np.sum([temp_upper.T[11:14],  temp_lower.T[11:14]])
    if(temp_index == 6):
        #return np.array([np.sum([temp_lower, temp_upper]), temp_index]).astype(np.int)
        return np.array([np.sum([temp_lower, temp_upper]), temp_index]).astype(int)

    else:
        #return np.array([0, temp_index]).astype(np.int)
        return np.array([0, temp_index]).astype(int)


# マージ結果を元にクラスタを更新
def ChangeCluster(cluster, marges):
    new_cluster = cluster.copy()
    main_clus = marges[0]
    cluss = np.unique(cluster)
    for clus in np.unique(marges[1:]):
        new_cluster[cluster==clus] =  main_clus
    return new_cluster    

def MargeCluster_TM(cluster, wave_shape, thr_marge=115):
    #default value of thr_marge is set for waveform with 60 points.
    new_cluster = cluster.copy()
    clus_list = np.unique(cluster)[1:]
    templates = np.array(list(map(functools.partial(MakeTemplates, result=cluster, wave_shape=wave_shape), clus_list)))
    clus_score = np.zeros([clus_list.shape[0],clus_list.shape[0]])
    for i in range(clus_list.shape[0]):
        clus_score[:,i] = np.array(list(map(functools.partial(CheckTemplate, wave=templates[i,0]),templates)))[:,0]
        clus_score[:i+1,i] = 0
    marges = np.where(clus_score.flatten() >= thr_marge)[0]
    marges = np.c_[marges%clus_list.shape[0], marges//clus_list.shape[0]]
    if marges.shape[0] >= 1:
        for marge in reversed(marges):
            # print('marge clusters : ', marge)
            new_cluster = ChangeCluster(new_cluster, marge)
    # print(np.unique(new_cluster))
    return new_cluster

def ReclustNoise(noise_wave, templates, thr_socre=72):
    clu = -1
    clus_score = np.array(list(map(functools.partial(CheckTemplate, wave=noise_wave),templates)))
    max_index = np.argmax(clus_score[:,0])
    if((clus_score[max_index,0] > 72) & (clus_score[max_index,1] == 6)):
        clu = max_index
    return clu

# noiseに分類されたspikeをtemplate-matchingによって救済
def RescueNoise(cluster, wave_shape, thr_noise=72):
    new_cluster = cluster.copy()
    ori_clus = np.roll(np.unique(new_cluster), -1)
    noise_index = np.where(cluster==-1)[0]
    templates = np.array(list(map(functools.partial(MakeTemplates, result=cluster, wave_shape=wave_shape), np.unique(cluster)[1:])))
    noise_waves = GetWaves(-1, cluster, wave_shape)
    noise_reclust = np.array(list(map(functools.partial(ReclustNoise, templates = templates), noise_waves)))
    new_cluster[noise_index] = ori_clus[noise_reclust]
    return new_cluster

# raw dataに格納されている情報の列方向の順番
# CH_ARRAY = np.array(['trigger', 'ch47', 'ch48', 'ch46', 'ch45', 'ch38', 'ch37', 'ch28',
# 'ch36', 'ch27', 'ch17', 'ch26', 'ch16', 'ch35', 'ch25', 'ch15', 'ch14', 'ch24', 'ch34',
# 'ch13', 'ch23', 'ch12', 'ch22', 'ch33', 'ch21', 'ch32', 'ch31', 'ch44', 'ch43', 'ch41',
# 'ch42', 'ch52', 'ch51', 'ch53', 'ch54', 'ch61', 'ch62', 'ch71', 'ch63', 'ch72', 'ch82',
# 'ch73', 'ch83', 'ch64', 'ch74', 'ch84', 'ch85', 'ch75', 'ch65', 'ch86', 'ch76', 'ch87',
# 'ch77', 'ch66', 'ch78', 'ch67', 'ch68', 'ch55', 'ch56', 'ch58', 'ch57'])

CH_ARRAY = np.array(['trigger', 'ch21', 'ch31', 'ch41', 'ch51', 'ch61', 'ch71', 'ch12',
'ch22', 'ch32', 'ch42', 'ch52', 'ch62', 'ch72', 'ch82', 'ch13', 'ch23', 'ch33', 'ch43',
'ch53', 'ch63', 'ch73', 'ch83', 'ch14', 'ch24', 'ch34', 'ch44', 'ch54', 'ch64', 'ch74',
'ch84', 'ch15', 'ch25', 'ch35', 'ch45', 'ch55', 'ch65', 'ch75', 'ch85', 'ch16', 'ch26',
'ch36', 'ch46', 'ch56', 'ch66', 'ch76', 'ch86', 'ch17', 'ch27', 'ch37', 'ch47', 'ch57',
'ch67', 'ch77', 'ch87', 'ch28', 'ch38', 'ch48', 'ch58', 'ch68', 'ch78'])

## 可視化する時のファイル形式
PIC_EXT = '.png'

## 可視化する際のグラフの線の色
COLOR = ['b', 'chartreuse', 'r', 'c', 'm', 'y', 'k', 'Brown', 'ForestGreen', 'darkcyan', 'maroon', 'orange', 'green', 'steelblue', 'purple', 'gold', 'navy', 'gray', 'indigo', 'black', 'darkgoldenrod']

## secondからmili secondへの変換
MS = 1000

# 関数化

In [3]:
import os

raw_path = "C:/Users/Imaris/Desktop/watanabe/250801/リポジトリ/test2ch.raw"
print("raw_path", raw_path)

# 変数
data_name = os.path.basename(raw_path).split(".")[0]
channels = ["ch21", "ch55"]
channel_num = len(channels)


raw_path C:/Users/Imaris/Desktop/watanabe/250801/リポジトリ/test2ch.raw


In [4]:
import numpy as np
from typing import Optional

def load_raw_matrix(path: "path", n_channels: int, dtype="h", scale=10.0) -> np.ndarray:
    """RAW -> (n_samples, n_channels) の 2 次元配列に変換"""
    data = np.fromfile(path, dtype=dtype)
    if data.size % n_channels:
        raise ValueError(f"RAW size {data.size} not divisible by n_channels={n_channels}")
    return data.reshape(-1, n_channels).astype(np.float32) / scale


def extract_channel(
    raw_matrix: np.ndarray,
    ch: int,
    ch_array:  Optional[np.ndarray] = None,
) -> np.ndarray:
    """
    raw_matrix: load_raw_matrix で得た 2D 配列
    ch        : 取り出したいチャンネル番号 (0-based)
    ch_array  : 旧コードと同じ 'ch47', 'ch48'... などの順序を保持している配列（任意）
    """
    if raw_matrix.ndim != 2:
        raise ValueError("raw_matrix must be 2D")

    if raw_matrix.shape[1] == 1:
        # 1 ch RAW の場合はそのまま返す
        return raw_matrix[:, 0]

    if ch_array is None:
        # RAW が単純な 0..N-1 の並びなら ch を直接使う
        if ch < 0 or ch >= raw_matrix.shape[1]:
            raise IndexError(f"channel {ch} out of range (0..{raw_matrix.shape[1]-1})")
        return raw_matrix[:, ch]

    # 旧ノートのように CH_ARRAY で並びを決めている場合
    mask = ch_array == ch
    if not np.any(mask):
        raise ValueError(f"{ch} は CH_ARRAY に存在しません")
    return raw_matrix[:, mask][:, 0]


In [5]:
# #入力データの準備
# raw_matrix = load_raw_matrix(raw_path, channel_num)
# wave = extract_channel(raw_matrix, ch)


In [6]:
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import itertools
import matplotlib.pyplot as plt



@dataclass
class ChannelMeta:
    data_name: str
    channel_label: str  # e.g. 'ch000'
    figure_dirs: dict[str, Path]  # 'spike_detect', 'pca', 'sorting_cluster', 'auto_correlo'
    h5_path: Path
    log_path: Path
    sampling_rate: int


def process_channel(raw_wave: np.ndarray, cfg) -> dict:
    """
    1 チャネル分の波形を受け取り、フィルタ〜クラスタリングまで処理する。
    戻り値は後段の可視化/保存で使う dict。
    """
    filtered = BandPassFilter(
        raw_wave,
        bottom=cfg.band_bottom,
        top=cfg.band_top,
        sampling_rate=cfg.fs,
    )
    spike_idx = SpikeDetection(
        filtered,
        sd_thr=cfg.spike_threshold_sd,
        order=cfg.spike_order,
        spike=cfg.spike_polarity,
    )

    waveforms, spike_idx = GetWaveShape(
        spike_idx,
        filtered,
        area_before_peak_ms=cfg.window_before_ms,
        area_after_peak_ms=cfg.window_after_ms,
        sampling_rate=cfg.fs,
        ms=MS,
    )
    if spike_idx.size == 0:
        return {
            "filtered": filtered,
            "spike_idx": spike_idx,
            "waveforms": np.empty((0, 0)),
            "waveforms_roi": np.empty((0, 0)),
            "features": np.empty((0, cfg.pca_components)),
            "variance": np.zeros(cfg.pca_components),
            "labels": np.array([], dtype=int),
            "spike_times_ms": np.array([]),
        }

    waveforms_roi = CutWaveShape(waveforms, area=cfg.cut_area)
    x_pca, variance = DimensionalityReductionWithDiffs(waveforms_roi, cfg.pca_components)
    features = StandardScaler().fit_transform(x_pca)

    clusters = ClusteringWithHDBSCAN(features, clu_size=cfg.cluster_min_size, min_sam=cfg.cluster_min_samples)
    merged = MargeCluster_TM(cluster=clusters, wave_shape=waveforms, thr_marge=cfg.template_merge_score)
    refined = RescueNoise(
        cluster=merged,
        wave_shape=CutWaveShape(waveforms, area=cfg.cut_area),
        thr_noise=cfg.noise_reassign_score,
    )

    spike_times_ms = spike_idx / (cfg.fs / MS)
    isi = np.c_[np.arange(1, spike_times_ms.size + 1), spike_times_ms, np.diff(np.r_[0, spike_times_ms]), refined]

    return {
        "filtered": filtered,
        "spike_idx": spike_idx,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "features": features,
        "variance": variance,
        "labels": refined,
        "spike_times_ms": spike_times_ms,
        "isi": isi,
    }


def visualize_and_save(meta: ChannelMeta, result: dict) -> None:
    """
    process_channel の戻り値を受け取り、図を保存しつつ HDF5 とログに結果をまとめる。
    """
    filtered = result["filtered"]
    spike_idx = result["spike_idx"]
    waveforms = result["waveforms"]
    waveforms_roi = result["waveforms_roi"]
    features = result["features"]
    labels = result["labels"]
    spike_times_ms = result["spike_times_ms"]
    variance = result["variance"]

    total_spikes = int(spike_idx.size)
    if labels.size:
        unique_labels, counts = np.unique(labels, return_counts=True)
    else:
        unique_labels = np.array([], dtype=int)
        counts = np.array([], dtype=int)
    positive_labels = unique_labels[unique_labels >= 0]

    log_lines = [
        f"[{meta.channel_label}]",
        f"total_spikes: {total_spikes}",
        f"detected_clusters: {int(positive_labels.size)}",
    ]
    if variance.size and total_spikes:
        log_lines.append("pca_variance: " + ", ".join(f"{v:.4f}" for v in np.atleast_1d(variance)))
    if unique_labels.size:
        log_lines.append("cluster_counts:")
        for clu, count in zip(unique_labels, counts):
            label_name = f"cluster {int(clu)}" if clu >= 0 else "noise"
            log_lines.append(f"  {label_name}: {int(count)}")
    else:
        log_lines.append("cluster_counts: none")
    acr_logs = []

    # 生波形の可視化
    plt.figure(figsize=(10, 3))
    t = np.arange(filtered.size) / meta.sampling_rate
    plt.plot(t, filtered, lw=0.5, color="steelblue")
    if spike_idx.size:
        plt.plot(spike_idx / meta.sampling_rate, filtered[spike_idx], "r.", ms=3)
    plt.xlabel("Time [s]")
    plt.ylabel("Filtered (uV)")
    plt.tight_layout()
    plt.savefig(meta.figure_dirs["spike_detect"] / f"{meta.data_name}_{meta.channel_label}_spike_detect.png")
    plt.close()

    # PCA 図
    if features.size:
        plt.figure(figsize=(5, 4))
        plt.scatter(features[:, 0], features[:, 1], s=5, c="gray", alpha=0.5)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA scatter {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_raw.png")
        plt.close()

        plt.figure(figsize=(5, 4))
        for clu, color in zip(unique_labels, itertools.cycle(COLOR)):
            mask = labels == clu
            if not mask.any():
                continue
            plt.scatter(features[mask, 0], features[mask, 1], s=8, alpha=0.7, label=f"clu {clu}", c=color)
        plt.legend(fontsize=8)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA clustered {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_cluster.png")
        plt.close()

    # クラスタ波形
    for clu, color in zip(unique_labels, itertools.cycle(COLOR)):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        plt.figure(figsize=(6, 3))
        plt.plot(waveforms[mask].T, color=color, alpha=0.2, lw=0.5)
        plt.plot(np.median(waveforms[mask], axis=0), color=color, lw=2)
        plt.title(f"{meta.channel_label} cluster {clu} (n={mask.sum()})")
        plt.xlabel("Samples")
        plt.ylabel("uV")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["sorting_cluster"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_waveforms.png")
        plt.close()

    # ACR とパワースペクトル
    x_axis = np.arange(-1000, 1001)
    for clu in unique_labels:
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        acr = CalcACR(spike_times_ms[mask])
        fire_index = JudgeAcr(x_axis, acr)
        acr_logs.append(
            f"  cluster {int(clu)}: n={int(mask.sum())}, fire_index={fire_index:.3f}% ({'PASS' if fire_index <= 1.0 else 'FAIL'})"
        )

        plt.figure(figsize=(5, 3))
        plt.plot(x_axis, acr, color="black", lw=1)
        plt.xlim(-200, 200)
        plt.xlabel("Time lag [ms]")
        plt.ylabel("Autocorrelation")
        plt.title(f"{meta.channel_label} cluster {clu} autocorrelogram")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_acr.png")
        plt.close()

        freqP = CalcPOW(acr, ex_file_path="")
        plt.figure(figsize=(5, 3))
        plt.plot(freqP[:, 0], freqP[:, 1], color="black", lw=1)
        plt.xlim(0, 80)
        plt.xlabel("Frequency [Hz]")
        plt.ylabel("Power/frequency")
        plt.title(f"{meta.channel_label} cluster {clu} power spectrum")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_power.png")
        plt.close()

    # HDF5 保存
    datasets = {
        "labels": labels,
        "spike_times_ms": spike_times_ms,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "variance": variance,
    }
    with h5py.File(meta.h5_path, "a") as h5:
        group = h5.require_group(meta.channel_label)
        for name, data in datasets.items():
            if name in group:
                del group[name]
            if data.size:
                group.create_dataset(name, data=data, compression="gzip", compression_opts=4)
            else:
                group.create_dataset(name, shape=data.shape, dtype=data.dtype)

    if acr_logs:
        log_lines.append("acr_tests:")
        log_lines.extend(acr_logs)
    else:
        log_lines.append("acr_tests: n/a")

    with meta.log_path.open("a", encoding="utf-8") as fp:
        fp.write("\n".join(log_lines) + "\n\n")



In [7]:
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import itertools
import matplotlib.pyplot as plt



@dataclass
class ChannelMeta:
    data_name: str
    channel_label: str  # e.g. 'ch000'
    figure_dirs: dict[str, Path]  # 'spike_detect', 'pca', 'sorting_cluster', 'auto_correlo'
    h5_path: Path
    log_path: Path
    sampling_rate: int


def process_channel(raw_wave: np.ndarray, cfg) -> dict:
    """
    1 チャンネル分の波形からクラスタリング結果までをまとめて計算。
    戻り値は可視化・保存で使う情報を dict で返す。
    """
    filtered = BandPassFilter(
        raw_wave,
        bottom=cfg.band_bottom,
        top=cfg.band_top,
        sampling_rate=cfg.fs,
    )
    spike_idx = SpikeDetection(
        filtered,
        sd_thr=cfg.spike_threshold_sd,
        order=cfg.spike_order,
        spike=cfg.spike_polarity,
    )

    waveforms, spike_idx = GetWaveShape(
        spike_idx,
        filtered,
        area_before_peak_ms=cfg.window_before_ms,
        area_after_peak_ms=cfg.window_after_ms,
        sampling_rate=cfg.fs,
        ms=MS,
    )
    if spike_idx.size == 0:
        return {
            "filtered": filtered,
            "spike_idx": spike_idx,
            "waveforms": np.empty((0, 0)),
            "waveforms_roi": np.empty((0, 0)),
            "features": np.empty((0, cfg.pca_components)),
            "variance": np.zeros(cfg.pca_components),
            "labels": np.array([], dtype=int),
            "spike_times_ms": np.array([]),
        }

    waveforms_roi = CutWaveShape(waveforms, area=cfg.cut_area)
    x_pca, variance = DimensionalityReductionWithDiffs(waveforms_roi, cfg.pca_components)
    features = StandardScaler().fit_transform(x_pca)

    clusters = ClusteringWithHDBSCAN(features, clu_size=cfg.cluster_min_size, min_sam=cfg.cluster_min_samples)
    merged = MargeCluster_TM(cluster=clusters, wave_shape=waveforms, thr_marge=cfg.template_merge_score)
    refined = RescueNoise(
        cluster=merged,
        wave_shape=CutWaveShape(waveforms, area=cfg.cut_area),
        thr_noise=cfg.noise_reassign_score,
    )

    spike_times_ms = spike_idx / (cfg.fs / MS)
    isi = np.c_[np.arange(1, spike_times_ms.size + 1), spike_times_ms, np.diff(np.r_[0, spike_times_ms]), refined]

    return {
        "filtered": filtered,
        "spike_idx": spike_idx,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "features": features,
        "variance": variance,
        "labels": refined,
        "spike_times_ms": spike_times_ms,
        "isi": isi,
    }


def visualize_and_save(meta: ChannelMeta, result: dict) -> None:
    """
    Receive process_channel output, produce plots, and persist results in HDF5.
    """
    filtered = result["filtered"]
    spike_idx = result["spike_idx"]
    waveforms = result["waveforms"]
    waveforms_roi = result["waveforms_roi"]
    features = result["features"]
    labels = result["labels"]
    spike_times_ms = result["spike_times_ms"]
    variance = result["variance"]

    # 検出波形
    plt.figure(figsize=(10, 3))
    t = np.arange(filtered.size) / meta.sampling_rate
    plt.plot(t, filtered, lw=0.5, color="steelblue")
    if spike_idx.size:
        plt.plot(spike_idx / meta.sampling_rate, filtered[spike_idx], "r.", ms=3)
    plt.xlabel("Time [s]")
    plt.ylabel("Filtered (uV)")
    plt.tight_layout()
    plt.savefig(meta.figure_dirs["spike_detect"] / f"{meta.data_name}_{meta.channel_label}_spike_detect.png")
    plt.close()

    # PCA 図
    if features.size:
        plt.figure(figsize=(5, 4))
        plt.scatter(features[:, 0], features[:, 1], s=5, c="gray", alpha=0.5)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA scatter {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_raw.png")
        plt.close()

        plt.figure(figsize=(5, 4))
        for clu, color in zip(np.unique(labels), itertools.cycle(COLOR)):
            mask = labels == clu
            plt.scatter(features[mask, 0], features[mask, 1], s=8, alpha=0.7, label=f"clu {clu}", c=color)
        plt.legend(fontsize=8)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA clustered {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_cluster.png")
        plt.close()

    # クラスタ波形
    for clu, color in zip(np.unique(labels), itertools.cycle(COLOR)):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        plt.figure(figsize=(6, 3))
        plt.plot(waveforms[mask].T, color=color, alpha=0.2, lw=0.5)
        plt.plot(np.median(waveforms[mask], axis=0), color=color, lw=2)
        plt.title(f"{meta.channel_label} cluster {clu} (n={mask.sum()})")
        plt.xlabel("Samples")
        plt.ylabel("uV")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["sorting_cluster"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_waveforms.png")
        plt.close()

    # ACR とパワースペクトラム
    x_axis = np.arange(-1000, 1001)
    for clu in np.unique(labels):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        acr = CalcACR(spike_times_ms[mask])

        plt.figure(figsize=(5, 3))
        plt.plot(x_axis, acr, color="black", lw=1)
        plt.xlim(-200, 200)
        plt.xlabel("Time lag [ms]")
        plt.ylabel("Autocorrelation")
        plt.title(f"{meta.channel_label} cluster {clu} autocorrelogram")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_acr.png")
        plt.close()

        freqP = CalcPOW(acr, ex_file_path="")
        plt.figure(figsize=(5, 3))
        plt.plot(freqP[:, 0], freqP[:, 1], color="black", lw=1)
        plt.xlim(0, 80)
        plt.xlabel("Frequency [Hz]")
        plt.ylabel("Power/frequency")
        plt.title(f"{meta.channel_label} cluster {clu} power spectrum")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_power.png")
        plt.close()

    # HDF5 保存
    datasets = {
        "labels": labels,
        "spike_times_ms": spike_times_ms,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "variance": variance,
    }
    with h5py.File(meta.h5_path, "a") as h5:
        group = h5.require_group(meta.channel_label)
        for name, data in datasets.items():
            if name in group:
                del group[name]
            if data.size:
                group.create_dataset(name, data=data, compression="gzip", compression_opts=4)
            else:
                group.create_dataset(name, shape=data.shape, dtype=data.dtype)


In [8]:
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import itertools
import matplotlib.pyplot as plt



@dataclass
class ChannelMeta:
    data_name: str
    channel_label: str  # e.g. 'ch000'
    figure_dirs: dict[str, Path]  # 'spike_detect', 'pca', 'sorting_cluster', 'auto_correlo'
    h5_path: Path
    log_path: Path
    sampling_rate: int


def process_channel(raw_wave: np.ndarray, cfg) -> dict:
    """
    1 チャンネル分の波形からクラスタリング結果までをまとめて計算。
    戻り値は可視化・保存で使う情報を dict で返す。
    """
    filtered = BandPassFilter(
        raw_wave,
        bottom=cfg.band_bottom,
        top=cfg.band_top,
        sampling_rate=cfg.fs,
    )
    spike_idx = SpikeDetection(
        filtered,
        sd_thr=cfg.spike_threshold_sd,
        order=cfg.spike_order,
        spike=cfg.spike_polarity,
    )

    waveforms, spike_idx = GetWaveShape(
        spike_idx,
        filtered,
        area_before_peak_ms=cfg.window_before_ms,
        area_after_peak_ms=cfg.window_after_ms,
        sampling_rate=cfg.fs,
        ms=MS,
    )
    if spike_idx.size == 0:
        return {
            "filtered": filtered,
            "spike_idx": spike_idx,
            "waveforms": np.empty((0, 0)),
            "waveforms_roi": np.empty((0, 0)),
            "features": np.empty((0, cfg.pca_components)),
            "variance": np.zeros(cfg.pca_components),
            "labels": np.array([], dtype=int),
            "spike_times_ms": np.array([]),
        }

    waveforms_roi = CutWaveShape(waveforms, area=cfg.cut_area)
    x_pca, variance = DimensionalityReductionWithDiffs(waveforms_roi, cfg.pca_components)
    features = StandardScaler().fit_transform(x_pca)

    clusters = ClusteringWithHDBSCAN(features, clu_size=cfg.cluster_min_size, min_sam=cfg.cluster_min_samples)
    merged = MargeCluster_TM(cluster=clusters, wave_shape=waveforms, thr_marge=cfg.template_merge_score)
    refined = RescueNoise(
        cluster=merged,
        wave_shape=CutWaveShape(waveforms, area=cfg.cut_area),
        thr_noise=cfg.noise_reassign_score,
    )

    spike_times_ms = spike_idx / (cfg.fs / MS)
    isi = np.c_[np.arange(1, spike_times_ms.size + 1), spike_times_ms, np.diff(np.r_[0, spike_times_ms]), refined]

    return {
        "filtered": filtered,
        "spike_idx": spike_idx,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "features": features,
        "variance": variance,
        "labels": refined,
        "spike_times_ms": spike_times_ms,
        "isi": isi,
    }


def visualize_and_save(meta: ChannelMeta, result: dict) -> None:
    """
    Receive process_channel output, produce plots, and persist results in HDF5.
    """
    filtered = result["filtered"]
    spike_idx = result["spike_idx"]
    waveforms = result["waveforms"]
    waveforms_roi = result["waveforms_roi"]
    features = result["features"]
    labels = result["labels"]
    spike_times_ms = result["spike_times_ms"]
    variance = result["variance"]

    # 検出波形
    plt.figure(figsize=(10, 3))
    t = np.arange(filtered.size) / meta.sampling_rate
    plt.plot(t, filtered, lw=0.5, color="steelblue")
    if spike_idx.size:
        plt.plot(spike_idx / meta.sampling_rate, filtered[spike_idx], "r.", ms=3)
    plt.xlabel("Time [s]")
    plt.ylabel("Filtered (uV)")
    plt.tight_layout()
    plt.savefig(meta.figure_dirs["spike_detect"] / f"{meta.data_name}_{meta.channel_label}_spike_detect.png")
    plt.close()

    # PCA 図
    if features.size:
        plt.figure(figsize=(5, 4))
        plt.scatter(features[:, 0], features[:, 1], s=5, c="gray", alpha=0.5)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA scatter {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_raw.png")
        plt.close()

        plt.figure(figsize=(5, 4))
        for clu, color in zip(np.unique(labels), itertools.cycle(COLOR)):
            mask = labels == clu
            plt.scatter(features[mask, 0], features[mask, 1], s=8, alpha=0.7, label=f"clu {clu}", c=color)
        plt.legend(fontsize=8)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA clustered {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_cluster.png")
        plt.close()

    # クラスタ波形
    for clu, color in zip(np.unique(labels), itertools.cycle(COLOR)):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        plt.figure(figsize=(6, 3))
        plt.plot(waveforms[mask].T, color=color, alpha=0.2, lw=0.5)
        plt.plot(np.median(waveforms[mask], axis=0), color=color, lw=2)
        plt.title(f"{meta.channel_label} cluster {clu} (n={mask.sum()})")
        plt.xlabel("Samples")
        plt.ylabel("uV")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["sorting_cluster"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_waveforms.png")
        plt.close()

    # ACR とパワースペクトラム
    x_axis = np.arange(-1000, 1001)
    for clu in np.unique(labels):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        acr = CalcACR(spike_times_ms[mask])

        plt.figure(figsize=(5, 3))
        plt.plot(x_axis, acr, color="black", lw=1)
        plt.xlim(-200, 200)
        plt.xlabel("Time lag [ms]")
        plt.ylabel("Autocorrelation")
        plt.title(f"{meta.channel_label} cluster {clu} autocorrelogram")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_acr.png")
        plt.close()

        freqP = CalcPOW(acr, ex_file_path="")
        plt.figure(figsize=(5, 3))
        plt.plot(freqP[:, 0], freqP[:, 1], color="black", lw=1)
        plt.xlim(0, 80)
        plt.xlabel("Frequency [Hz]")
        plt.ylabel("Power/frequency")
        plt.title(f"{meta.channel_label} cluster {clu} power spectrum")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_power.png")
        plt.close()

    # HDF5 保存
    datasets = {
        "labels": labels,
        "spike_times_ms": spike_times_ms,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "variance": variance,
    }
    with h5py.File(meta.h5_path, "a") as h5:
        group = h5.require_group(meta.channel_label)
        for name, data in datasets.items():
            if name in group:
                del group[name]
            if data.size:
                group.create_dataset(name, data=data, compression="gzip", compression_opts=4)
            else:
                group.create_dataset(name, shape=data.shape, dtype=data.dtype)


In [9]:
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler
import itertools
import matplotlib.pyplot as plt



@dataclass
class ChannelMeta:
    data_name: str
    channel_label: str  # e.g. 'ch000'
    figure_dirs: dict[str, Path]  # 'spike_detect', 'pca', 'sorting_cluster', 'auto_correlo'
    h5_path: Path
    log_path: Path
    sampling_rate: int


def process_channel(raw_wave: np.ndarray, cfg) -> dict:
    """
    1 チャンネル分の波形からクラスタリング結果までをまとめて計算。
    戻り値は可視化・保存で使う情報を dict で返す。
    """
    filtered = BandPassFilter(
        raw_wave,
        bottom=cfg.band_bottom,
        top=cfg.band_top,
        sampling_rate=cfg.fs,
    )
    spike_idx = SpikeDetection(
        filtered,
        sd_thr=cfg.spike_threshold_sd,
        order=cfg.spike_order,
        spike=cfg.spike_polarity,
    )

    waveforms, spike_idx = GetWaveShape(
        spike_idx,
        filtered,
        area_before_peak_ms=cfg.window_before_ms,
        area_after_peak_ms=cfg.window_after_ms,
        sampling_rate=cfg.fs,
        ms=MS,
    )
    if spike_idx.size == 0:
        return {
            "filtered": filtered,
            "spike_idx": spike_idx,
            "waveforms": np.empty((0, 0)),
            "waveforms_roi": np.empty((0, 0)),
            "features": np.empty((0, cfg.pca_components)),
            "variance": np.zeros(cfg.pca_components),
            "labels": np.array([], dtype=int),
            "spike_times_ms": np.array([]),
        }

    waveforms_roi = CutWaveShape(waveforms, area=cfg.cut_area)
    x_pca, variance = DimensionalityReductionWithDiffs(waveforms_roi, cfg.pca_components)
    features = StandardScaler().fit_transform(x_pca)

    clusters = ClusteringWithHDBSCAN(features, clu_size=cfg.cluster_min_size, min_sam=cfg.cluster_min_samples)
    merged = MargeCluster_TM(cluster=clusters, wave_shape=waveforms, thr_marge=cfg.template_merge_score)
    refined = RescueNoise(
        cluster=merged,
        wave_shape=CutWaveShape(waveforms, area=cfg.cut_area),
        thr_noise=cfg.noise_reassign_score,
    )

    spike_times_ms = spike_idx / (cfg.fs / MS)
    isi = np.c_[np.arange(1, spike_times_ms.size + 1), spike_times_ms, np.diff(np.r_[0, spike_times_ms]), refined]

    return {
        "filtered": filtered,
        "spike_idx": spike_idx,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "features": features,
        "variance": variance,
        "labels": refined,
        "spike_times_ms": spike_times_ms,
        "isi": isi,
    }


def visualize_and_save(meta: ChannelMeta, result: dict) -> None:
    """
    Receive process_channel output, produce plots, and persist results in HDF5.
    """
    filtered = result["filtered"]
    spike_idx = result["spike_idx"]
    waveforms = result["waveforms"]
    waveforms_roi = result["waveforms_roi"]
    features = result["features"]
    labels = result["labels"]
    spike_times_ms = result["spike_times_ms"]
    variance = result["variance"]

    # 検出波形
    plt.figure(figsize=(10, 3))
    t = np.arange(filtered.size) / meta.sampling_rate
    plt.plot(t, filtered, lw=0.5, color="steelblue")
    if spike_idx.size:
        plt.plot(spike_idx / meta.sampling_rate, filtered[spike_idx], "r.", ms=3)
    plt.xlabel("Time [s]")
    plt.ylabel("Filtered (uV)")
    plt.tight_layout()
    plt.savefig(meta.figure_dirs["spike_detect"] / f"{meta.data_name}_{meta.channel_label}_spike_detect.png")
    plt.close()

    # PCA 図
    if features.size:
        plt.figure(figsize=(5, 4))
        plt.scatter(features[:, 0], features[:, 1], s=5, c="gray", alpha=0.5)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA scatter {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_raw.png")
        plt.close()

        plt.figure(figsize=(5, 4))
        for clu, color in zip(np.unique(labels), itertools.cycle(COLOR)):
            mask = labels == clu
            plt.scatter(features[mask, 0], features[mask, 1], s=8, alpha=0.7, label=f"clu {clu}", c=color)
        plt.legend(fontsize=8)
        plt.xlabel("PC1")
        plt.ylabel("PC2")
        plt.title(f"PCA clustered {meta.data_name}_{meta.channel_label}")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["pca"] / f"{meta.data_name}_{meta.channel_label}_pca_cluster.png")
        plt.close()

    # クラスタ波形
    for clu, color in zip(np.unique(labels), itertools.cycle(COLOR)):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        plt.figure(figsize=(6, 3))
        plt.plot(waveforms[mask].T, color=color, alpha=0.2, lw=0.5)
        plt.plot(np.median(waveforms[mask], axis=0), color=color, lw=2)
        plt.title(f"{meta.channel_label} cluster {clu} (n={mask.sum()})")
        plt.xlabel("Samples")
        plt.ylabel("uV")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["sorting_cluster"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_waveforms.png")
        plt.close()

    # ACR とパワースペクトラム
    x_axis = np.arange(-1000, 1001)
    for clu in np.unique(labels):
        mask = labels == clu
        if clu < 0 or not mask.any():
            continue
        acr = CalcACR(spike_times_ms[mask])

        plt.figure(figsize=(5, 3))
        plt.plot(x_axis, acr, color="black", lw=1)
        plt.xlim(-200, 200)
        plt.xlabel("Time lag [ms]")
        plt.ylabel("Autocorrelation")
        plt.title(f"{meta.channel_label} cluster {clu} autocorrelogram")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_acr.png")
        plt.close()

        freqP = CalcPOW(acr, ex_file_path="")
        plt.figure(figsize=(5, 3))
        plt.plot(freqP[:, 0], freqP[:, 1], color="black", lw=1)
        plt.xlim(0, 80)
        plt.xlabel("Frequency [Hz]")
        plt.ylabel("Power/frequency")
        plt.title(f"{meta.channel_label} cluster {clu} power spectrum")
        plt.tight_layout()
        plt.savefig(meta.figure_dirs["auto_correlo"] / f"{meta.data_name}_{meta.channel_label}_cluster{clu}_power.png")
        plt.close()

    # HDF5 保存
    datasets = {
        "labels": labels,
        "spike_times_ms": spike_times_ms,
        "waveforms": waveforms,
        "waveforms_roi": waveforms_roi,
        "variance": variance,
    }
    with h5py.File(meta.h5_path, "a") as h5:
        group = h5.require_group(meta.channel_label)
        for name, data in datasets.items():
            if name in group:
                del group[name]
            if data.size:
                group.create_dataset(name, data=data, compression="gzip", compression_opts=4)
            else:
                group.create_dataset(name, shape=data.shape, dtype=data.dtype)


In [9]:
from dataclasses import dataclass, field
import numpy as np

@dataclass
class SortingConfig:
    fs: int = 20_000
    band_bottom: int = 300
    band_top: int = 3_000
    band_numtaps: int = 255
    spike_threshold_sd: float = 4.0
    spike_order: int = 15
    spike_polarity: int = -1
    window_before_ms: float = 1.0
    window_after_ms: float = 2.0
    cut_area: int = 13
    pca_components: int = 2
    cluster_min_size: int = 2000
    cluster_min_samples: int = 250
    template_merge_score: int = 115
    noise_reassign_score: int = 72
    ch_array: np.ndarray = field(
        default_factory=lambda: np.array([
            'trigger', 'ch21', 'ch31', 'ch41', 'ch51', 'ch61', 'ch71', 'ch12',
            'ch22', 'ch32', 'ch42', 'ch52', 'ch62', 'ch72', 'ch82', 'ch13', 'ch23',
            'ch33', 'ch43', 'ch53', 'ch63', 'ch73', 'ch83', 'ch14', 'ch24', 'ch34',
            'ch44', 'ch54', 'ch64', 'ch74', 'ch84', 'ch15', 'ch25', 'ch35', 'ch45',
            'ch55', 'ch65', 'ch75', 'ch85', 'ch16', 'ch26', 'ch36', 'ch46', 'ch56',
            'ch66', 'ch76', 'ch86', 'ch17', 'ch27', 'ch37', 'ch47', 'ch57', 'ch67',
            'ch77', 'ch87', 'ch28', 'ch38', 'ch48', 'ch58', 'ch68', 'ch78'
        ])
    )



In [10]:
# raw_path = Path(raw_path)              # 文字列なら Path 化
# data_name = raw_path.stem             # 例: '01_Flash_0001'
# base_dir = Path(output_root)
# dirs = setup_output_dirs(base_dir)    # 例として spike_detect / pca / sorting_cluster / auto_correlo / npy を返す関数
# figure_dirs = {
#     "spike_detect": dirs["spike_detect"],
#     "pca": dirs["pca"],
#     "sorting_cluster": dirs["sorting_cluster"],
#     "auto_correlo": dirs["auto_correlo"],
# }

In [11]:
# load_raw_matrix は後続セルでチャネル設定と合わせて実行します


In [12]:
base_dir = Path("./temp_results")  # 出力先
base_dir.mkdir(parents=True, exist_ok=True)

figure_dirs = {
    "spike_detect": base_dir / "spike_detect",
    "pca": base_dir / "pca",
    "sorting_cluster": base_dir / "sorting_cluster",
    "auto_correlo": base_dir / "auto_correlo",
}
for path in figure_dirs.values():
    path.mkdir(parents=True, exist_ok=True)

h5_path = base_dir / f"{data_name}_spike_sort.h5"
if h5_path.exists():
    h5_path.unlink()

log_path = base_dir / f"{data_name}_spike_sort.log"
with log_path.open("w", encoding="utf-8") as fp:
    fp.write(f"Spike sorting log for {data_name} (created {datetime.now().isoformat()})\n\n")


In [13]:
channel_array = np.array(channels, dtype=str)
cfg = SortingConfig(ch_array=channel_array)
raw_matrix = load_raw_matrix(raw_path, channel_num)

for ch_label in channel_array[0:1]:
    wave = extract_channel(raw_matrix, ch_label, ch_array=channel_array)
    result = process_channel(wave, cfg)
    meta = ChannelMeta(data_name, ch_label, figure_dirs, h5_path, log_path, cfg.fs)
    visualize_and_save(meta, result)


TypeError: __init__() takes 6 positional arguments but 7 were given

In [67]:
from dataclasses import fields

print([f.name for f in fields(ChannelMeta)])



['data_name', 'channel_label', 'figure_dirs', 'h5_path', 'sampling_rate']


TypeError: __init__() takes 6 positional arguments but 7 were given

TypeError: __init__() takes 6 positional arguments but 7 were given

In [None]:
from pathlib import Path
import os
print(os.getcwd())  # まずはカレントディレクトリ確認

base = Path(r"C:/Users/Imaris/Desktop/watanabe/250801/リポジトリ/spike_sort_test_20250919_155648/npy")
for path in base.glob("*.npy"):
    data = np.load(path, allow_pickle=True)
    print(path.name, data.shape, data[:10] if data.ndim == 1 else data[:10, :5])

c:\Users\Imaris\Desktop\watanabe\250801\リポジトリ
test_ch000_cluster0_spikes.npy (10819,) [2069.65 9101.35 9116.8  9135.45 9138.8  9148.55 9206.75 9233.8  9279.6
 9305.4 ]
test_ch000_cluster1_spikes.npy (4928,) [ 963.05 2614.7  3001.25 3925.8  4004.15 4131.2  4537.05 4705.6  5235.7
 5284.9 ]
test_ch000_cluster2_spikes.npy (45117,) [ 60.25 200.25 268.15 385.9  439.   512.1  592.85 796.45 798.75 812.7 ]
test_ch000_cluster_labels.npy (151881,) [-1 -1  2 -1 -1 -1  2 -1  2 -1]
test_ch000_spike_times_ms.npy (151881,) [ 15.85  58.35  60.25  87.85 134.75 145.15 200.25 258.8  268.15 296.25]
test_ch000_variance.npy (2,) [0.55240433 0.09035588]
