In [1]:
import tkinter as tk
from tkinter import messagebox, filedialog

import librosa
import pandas as pd
import os
import numpy as np


class LabelingTool:
    def __init__(self, root, folder_path, out_path, labels, sr=11025, n_strips=8, scale_y=1.5):
        self.root = root
        self.root.title("Time Series Labeling Tool")
        self.folder_path = folder_path
        self.out_path = out_path
        self.files = [f for f in os.listdir(folder_path) if f.endswith('.wav')]  # assuming CSV time series files
        self.current_file_idx = 0
        self.labels = labels
        self.regions = []  # Store labels as (start, end, label)
        self.current_label = None
        self.drag_start = None
        self.drag_start_tc = None
        self.sr = sr
        self.redraw_timer = None
        self.total_redraw = False
        self.n_strips = n_strips
        self.scale_y = scale_y
        self.sample_window_start = 0
        self.sample_window_end = n_strips * 2
        self.samples_per_strip = 2

        # Create UI components
        self.create_widgets()

        self.strip_height = self.canvas.winfo_height() // n_strips
        self.canvas_width = self.canvas.winfo_width()
        self.draw_running = False

        # Load first file
        self.load_file(self.current_file_idx)

    def create_widgets(self):
        canvas_frame = tk.Frame(self.root)
        canvas_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        # Label to show the current file name
        self.file_label = tk.Label(canvas_frame, text="Current File: ", font=("Arial", 12))
        self.file_label.pack(side=tk.TOP, anchor="w", padx=10, pady=5)

        # Canvas to display the timeseries
        self.canvas = tk.Canvas(canvas_frame, bg='white')
        self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Canvas for timeline
        self.time_canvas = tk.Canvas(canvas_frame, bg='white', height=275)
        self.time_canvas.pack(side=tk.BOTTOM, fill=tk.X)

        # Bind mouse events to the canvas
        self.canvas.bind("<ButtonPress-1>", self.on_click)
        self.canvas.bind("<B1-Motion>", self.on_drag)
        self.canvas.bind("<ButtonRelease-1>", self.on_release)
        self.canvas.bind("<Configure>", self.on_resize)  # Bind resize event

        self.time_canvas.bind("<ButtonPress-1>", self.on_click_tc)
        self.time_canvas.bind("<B1-Motion>", self.on_drag_tc)
        self.time_canvas.bind("<ButtonRelease-1>", self.on_release_tc)
        self.time_canvas.bind("<Configure>", self.on_resize)  # Bind resize event

        # Control frame for configuration and buttons
        control_frame = tk.Frame(self.root)
        control_frame.pack(side=tk.RIGHT, padx=10, pady=10)

        config_frame = tk.Frame(control_frame)
        config_frame.pack(side=tk.TOP, padx=10, pady=10)

        # Sample rate configuration
        tk.Label(config_frame, text="Number of Strips:").grid(row=0, column=0, padx=5, pady=5)
        self.strips_slider = tk.Scale(config_frame, from_=1, to=16, resolution=1, orient=tk.HORIZONTAL,
                                      command=self.update_n_strips)
        self.strips_slider.set(self.n_strips)  # Set default value
        self.strips_slider.grid(row=0, column=1, padx=5, pady=5)

        tk.Label(config_frame, text="Y-Scaling:").grid(row=2, column=0, padx=5, pady=5)
        self.y_scale_slider = tk.Scale(config_frame, from_=1, to=7.5, resolution=0.05, orient=tk.HORIZONTAL,
                                       command=self.update_y_scale)
        self.y_scale_slider.set(self.scale_y)
        self.y_scale_slider.grid(row=2, column=1, padx=5, pady=5)

        # Label selection buttons
        for label in self.labels.keys():
            color = self.labels[label]
            label_button = tk.Button(control_frame, text=label, bg=color, fg="black",
                                     command=lambda label=label: self.set_label(label))
            label_button.pack(side=tk.TOP, padx=5, pady=5)

        # Listbox to show the created labels
        self.label_listbox = tk.Listbox(control_frame, width=50)
        self.label_listbox.pack(side=tk.TOP, padx=10, pady=5)

        # Delete label button
        self.delete_button = tk.Button(control_frame, text="Delete Label", command=self.delete_label)
        self.delete_button.pack(side=tk.TOP, padx=5, pady=5)

        # Export labels to CSV button
        self.export_button = tk.Button(control_frame, text="Export to CSV", command=self.export_labels)
        self.export_button.pack(side=tk.TOP, padx=5, pady=5)

        # Next file button
        self.next_button = tk.Button(control_frame, text="Next File", command=self.next_file)
        self.next_button.pack(side=tk.TOP, padx=5, pady=5)

    def update_y_scale(self, value):
        value = float(value)
        if self.scale_y != value:
            self.scale_y = value
            self.delayed_redraw()

    def update_n_strips(self, value):
        value = int(value)
        if self.n_strips != value:
            self.n_strips = value
            self.update_samples_per_strip()
            self.delayed_redraw()

    def set_label(self, label):
        self.current_label = label
        messagebox.showinfo("Label Selected", f"Current label set to: {label}")

    def delayed_redraw(self):
        if self.redraw_timer:
            self.root.after_cancel(self.redraw_timer)
        self.redraw_timer = self.root.after(20, self.redraw)

    def redraw(self):
        if self.draw_running:
            self.delayed_redraw()
            return
        self.draw_running = True
        self.redraw_timer = None
        self.strip_height = self.canvas.winfo_height() // self.n_strips
        self.canvas_width = self.canvas.winfo_width()
        self.canvas.delete("all")
        self.visualize_waveform()
        self.draw_regions()

        if self.total_redraw:
            self.time_canvas.delete("all")
            self.draw_waveform_timeline()
            self.draw_time_region(0, self.time_canvas.winfo_width(), fill="Blue", stipple="gray50", tags="region_tc")
            self.total_redraw = False

        self.draw_running = False


    def get_strip(self, y):
        """Returns the index (0 to n) of the strip based on the y-coordinate."""
        strip_index = y // self.strip_height
        if strip_index < 0:
            strip_index = 0
        elif strip_index >= self.n_strips:
            strip_index = self.n_strips - 1  # In case the y is slightly out of bounds
        return int(strip_index)

    def on_click_tc(self, event):
        self.drag_start_tc = min(max(event.x, 0), self.time_canvas.winfo_width()-1)

    def on_drag_tc(self, event):
        if self.drag_start_tc is not None:
            self.time_canvas.delete("highlight_tc")  # Remove previous highlights

            x_start = self.drag_start_tc
            x_stop = min(max(event.x, 0), self.time_canvas.winfo_width()-1)

            self.draw_time_region(x_start, x_stop, fill="Green", stipple="gray50", tags="highlight_tc")

    def on_release_tc(self, event):
        if self.drag_start_tc is not None:
            x_start = self.drag_start_tc
            x_stop = min(max(event.x, 0), self.time_canvas.winfo_width()-1)
            if x_start > x_stop:
                tmp = x_start
                x_start = x_stop
                x_stop = tmp
                del tmp

            # Scale to time domain
            width = self.time_canvas.winfo_width()
            total_samples = len(self.y)

            start_sample = x_start * total_samples / width
            stop_sample = x_stop * total_samples / width

            self.time_canvas.delete("region_tc")
            self.draw_time_region(x_start, x_stop, fill="Blue", stipple="gray50", tags="region_tc")

            self.update_sample_window(start_sample, stop_sample)

            # Reset drag start
            self.drag_start_tc = None
            self.time_canvas.delete("highlight_tc")
            self.delayed_redraw()

    def update_sample_window(self, start_sample, stop_sample):
        self.sample_window_start = int(start_sample)
        self.sample_window_end = int(stop_sample)
        self.update_samples_per_strip()

    def update_samples_per_strip(self):
        window_samples = self.sample_window_end - self.sample_window_start
        self.samples_per_strip = max(1, window_samples // self.n_strips)

    def on_resize(self, event):
        """Handles the resizing of the window."""
        self.total_redraw = True
        self.delayed_redraw()

    def on_click(self, event):
        if self.current_label:
            self.drag_start = min(max(event.x, 0), self.canvas.winfo_width()-1)
            # Determine which strip is being clicked
            self.strip_start = self.get_strip(event.y)

    def on_drag(self, event):
        if self.drag_start is not None:
            self.canvas.delete("highlight")  # Remove previous highlights

            x_start = self.drag_start
            x_stop = min(max(event.x, 0), self.canvas.winfo_width()-1)

            strip_start = self.strip_start
            strip_stop = self.get_strip(event.y)
            self.draw_region(strip_start, strip_stop, x_start, x_stop, fill=self.labels[self.current_label],
                             stipple="gray50", tags="highlight")

    def on_release(self, event):
        if self.drag_start is not None and self.current_label:
            strip_start, strip_stop, x_start, x_stop = self.order_strip_and_x(self.strip_start,
                                                                              self.get_strip(event.y),
                                                                              self.drag_start,
                                                                              min(max(event.x, 0), self.canvas.winfo_width()-1))
            # Scale to time domain
            start_sample, start_time = self.scale_to_time(x_start, strip_start)
            end_sample, end_time = self.scale_to_time(x_stop, strip_stop)

            self.regions.append((start_time, end_time, self.current_label, start_sample, end_sample))
            self.update_label_listbox()

            self.draw_region(strip_start, strip_stop, x_start, x_stop, fill=self.labels[self.current_label],
                             stipple="gray50", tags="region")


            x_start = self.sample_to_timecanvas_x(start_sample)
            x_stop = self.sample_to_timecanvas_x(end_sample)
            self.draw_time_region(x_start, x_stop, fill=self.labels[self.current_label], stipple="gray50", tags="region")

            # Reset drag start
            self.drag_start = None
            self.strip_start = None
            self.canvas.delete("highlight")

    def order_strip_and_x(self, strip_start, strip_stop, x_start, x_stop):
        if strip_stop < strip_start or (strip_start == strip_stop and x_start > x_stop):
            return strip_stop, strip_start, x_stop, x_start
        else:
            return strip_start, strip_stop, x_start, x_stop

    def sample_to_timecanvas_x(self, sample):
        return sample * self.time_canvas.winfo_width() / len(self.y)

    def draw_regions(self):
        self.canvas.delete("region")
        self.time_canvas.delete("region")
        for start_time, end_time, label, start_sample, end_sample in self.regions:
            strip_start, x_start = self.sample_to_strip(start_sample)
            strip_stop, x_stop = self.sample_to_strip(end_sample)

            x_start_tc = self.sample_to_timecanvas_x(start_sample)
            x_stop_tc = self.sample_to_timecanvas_x(end_sample)

            self.draw_time_region(x_start_tc, x_stop_tc, fill=self.labels[label], stipple="gray50", tags="region")

            if strip_stop < 0 or strip_start >= self.n_strips:
                continue

            if strip_stop >= self.n_strips:
                x_stop = self.canvas_width
                strip_stop = self.n_strips - 1

            if strip_start < 0:
                strip_start = 0
                x_start = 0

            self.draw_region(strip_start, strip_stop, x_start, x_stop, fill=self.labels[label], stipple="gray50",
                             tags="region")

    def sample_to_strip(self, sample): # TODO
        sample = sample - self.sample_window_start
        strip = int(sample // self.samples_per_strip)
        sample_in_strip = sample % self.samples_per_strip
        x = int(sample_in_strip * self.canvas_width / self.samples_per_strip)
        return strip, x

    def draw_region(self, strip_start, strip_stop, x_start, x_stop, **kwargs):
        strip_start, strip_stop, x_start, x_stop = self.order_strip_and_x(strip_start, strip_stop, x_start, x_stop)

        # Create a rectangle
        for i in range(strip_start, strip_stop + 1):
            x0 = x_start if i == strip_start else 0
            x1 = x_stop if i == strip_stop else self.canvas_width
            y0 = i * self.strip_height
            y1 = (i + 1) * self.strip_height
            self.canvas.create_rectangle(x0, y0, x1, y1, **kwargs)

    def draw_time_region(self, x_start, x_stop, **kwargs):
        if x_start > x_stop:
            tmp = x_start
            x_start = x_stop
            x_stop = tmp
            del tmp

        self.time_canvas.create_rectangle(x_start, 0, x_stop, self.time_canvas.winfo_height(), **kwargs)

    def scale_to_time(self, x, strip):
        """ Convert canvas X coordinate to time based on the file's time series """
        sample = (x * self.samples_per_strip) / self.canvas_width + strip * self.samples_per_strip
        sample = sample + self.sample_window_start
        time = self.get_time_str(sample)
        return sample, time

    def get_time_str(self, sample_number):
        time_in_seconds = sample_number / self.sr
        seconds = int(time_in_seconds)
        microseconds = int((time_in_seconds - seconds) * 1e6)
        time_str = f"{seconds}.{microseconds:06d}"
        return time_str

    def update_label_listbox(self):
        self.label_listbox.delete(0, tk.END)
        for idx, (start, end, label, start_sample, stop_sample) in enumerate(self.regions):
            self.label_listbox.insert(tk.END, f"{idx + 1}. {start}s - {end}s : {label}")
            self.label_listbox.itemconfig(idx, bg=self.labels[label])

    def delete_label(self):
        selected = self.label_listbox.curselection()
        if selected:
            index = selected[0]
            del self.regions[index]
            self.update_label_listbox()
            self.draw_regions()

    def export_labels(self):
        if not self.regions:
            messagebox.showwarning("No Labels", "No labels to export.")
            return

       
        file_path = os.path.join(self.out_path, self.files[self.current_file_idx][:-4] + "_labels.csv")

       
        df = pd.DataFrame(self.regions, columns=["Start", "End", "Label", "Start Sample", "End Sample"])
        try:
            df.to_csv(file_path, index=False)
            messagebox.showinfo("Exported", f"Labels exported successfully to {file_path}!")
        except Exception as e:
            messagebox.showerror("Export Error", f"An error occurred while saving the file: {e}")
            
    def load_file(self, index):
        if index < 0 or index >= len(self.files):
            messagebox.showwarning("End of Files", "No more files to load.")
            return

        file_path = os.path.join(self.folder_path, self.files[index])

        # Update the file label to show the current file name
        self.file_label.config(text=f"Current File: {self.files[index]}")

        print(f"Draw {file_path}")

        # Load the data
        y, actual_sr = librosa.load(file_path, sr=None)
        if self.sr != actual_sr:
            y = librosa.resample(y, orig_sr=actual_sr, target_sr=self.sr)
            print(f"Resampled {file_path} from {actual_sr} to {self.sr}.")
        self.y = y

        self.update_sample_window(0, len(y))

        self.canvas.delete("all")
        self.time_canvas.delete("all")
        self.visualize_waveform()
        self.draw_waveform_timeline()
        self.draw_time_region(0, self.time_canvas.winfo_width(), fill="Blue", stipple="gray50", tags="region_tc")

        # Reset labels
        self.regions = []
        self.update_label_listbox()

    def visualize_waveform(self):
        """Draws waveform of the audio data across n_strips horizontal strips."""
        y = self.y
        samples_per_strip = self.samples_per_strip

        for i in range(self.n_strips):
            start = self.sample_window_start + i * samples_per_strip
            stop = self.sample_window_start + start + samples_per_strip
            if start >= len(y):
                continue
            elif stop >= len(y):
                strip_data = y[start:len(y)]
            else:
                strip_data = y[start:stop]
            self.draw_waveform(strip_data, i)

    def draw_waveform(self, data, strip_index):
        strip_top = strip_index * self.strip_height

        # Draw the waveform on the canvas
        y_pos = np.array(data)
        y_pos = y_pos * (-self.scale_y)
        y_pos = y_pos + 1
        y_pos = y_pos * (self.strip_height // 2)
        y_pos = y_pos + strip_top

        x_pos = np.arange(len(data))
        x_pos = x_pos * self.canvas_width / self.samples_per_strip

        coords = np.column_stack((x_pos, y_pos)).ravel()

        self.canvas.create_line(*coords, fill="black", width=1)

    def draw_waveform_timeline(self):
        height = self.time_canvas.winfo_height()
        width = self.time_canvas.winfo_width()
        y_length = len(self.y)

        x_pos = np.arange(y_length)
        x_pos = x_pos * (width / y_length)

        y_pos = np.array(self.y)
        y_pos = -y_pos
        y_pos = y_pos + 1
        y_pos = y_pos * (height // 2)

        coords = np.column_stack((x_pos, y_pos)).ravel()

        self.time_canvas.create_line(*coords, fill="black", width=1)


    def next_file(self):
        self.current_file_idx += 1
        if self.current_file_idx >= len(self.files):
            messagebox.showwarning("No More Files", "You have reached the last file.")
        else:
            self.load_file(self.current_file_idx)


if __name__ == '__main__':
    root = tk.Tk()

    wav_path = filedialog.askdirectory(title="Select Folder with .wav Files")
    out_path = filedialog.askdirectory(title="Select Folder to Output .csv Files")

    if wav_path and out_path:
        app = LabelingTool(root, wav_path, out_path, labels={"PSW": "lightgreen", "Fibrillation": "lightblue"})
        root.mainloop()
    else:
        messagebox.showerror("No Folder", "You must select a folder containing time series files.")


ModuleNotFoundError: No module named 'tkinter'