In [None]:
import os
import glob
import json
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, clear_output
import h5py
import cv2

In [None]:
class ClickablePeakAnnotator:
    def __init__(self, data_folder, output_folder, window_size=300):
        # --- Paths and Config ---
        self.data_folder = data_folder
        self.output_folder = output_folder
        self.peak_folder = os.path.join(output_folder, 'peak_spectrograms')
        self.non_peak_folder = os.path.join(output_folder, 'non_peak_spectrograms')
        self.window_size = window_size
        self.progress_file = os.path.join(output_folder, 'annotation_log.json')
        
        # --- Create Directories ---
        os.makedirs(self.peak_folder, exist_ok=True)
        os.makedirs(self.non_peak_folder, exist_ok=True)
        
        # --- State Variables ---
        self.files = sorted(glob.glob(os.path.join(data_folder, '*.csv')))
        self.file_basenames = [os.path.basename(f) for f in self.files]
        self.current_file_idx = 0
        self.current_row_idx = 0
        self.df_current = None
        self.current_signal = None
        self.current_peaks = None
        
        # --- Load & Pre-scan ---
        self.annotated_log = self.load_progress()
        print("Scanning files to get row counts...")
        self.file_row_counts = self.scan_all_files() # Get total rows for all files
        print("Scan complete. Building UI.")
        
        # --- UI Components ---
        self.overview_label = widgets.HTML(value=self.build_overview_html())
        self.file_dropdown = widgets.Dropdown(options=self.file_basenames, description='Select File:', style={'description_width': 'initial'})
        self.start_button = widgets.Button(description='Load File', button_style='primary')
        self.skip_button = widgets.Button(description="Skip Signal (No Peak)", button_style='warning')
        
        self.progress_bar = widgets.IntProgress(value=0, min=0, max=100, description='File Progress:', style={'description_width': 'initial'})
        self.progress_label = widgets.Label(value='(0/0)')
        
        self.header_output = widgets.Output()
        self.plot_output = widgets.Output()
        
        # --- UI Layout ---
        self.file_selector_ui = widgets.HBox([self.file_dropdown, self.start_button])
        self.progress_ui = widgets.HBox([self.progress_bar, self.progress_label])
        
        self.main_annotator_ui = widgets.VBox([
            self.header_output, 
            self.plot_output, 
            self.skip_button
        ])
        
        self.full_ui = widgets.VBox([
            widgets.HTML("<h3>Overall Project Progress</h3>"),
            self.overview_label,
            widgets.HTML("<hr><h3>Start Annotation Session</h3>"),
            self.file_selector_ui,
            self.progress_ui,
            self.main_annotator_ui
        ])

        # Initialize UI state
        self.progress_ui.layout.visibility = 'hidden'
        self.main_annotator_ui.layout.visibility = 'hidden'
        
        # --- Setup Callbacks ---
        self.start_button.on_click(self.start_annotation_session)
        self.skip_button.on_click(self.skip_current)
        
        # --- Display the UI ---
        display(self.full_ui)

    def scan_all_files(self):
        """Pre-scans all CSVs to get their total row counts for the progress overview."""
        counts = {}
        for i, file_path in enumerate(self.files):
            fname = self.file_basenames[i]
            try:
                # Use pandas to quickly get row count
                df_temp = pd.read_csv(file_path, header=None, usecols=[0])
                counts[fname] = df_temp.shape[0]
            except Exception as e:
                print(f"Warning: Could not read {fname} for row count. Error: {e}")
                counts[fname] = 0 # Mark as 0 if unreadable
        return counts

    def build_overview_html(self):
        """Creates the HTML for the overall progress report."""
        html = "<ul>"
        total_rows_done = 0
        total_rows_all_files = 0
        
        for fname in self.file_basenames:
            total = self.file_row_counts.get(fname, 0)
            completed_list = self.annotated_log.get(fname, [])
            completed_count = len(completed_list)
            
            total_rows_done += completed_count
            total_rows_all_files += total
            
            status = ""
            if completed_count >= total and total > 0:
                status = "<b> - ✅ COMPLETE</b>"
            elif completed_count > 0:
                status = " - In Progress"
                
            html += f"<li><b>{fname}:</b> ({completed_count} / {total}) rows {status}</li>"
            
        html += "</ul>"
        
        # Add total summary
        overall_summary = f"<b>Total All Files: ({total_rows_done} / {total_rows_all_files}) rows annotated</b>"
        html = overall_summary + "<br>" + html
        
        return html

    def load_progress(self):
        """Loads the annotation log file if it exists."""
        if os.path.exists(self.progress_file):
            try:
                with open(self.progress_file, 'r') as f:
                    return json.load(f)
            except json.JSONDecodeError:
                return {}
        return {} # Format: {"filename": [list_of_completed_row_indices]}

    def save_progress(self):
        """Saves the current annotation progress to the log file."""
        with open(self.progress_file, 'w') as f:
            json.dump(self.annotated_log, f, indent=2)

    def get_next_unprocessed_row(self):
        """Finds the next unprocessed row *in the currently loaded file*."""
        fname = os.path.basename(self.files[self.current_file_idx])
        processed_rows = self.annotated_log.get(fname, [])
        
        for r_idx in range(self.df_current.shape[0]):
            if r_idx not in processed_rows:
                return r_idx
        return -1 # This file is complete

    def start_annotation_session(self, b):
        """Callback for the 'Load File' button."""
        selected_fname = self.file_dropdown.value
        self.current_file_idx = self.file_basenames.index(selected_fname)
        try:
            self.df_current = pd.read_csv(self.files[self.current_file_idx], header=None)
        except Exception as e:
            with self.header_output:
                clear_output()
                print(f"Error loading {selected_fname}: {e}")
            return
        
        # Show the annotator UI
        self.progress_ui.layout.visibility = 'visible'
        self.main_annotator_ui.layout.visibility = 'visible'
        self.file_dropdown.disabled = True
        self.start_button.disabled = True
        
        self.update_progress_ui() # Set initial progress
        self.load_next_batch()

    def load_next_batch(self):
        """Loads the next unprocessed signal from the current file."""
        r_idx = self.get_next_unprocessed_row()
        fname = os.path.basename(self.files[self.current_file_idx])
        
        if r_idx == -1:
            # This file is finished
            with self.header_output:
                clear_output()
                print(f"✅ File {fname} is complete! Select a new file above.")
            with self.plot_output:
                clear_output()
            
            # Hide annotator and re-enable file selection
            self.main_annotator_ui.layout.visibility = 'hidden'
            self.progress_ui.layout.visibility = 'hidden'
            self.file_dropdown.disabled = False
            self.start_button.disabled = False
            self.df_current = None
            
            # Refresh the overall progress view
            self.overview_label.value = self.build_overview_html()
            return

        self.current_row_idx = r_idx
        self.current_signal = self.df_current.iloc[r_idx, 16:].astype(float)
        self.current_peaks, _ = find_peaks(self.current_signal.values, prominence=1000, distance=50)
        
        self.render_plot(fname, r_idx)

    def render_plot(self, fname, r_idx):
        """Renders the interactive Plotly graph."""
        with self.header_output:
            clear_output(wait=True)
            print(f"File: {fname} | Signal Row: {r_idx} | {len(self.current_peaks)} peaks found.")
            print("➡️ INSTRUCTIONS: Zoom in, then CLICK directly on the peak (on the blue line).")

        with self.plot_output:
            clear_output(wait=True)
            
            fig = go.FigureWidget()
            
            # Trace 0: Signal (Clickable)
            fig.add_trace(go.Scatter(
                y=self.current_signal.values, mode='lines', name='Signal',
                line=dict(color='#1f77b4', width=1), hoverinfo='x+y'
            ))

            # Trace 1: Peaks (Visual guide only)
            fig.add_trace(go.Scatter(
                x=self.current_peaks,
                y=self.current_signal.values[self.current_peaks],
                mode='markers', name='Suggested Peaks',
                marker=dict(color='red', size=12, symbol='x-thin', line=dict(width=2)),
                hoverinfo='skip'
            ))

            fig.update_layout(
                xaxis_title="Sample", yaxis_title="Amplitude", height=500,
                margin=dict(l=20, r=20, t=20, b=20),
                hovermode='closest', dragmode='zoom'
            )

            # --- CLICK CALLBACK ---
            def on_click(trace, points, selector):
                if not points.point_inds: return
                real_peak_index = points.point_inds[0]
                
                with self.header_output:
                    print(f"Processing Peak at sample {real_peak_index}...")
                self.process_selection(real_peak_index)

            # Attach Callback to Trace 0 (The Signal Line)
            fig.data[0].on_click(on_click)
            display(fig)
            
    def update_progress_ui(self):
        """Updates the progress bar and label for the *current file*."""
        fname = os.path.basename(self.files[self.current_file_idx])
        processed_rows = self.annotated_log.get(fname, [])
        
        total_rows = self.file_row_counts.get(fname, 0)
        completed_rows = len(processed_rows)
        
        self.progress_bar.max = total_rows
        self.progress_bar.value = completed_rows
        self.progress_label.value = f"({completed_rows} / {total_rows})"

    def process_selection(self, peak_index):
        """Generates and saves spectrograms and updates progress."""
        start = max(0, peak_index - self.window_size // 2)
        end = start + self.window_size
        if end > len(self.current_signal):
            end = len(self.current_signal)
            start = end - self.window_size
        peak_window = self.current_signal.values[start:end]

        non_peak_window = None
        for _ in range(50):
            rand_start = np.random.randint(0, len(self.current_signal) - self.window_size)
            if not (rand_start <= peak_index < rand_start + self.window_size):
                non_peak_window = self.current_signal.values[rand_start:rand_start + self.window_size]
                break
        
        self.save_spectrogram(peak_window, 'peak')
        if non_peak_window is not None:
            self.save_spectrogram(non_peak_window, 'non_peak')
            
        fname = os.path.basename(self.files[self.current_file_idx])
        if fname not in self.annotated_log:
            self.annotated_log[fname] = []
        self.annotated_log[fname].append(int(self.current_row_idx))
        self.save_progress()
        
        self.update_progress_ui() # Update progress bar
        self.load_next_batch()

    def save_spectrogram(self, window_data, label):
        """Uses Matplotlib to save a 300x300 grayscale spectrogram PNG."""
        now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        fname = f"{label}_f{self.current_file_idx}_r{self.current_row_idx}_{now}.png"
        folder = self.peak_folder if label == 'peak' else self.non_peak_folder
        save_path = os.path.join(folder, fname)

        fig, ax = plt.subplots(figsize=(3, 3))
        plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
        ax.axis('off')
        
        ax.specgram(window_data, Fs=1024, mode='psd', scale='dB', cmap='gray')
        plt.savefig(save_path, dpi=100, bbox_inches='tight', pad_inches=0)
        plt.close(fig)

    def skip_current(self, b):
        """Skips the current signal and updates progress."""
        fname = os.path.basename(self.files[self.current_file_idx])
        if fname not in self.annotated_log:
            self.annotated_log[fname] = []
        self.annotated_log[fname].append(int(self.current_row_idx))
        self.save_progress()
        
        with self.header_output:
            print(f"Skipping {fname} - Row {self.current_row_idx}...")
        
        self.update_progress_ui() # Update progress bar
        self.load_next_batch()

In [None]:
# --- Define your paths here ---
DATA_DIR = 'data' 
OUTPUT_DIR = 'annotated_dataset'

# --- Start the tool ---
# This will display the file selection dropdown
print("Initializing Annotator...")
app = ClickablePeakAnnotator(DATA_DIR, OUTPUT_DIR)

In [None]:
def combine_spectrograms_to_h5(output_folder, h5_filename='combined_spectrograms.h5'):
    """
    Finds all PNGs in the peak/non_peak subfolders and combines them
    into a single HDF5 file for efficient ML training.
    """
    peak_path = os.path.join(output_folder, 'peak_spectrograms')
    non_peak_path = os.path.join(output_folder, 'non_peak_spectrograms')
    
    peak_files = glob.glob(os.path.join(peak_path, '*.png'))
    non_peak_files = glob.glob(os.path.join(non_peak_path, '*.png'))
    
    print(f"Found {len(peak_files)} peaks and {len(non_peak_files)} non-peaks.")
    
    if len(peak_files) == 0 and len(non_peak_files) == 0:
        print("No images found. Annotate some data first.")
        return

    # Read one image to get dimensions (e.g., 300x300)
    sample_file = peak_files[0] if peak_files else non_peak_files[0]
    sample = cv2.imread(sample_file, cv2.IMREAD_GRAYSCALE)
    if sample is None:
        print(f"Error reading sample image: {sample_file}")
        return
    h, w = sample.shape
    
    # Initialize H5 file
    with h5py.File(h5_filename, 'w') as hf:
        # Create datasets
        total_count = len(peak_files) + len(non_peak_files)
        
        dset_img = hf.create_dataset('images', (total_count, h, w, 1), dtype='u1')
        dset_labels = hf.create_dataset('labels', (total_count,), dtype='u1')
        
        idx = 0
        
        # Process Peak Images (Label = 1)
        print(f"Processing {len(peak_files)} Peak Images...")
        for f in peak_files:
            img = cv2.imread(f, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = np.expand_dims(img, axis=-1) # Add channel dimension
                dset_img[idx] = img
                dset_labels[idx] = 1 # Label 1 for 'peak'
                idx += 1
            
        # Process Non-Peak Images (Label = 0)
        print(f"Processing {len(non_peak_files)} Non-Peak Images...")
        for f in non_peak_files:
            img = cv2.imread(f, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                img = np.expand_dims(img, axis=-1)
                dset_img[idx] = img
                dset_labels[idx] = 0 # Label 0 for 'non_peak'
                idx += 1
            
    print(f"\nSuccessfully saved {idx} items to {h5_filename}")

In [None]:
# --- Run the combiner ---
OUTPUT_DIR = 'annotated_dataset'
H5_FILE_NAME = 'my_training_data.h5'

combine_spectrograms_to_h5(OUTPUT_DIR, h5_filename=H5_FILE_NAME)