# Libraries

In [None]:
import os
import math
import numpy as np
from scipy import signal
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 12

axisFontSize = 14
titleFontSize = 20

# Parameters

## Data Path

In [None]:
tracePath = 'C:/Users/01sun/source/repos/raspberrypi/trace_20250124'
traceNum = 1000000

## Spectrogram

In [None]:
samplingFrequency = int(5e9)  # unit: Hz (set this from the setting of oscilloscope)
windowLength      = 1000      # The number of samples in the window.
noverlap          = None    # Number of points to overlap between segments. If None, noverlap = nperseg // 8. Defaults to None.

## FIR (Finite Impulse Response) filter

In [None]:
filterOrder  = int(1e2 - 1)            # Length of the filter (number of coefficients, i.e. the filter order + 1). numtaps must be odd if a passband includes the Nyquist frequency
filterWidth  = None                    # If width is not None, then assume it is the approximate width of the transition region (expressed in the same units as fs) for use in Kaiser FIR filter design. In this case, the window argument is ignored.
bandPassFreq = [int(1e7), int(1e8)]#, int(895e6), int(905e6), int(995e6), int(1005e6)]#[int(15e6), int(25e6)]  # Cutoff frequency of filter (expressed in the same units as fs) OR an array of cutoff frequencies (that is, band edges). In the latter case, the frequencies in cutoff should be positive and monotonically increasing between 0 and fs/2. The values 0 and fs/2 must not be included in cutoff.

## Alignment

In [None]:
referenceTraceIndex = 350404
referenceTraceXrange = [15000, 40000]# [9500, 44500] #[39000, 44900] #[42900, 44900]
correlationCriterion = 0.65           # drop the trace whose maximum of the correlation between the trace and reference trace is under that
traceCutRange = [13000, 42000]

MAwindowSize = 500
jump         = 1
bound        = 500

## Result

In [None]:
resultFolderName = '{}/alignment_and_drop_GPU_20250203'.format(tracePath)
os.makedirs(resultFolderName, exist_ok=True)
print(resultFolderName)

with open('{}/parameters.txt'.format(resultFolderName), 'w') as fp:
    fp.write('tracePath: {}\n'.format(tracePath))  
    fp.write('filterOrder: {}\n'.format(filterOrder))    
    fp.write('filterWidth: {}\n'.format(filterWidth))    
    fp.write('bandPassFreq: {}\n'.format(bandPassFreq))    
    fp.write('referenceTraceIndex: {}\n'.format(referenceTraceIndex))   
    fp.write('referenceTraceXrange: {}\n'.format(referenceTraceXrange))    
    fp.write('correlationCriterion: {}\n'.format(correlationCriterion))    
    fp.write('traceCutRange: {}\n'.format(traceCutRange))    
    fp.write('MAwindowSize: {}\n'.format(MAwindowSize))    
    fp.write('jump: {}\n'.format(jump))    
    fp.write('bound: {}\n'.format(bound))  

In [None]:
trace = np.load('{}/trace_0.npy'.format(tracePath))
print('Number of points: {}'.format(len(trace)))

# Select Band Pass Freq

In [None]:
firCoeff = signal.firwin(numtaps=filterOrder, cutoff=bandPassFreq, width=filterWidth, window='hamming', fs=samplingFrequency, pass_zero=False)

In [None]:
trace = np.load('{}/trace_{}.npy'.format(tracePath, referenceTraceIndex))
filteredTrace = signal.filtfilt(b=firCoeff, a=1.0, x=trace)
result_path = "C:/Users/01sun"
plt.figure(figsize=(12, 6))
plt.plot(filteredTrace)
plt.title('Original Trace')
plt.xlim(0, trace.shape[0])
plt.ylabel('Voltage (V)', fontsize=axisFontSize)
plt.tight_layout()
plt.savefig('{}/trace.png'.format(result_path))
plt.show()


In [None]:
trace_index_range = [350400, 350410]

In [None]:
for traceIndex in tqdm(range(trace_index_range[0], trace_index_range[1])):
    print(traceIndex)
    trace = np.load('{}/trace_{}.npy'.format(tracePath, traceIndex))
    f, t, Sxx = signal.spectrogram(trace, fs=samplingFrequency, noverlap=noverlap, window=signal.get_window(window='hamming', Nx=windowLength))
    
    duration = trace.shape[0] / samplingFrequency
    
    plt.figure(figsize=(15, 13))
    plt.subplot(6, 1, 1)
    plt.title('Original Signal', fontsize=titleFontSize)
    plt.plot(np.linspace(0, duration, trace.shape[0]), trace, linewidth=0.2)
    plt.xlim(0, duration)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    
    plt.subplot(6, 1, 2)
    plt.title('Spectrogram', fontsize=titleFontSize)
    plt.pcolormesh(t, f, Sxx)
    plt.xlim(0, duration)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    
    plt.subplot(6, 1, 3)
    plt.title('Spectrogram near clock freq', fontsize=titleFontSize)
    ind1 = len(f[f<=550e6])
    ind2 = len(f[f<=650e6])
    plt.pcolormesh(t, f[ind1:ind2], Sxx[ind1:ind2, :])
    plt.xlim(0, duration)
    plt.ylim(550e6, 650e6)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    
    plt.subplot(6, 1, 4)
    plt.title('Spectrogram low freq 1', fontsize=titleFontSize)
    ind = len(f[f<=400e6])
    plt.pcolormesh(t, f[:ind], Sxx[:ind, :])
    plt.xlim(0, duration)
    plt.ylim(0, 400e6)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    
    plt.subplot(6, 1, 5)
    plt.title('Spectrogram low freq 2', fontsize=titleFontSize)
    ind = len(f[f<=100e6])
    plt.pcolormesh(t, f[:ind], Sxx[:ind, :])
    plt.xlim(0, duration)
    plt.ylim(0, 100e6)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
             
    plt.subplot(6, 1, 6)
    plt.title('Spectrogram band pass', fontsize=titleFontSize)
    ind1 = len(f[f<bandPassFreq[0]])
    ind2 = len(f[f<=bandPassFreq[1]])
    plt.pcolormesh(t, f[ind1:ind2], Sxx[ind1:ind2, :])
    plt.xlim(0, duration)
    plt.ylim(bandPassFreq[0], bandPassFreq[1])
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    
    plt.xlabel('Time (sec)', fontsize=axisFontSize)
    plt.tight_layout()
    plt.savefig('{}/spectrogram_{}.png'.format(resultFolderName, traceIndex), dpi=300, bbox_inches='tight')
    plt.show()

    #########################################################################################################
    
    filteredTrace = signal.filtfilt(b=firCoeff, a=1.0, x=trace)
        
    plt.figure(figsize=(15, 7))
    plt.subplot(2, 1, 1)
    plt.title('Original Signal', fontsize=titleFontSize)
    plt.plot(trace, linewidth=0.5)
    plt.xlim(0, trace.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
        
    plt.subplot(2, 1, 2)
    plt.title('Filtered Signal ({}-{}Hz)'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
    plt.plot(filteredTrace, linewidth=0.5)
    plt.xlim(0, trace.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    
    plt.xlabel('Time (point)', fontsize=axisFontSize)
    plt.tight_layout()
    plt.savefig('{}/bandPass_{}.png'.format(resultFolderName, traceIndex), dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
for traceIndex in tqdm(range(350404, 350405)):
    print(traceIndex)
    trace = np.load('{}/trace_{}.npy'.format(tracePath, traceIndex))
    f, t, Sxx = signal.spectrogram(trace, fs=samplingFrequency, noverlap=noverlap, window=signal.get_window(window='hamming', Nx=windowLength))

    # 스펙트로그램의 x축을 포인트 단위로 변환
    t_points = (t * samplingFrequency).astype(int)  # 시간 단위를 포인트 단위로 변경

    plt.figure(figsize=(15, 25))
    
    # Original Signal
    plt.subplot(6, 1, 1)
    plt.title('Original Signal', fontsize=titleFontSize)
    plt.plot(range(trace.shape[0]), trace, linewidth=0.2)  # 포인트 단위
    plt.xlim(0, trace.shape[0])
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
    
    # Spectrogram
    plt.subplot(6, 1, 2)
    plt.title('Spectrogram', fontsize=titleFontSize)
    plt.pcolormesh(t_points, f, Sxx)  # x축을 포인트 단위로 변경
    plt.xlim(0, t_points[-1])
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
    
    # Spectrogram near clock freq
    plt.subplot(6, 1, 3)
    plt.title('Spectrogram near clock freq', fontsize=titleFontSize)
    ind1 = len(f[f <= 550e6])
    ind2 = len(f[f <= 650e6])
    plt.pcolormesh(t_points, f[ind1:ind2], Sxx[ind1:ind2, :])  # x축을 포인트 단위로 변경
    plt.xlim(0, t_points[-1])
    plt.ylim(550e6, 650e6)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
    
    # Spectrogram low freq 1
    plt.subplot(6, 1, 4)
    plt.title('Spectrogram low freq 1', fontsize=titleFontSize)
    ind = len(f[f <= 400e6])
    plt.pcolormesh(t_points, f[:ind], Sxx[:ind, :])  # x축을 포인트 단위로 변경
    plt.xlim(0, t_points[-1])
    plt.ylim(0, 400e6)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
    
    # Spectrogram low freq 2
    plt.subplot(6, 1, 5)
    plt.title('Spectrogram low freq 2', fontsize=titleFontSize)
    ind = len(f[f <= 100e6])
    plt.pcolormesh(t_points, f[:ind], Sxx[:ind, :])  # x축을 포인트 단위로 변경
    plt.xlim(0, t_points[-1])
    plt.ylim(0, 100e6)
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
             
    # Spectrogram band pass
    # Spectrogram band pass
    plt.subplot(6, 1, 6)
    plt.title('Spectrogram band pass', fontsize=titleFontSize)
    ind1 = len(f[f < bandPassFreq[0]])
    ind2 = len(f[f <= bandPassFreq[1]])
    plt.pcolormesh(t_points, f[ind1:ind2], Sxx[ind1:ind2, :])  # x축을 포인트 단위로 변경
    plt.xlim(0, t_points[-1])
    plt.ylim(bandPassFreq[0], bandPassFreq[1])
    plt.ylabel('Frequency (Hz)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
    
    # 추가: 빨간 수직선 (6000과 26000 부분)
    plt.axvline(x=referenceTraceXrange[0], color='red', linestyle='-', linewidth=1.5, label=f'Mark: {referenceTraceXrange[0]}')
    plt.axvline(x=referenceTraceXrange[1], color='red', linestyle='-', linewidth=1.5, label=f'Mark: {referenceTraceXrange[1]}')
    plt.tight_layout()  # 자동 간격 조정
    # 범례 추가
    plt.legend(loc='upper right')
    plt.tight_layout()  # 자동 간격 조정
    plt.savefig('{}/spectrogram_{}.png'.format(resultFolderName, traceIndex), dpi=300, bbox_inches='tight')
    #########################################################################################################
    
    # Band-pass Filtered Signal
    filteredTrace = signal.filtfilt(b=firCoeff, a=1.0, x=trace)
        
    plt.figure(figsize=(15, 7))
    plt.subplot(2, 1, 1)
    plt.title('Original Signal', fontsize=titleFontSize)
    plt.plot(range(trace.shape[0]), trace, linewidth=0.5)  # 포인트 단위
    plt.xlim(0, trace.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)

    plt.subplot(2, 1, 2)
    plt.title('Filtered Signal ({}-{}Hz)'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
    plt.plot(range(filteredTrace.shape[0]), filteredTrace, linewidth=0.5)  # 포인트 단위
    plt.xlim(0, filteredTrace.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    plt.xlabel('Point Index', fontsize=axisFontSize)
    
    plt.tight_layout()
    plt.tight_layout()
    plt.savefig('{}/bandPass_{}.png'.format(resultFolderName, traceIndex), dpi=300, bbox_inches='tight')
    plt.show()


# Setting Reference Trace

In [None]:
# trace = np.load('{}/trace_{}.npy'.format(tracePath, referenceTraceIndex))
filtTrace = signal.filtfilt(b=firCoeff, a=1.0, x=trace)

refTrace = np.array(filtTrace[referenceTraceXrange[0] : referenceTraceXrange[1]], dtype=np.float32)

plt.figure(figsize=(10, 10))
plt.subplot(4, 1, 1)
plt.title('Original Signal', fontsize=titleFontSize)
plt.plot(trace, linewidth=0.5)
plt.vlines(referenceTraceXrange[0], np.min(trace), np.max(trace), color='r', alpha=0.7)
plt.vlines(referenceTraceXrange[1], np.min(trace), np.max(trace), color='r', alpha=0.7)
plt.xlim(0, len(trace))
plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    
plt.subplot(4, 1, 2)
plt.title('Filtered Signal ({}-{}Hz)'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
plt.plot(filtTrace, linewidth=0.5)
plt.vlines(referenceTraceXrange[0], np.min(filtTrace), np.max(filtTrace), color='r', alpha=0.7)
plt.vlines(referenceTraceXrange[1], np.min(filtTrace), np.max(filtTrace), color='r', alpha=0.7)
plt.xlim(0, len(trace))
plt.ylabel('Voltage (V)', fontsize=axisFontSize)

plt.subplot(4, 1, 3)
plt.title('Original Signal', fontsize=titleFontSize)
plt.plot(trace, linewidth=0.5)
plt.xlim(referenceTraceXrange[0], referenceTraceXrange[1])
plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    
plt.subplot(4, 1, 4)
plt.title('Filtered Signal ({}-{}Hz)'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
plt.plot(refTrace, linewidth=0.5)
plt.xlim(0, len(refTrace)-1)
plt.ylabel('Voltage (V)', fontsize=axisFontSize)

plt.xlabel('Time (point)', fontsize=axisFontSize)
plt.tight_layout()
plt.savefig('{}/criterionTrace.png'.format(resultFolderName), dpi=300, bbox_inches='tight')
plt.show()

# Test Criterion

In [None]:
def getCorrs(trace, refTrace, refTraceSum, refTraceSquSum):
    squTrace = np.square(trace)
    traceLen = refTrace.shape[0]

    # Phase 1
    corrs = []
    maxCorr      = 0
    maxCorrIndex = 0
    for pointIndex in range(0, trace.shape[0] - refTrace.shape[0]):
        traceSum    = np.sum(trace[pointIndex:pointIndex+traceLen])
        traceSquSum = np.sum(squTrace[pointIndex:pointIndex+traceLen])
        corr = (traceLen *  np.sum(trace[pointIndex:pointIndex+traceLen] * refTrace) - traceSum * refTraceSum) / (math.sqrt(traceLen * traceSquSum - traceSum ** 2) * math.sqrt(traceLen * refTraceSquSum - refTraceSum ** 2))
        corrs.append(corr)
        if maxCorr < corr:
            maxCorr = corr
            maxCorrIndex = pointIndex
            
    return maxCorr, maxCorrIndex, corrs

In [None]:
refTrace = np.load('{}/trace_{}.npy'.format(tracePath, referenceTraceIndex))
filteredRefTrace = np.array(signal.filtfilt(b=firCoeff, a=1.0, x=refTrace), dtype=np.float32)
refTraceEnvelope = filteredRefTrace[referenceTraceXrange[0] : referenceTraceXrange[1]]

cntOverCriterion = 0
refTraceSum     = np.sum(refTraceEnvelope)
refTraceSquSum  = np.sum(np.square(refTraceEnvelope))
for traceIndex in tqdm(range(350000, 351000)):
    trace = np.load('{}/trace_{}.npy'.format(tracePath, traceIndex))
    filteredTrace = np.array(signal.filtfilt(b=firCoeff, a=1.0, x=trace), dtype=np.float32)
    traceEnvelope = filteredTrace
    #traceEnvelope = np.abs(signal.hilbert(filteredTrace))
    
    maxCorr, maxCorrIndex, corrs = getCorrs(traceEnvelope, refTraceEnvelope, refTraceSum, refTraceSquSum)
    print('Trace {:2d}\tCorrelation: {:.3f}\t(point: {:7d})\tOver criterion: {}'.format(traceIndex, maxCorr, maxCorrIndex, maxCorr >= correlationCriterion))
    cntOverCriterion += maxCorr >= correlationCriterion
    
    plt.figure(figsize=(15, 10))
    plt.subplot(4, 1, 1)
    plt.title('Original Signal', fontsize=titleFontSize)
    plt.plot(trace, linewidth=0.5)
    plt.xlim(0, trace.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    plt.vlines(maxCorrIndex,                           np.min(trace), np.max(trace), color='r', alpha=0.5)
    plt.vlines(maxCorrIndex+refTraceEnvelope.shape[0], np.min(trace), np.max(trace), color='r', alpha=0.5)
        
    plt.subplot(4, 1, 2)
    plt.title('Filtered Signal ({}-{}Hz)'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
    plt.plot(filteredTrace, linewidth=0.5, alpha=0.7)
    plt.plot(np.arange(maxCorrIndex, maxCorrIndex+refTraceEnvelope.shape[0]), filteredRefTrace[referenceTraceXrange[0] : referenceTraceXrange[1]], linewidth=0.5, alpha=0.7)
    plt.xlim(0, filteredTrace.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    plt.vlines(maxCorrIndex,                           np.min(filteredTrace), np.max(filteredTrace), color='r', alpha=0.5)
    plt.vlines(maxCorrIndex+refTraceEnvelope.shape[0], np.min(filteredTrace), np.max(filteredTrace), color='r', alpha=0.5)

    plt.subplot(4, 1, 3)
    plt.title('Overlap zoom'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
    plt.plot(traceEnvelope, linewidth=0.5, alpha=0.7)
    plt.plot(np.arange(maxCorrIndex, maxCorrIndex+refTraceEnvelope.shape[0]), refTraceEnvelope, linewidth=0.5, alpha=0.7)
    plt.xlim(maxCorrIndex,  maxCorrIndex+refTraceEnvelope.shape[0]-1)
    plt.ylabel('Voltage (V)', fontsize=axisFontSize)
    plt.vlines(maxCorrIndex,                           np.min(traceEnvelope), np.max(traceEnvelope), color='r', alpha=0.5)
    plt.vlines(maxCorrIndex+refTraceEnvelope.shape[0], np.min(traceEnvelope), np.max(traceEnvelope), color='r', alpha=0.5)

    plt.subplot(4, 1, 4)
    plt.title('Correlation'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
    plt.plot(corrs)
    plt.xlim(0, len(corrs)-1)
    
    plt.xlabel('Time (point)', fontsize=axisFontSize)
    plt.tight_layout()
    plt.savefig('{}/criterionTest_{}.png'.format(resultFolderName, traceIndex), dpi=300, bbox_inches='tight')
    plt.show()
print(cntOverCriterion)

# GPU Alignment

## Library

In [None]:
import gc
import cupy as cp
import cusignal
from concurrent.futures import ThreadPoolExecutor

## GPU 정렬을 위한 함수

In [None]:
# Check GPU memory capacity
def check_gpu_memory():
    mem_info = cp.cuda.runtime.memGetInfo()
    free_mem = mem_info[0] / (1024 ** 2)  # Free memory in MB
    total_mem = mem_info[1] / (1024 ** 2)  # Total memory in MB
    print(f"GPU Memory - Free: {free_mem:.2f} MB, Total: {total_mem:.2f} MB")
    return free_mem, total_mem

def calculate_chunk_size(free_mem_mb, trace_length, ref_length, float32_size=4, mem_usage_ratio=0.5):
    # Usable memory after reserving for overhead
    usable_mem_mb = free_mem_mb * mem_usage_ratio
    if usable_mem_mb <= 0:
        raise ValueError("Insufficient GPU memory available for processing.")

    usable_mem_bytes = usable_mem_mb * (1024 ** 2)
    
    # Memory required for one trace and filtering
    single_trace_mem = trace_length * float32_size
    ref_trace_mem = ref_length * float32_size
    filtering_overhead_mem = single_trace_mem  # 필터링 중 추가 메모리 요구량

    # Total memory for filtering a batch
    overhead_mem = ref_trace_mem * 3  # Reference trace and its sums
    memory_per_trace = single_trace_mem + filtering_overhead_mem

    # Remaining memory for traces
    available_mem = usable_mem_bytes - overhead_mem
    if available_mem <= 0:
        raise ValueError("Not enough memory for even a single trace.")

    # Calculate maximum chunk size
    chunk_size = int(available_mem / memory_per_trace)
    print(f"Adjusted chunk size: {chunk_size} traces (usable GPU memory: {usable_mem_mb:.2f} MB)")
    return max(chunk_size, 1)


# Optimized GPU-based correlation calculation for multiple traces
def getCorrs_gpu_batch(gpu_traces, gpu_refTrace, gpu_refTraceSum, gpu_refTraceSquSum, traceLen, bound):
    num_traces, trace_length = gpu_traces.shape
    maxCorrs = cp.zeros(num_traces, dtype=cp.float32)
    maxCorrIndices = cp.zeros(num_traces, dtype=cp.int32)

    gpu_refTrace = gpu_refTrace.reshape(1, -1)  # Broadcastable shape

    for pointIndex in tqdm(range(bound, trace_length - traceLen - bound)):
        gpu_segments = gpu_traces[:, pointIndex:pointIndex + traceLen]
        traceSums = cp.sum(gpu_segments, axis=1)
        traceSquSums = cp.sum(cp.square(gpu_segments), axis=1)
        corrs = ((traceLen * cp.sum(gpu_segments * gpu_refTrace, axis=1) - traceSums * gpu_refTraceSum) /
                 (cp.sqrt(traceLen * traceSquSums - traceSums ** 2) *
                  cp.sqrt(traceLen * gpu_refTraceSquSum - gpu_refTraceSum ** 2)))

        update_mask = corrs > maxCorrs
        maxCorrs = cp.where(update_mask, corrs, maxCorrs)
        maxCorrIndices = cp.where(update_mask, pointIndex, maxCorrIndices)

    return cp.asnumpy(maxCorrs), cp.asnumpy(maxCorrIndices)

# File I/O Function (멀티 프로세스 기반)

In [None]:
def load_trace(filepath, cut_range):
    try:
        return np.load(filepath)[cut_range[0]:cut_range[1]]
    except Exception as e:
        print(f"Error loading file {filepath}: {e}")
        return None

## Alignment

In [None]:
gpu_firCoeff = cusignal.firwin(
    numtaps=filterOrder,          # 필터 탭 수
    cutoff=bandPassFreq,          # 대역 통과 주파수
    width=filterWidth,            # 필터의 폭
    window='hamming',             # 창 함수
    fs=samplingFrequency,         # 샘플링 주파수
    pass_zero=False               # 대역 통과 필터 여부
)

# CuPy 배열로 변환 (cusignal은 GPU 기반 연산에 CuPy를 사용)
gpu_firCoeff = cp.asarray(gpu_firCoeff)

In [None]:
refTrace = np.load('{}/trace_{}.npy'.format(tracePath, referenceTraceIndex))
filteredRefTrace = np.array(signal.filtfilt(b=firCoeff, a=1.0, x=refTrace), dtype=np.float32)
filteredRefTrace = filteredRefTrace[referenceTraceXrange[0]:referenceTraceXrange[1]]

refTraceSum = np.sum(filteredRefTrace)
refTraceSquSum = np.sum(np.square(filteredRefTrace))

gpu_refTrace = cp.asarray(filteredRefTrace)
gpu_refTraceSum = cp.asarray(refTraceSum)
gpu_refTraceSquSum = cp.asarray(refTraceSquSum)

# Check GPU memory before processing
free_mem, total_mem = check_gpu_memory()

# Determine chunk size based on GPU memory and reference trace size
trace_length = traceCutRange[1] - traceCutRange[0]
ref_length = filteredRefTrace.shape[0]
chunk_size = calculate_chunk_size(free_mem, trace_length, ref_length)

# Process traces in chunks
gpu_traces = []
chunk_indices = list(range(0, traceNum, chunk_size))

for start_index in tqdm(chunk_indices, desc="Processing chunks"):
    end_index = min(start_index + chunk_size, traceNum)

    # Load all traces for the current chunk using parallel I/O
    with ThreadPoolExecutor(max_workers=16) as executor:
        traces = list(executor.map(
            lambda idx: load_trace(f"{tracePath}/trace_{idx}.npy", traceCutRange),
            range(start_index, end_index)
        ))
    traces = np.array([t for t in traces if t is not None], dtype=np.float32)  # Remove failed loads

    # Upload traces to GPU
    gpu_traces = cp.asarray(traces)  # (chunk_size, trace_length)
    
    # Apply FIR filter to all traces using cuSignal with broadcasting
    filtered_gpu_traces = cusignal.filtfilt(gpu_firCoeff, 1.0, gpu_traces, axis=1)
    gpu_traces = None  # 참조 해제
    cp.get_default_memory_pool().free_all_blocks()  # 메모리 풀에서 해제된 블록 반환

    # Force filtered_gpu_traces back to GPU memory
    filtered_gpu_traces = cp.asarray(filtered_gpu_traces)
    
    # Transfer filtered traces back to CPU if needed
    filtered_traces = cp.asnumpy(filtered_gpu_traces)  # (chunk_size, trace_length)
    print("processing...")
    # Perform correlation for the current chunk
    maxCorrs, maxCorrIndices = getCorrs_gpu_batch(
        filtered_gpu_traces, gpu_refTrace, gpu_refTraceSum, gpu_refTraceSquSum,
        filteredRefTrace.shape[0], bound
    )

    # Save aligned traces
    for traceIndex, (maxCorr, maxCorrIndex) in enumerate(zip(maxCorrs, maxCorrIndices), start=start_index):
        if maxCorr >= correlationCriterion:
            trace = filtered_traces[traceIndex - start_index]
            if 0 <= maxCorrIndex - bound and maxCorrIndex + filteredRefTrace.shape[0] + bound < len(trace):
                np.save('{}/alignTrace{}.npy'.format(resultFolderName, traceIndex), trace[maxCorrIndex - bound : maxCorrIndex + filteredRefTrace.shape[0] + bound])
                np.save('{}/trace{}.npy'.format(resultFolderName, traceIndex), traces[traceIndex - start_index][maxCorrIndex - bound : maxCorrIndex + filteredRefTrace.shape[0] + bound])
                # print('[ {} trace ] corr: {:.3f}    maxCorrIndex: {}'.format(traceIndex, maxCorr, maxCorrIndex))
                # if traceIndex < 100:
                #     plt.figure(figsize=(15, 5))
                #     plt.subplocft(2, 1, 1)
                #     plt.title('Original Signal', fontsize=titleFontSize)
                #     plt.plot(trace, linewidth=0.5)
                #     plt.xlim(0, trace.shape[0]-1)
                #     plt.ylabel('Voltage (V)', fontsize=axisFontSize)
                #     plt.vlines(maxCorrIndex,                           np.min(trace), np.max(trace), color='r', alpha=0.5)
                #     plt.vlines(maxCorrIndex+filteredRefTrace.shape[0], np.min(trace), np.max(trace), color='r', alpha=0.5)
                        
                #     plt.subplot(2, 1, 2)
                #     plt.title('Filtered Signal ({}-{}Hz)'.format(bandPassFreq[0], bandPassFreq[1]), fontsize=titleFontSize)
                #     plt.plot(filtered_traces[traceIndex - start_index], linewidth=0.5, alpha=0.7)
                #     plt.plot(np.arange(maxCorrIndex, maxCorrIndex+filteredRefTrace.shape[0]), filteredRefTrace, linewidth=0.5, alpha=0.7)
                #     plt.xlim(0, filtered_traces[traceIndex - start_index].shape[0]-1)
                #     plt.ylabel('Voltage (V)', fontsize=axisFontSize)
                #     plt.vlines(maxCorrIndex,                           np.min(filtered_traces[traceIndex - start_index]), np.max(filtered_traces[traceIndex - start_index]), color='r', alpha=0.5)
                #     plt.vlines(maxCorrIndex+filteredRefTrace.shape[0], np.min(filtered_traces[traceIndex - start_index]), np.max(filtered_traces[traceIndex - start_index]), color='r', alpha=0.5)
                    
                #     plt.xlabel('Time (point)', fontsize=axisFontSize)
                #     plt.tight_layout()
                #     plt.show()

    # Free GPU memory
    del filtered_gpu_traces
    cp._default_memory_pool.free_all_blocks()

    # Free CPU memory
    del traces, filtered_traces
    gc.collect()