In [52]:
import os
import cv2
import math
import time
import pandas as pd
from dataclasses import dataclass, asdict
from datetime import datetime, timedelta
from dateutil import parser as dtparser

import tkinter as tk
from tkinter import ttk, filedialog, messagebox

from PIL import Image, ImageTk


@dataclass
class Annotation:
    video_path: str
    behavior: str
    kind: str                  # "bout" or "event"
    start_frame: int
    end_frame: int
    start_time_s: float
    end_time_s: float
    start_datetime: str        # ISO string or ""
    end_datetime: str          # ISO string or ""


class BehaviourAnnotatorGUI:
    def __init__(self, root: tk.Tk):
        self.root = root
        self.root.title("Behaviour Annotator")
        self.root.geometry("1200x760")

        # state of the videos 
        self.cap = None
        self.video_path = ""
        self.fps = 0.0
        self.frame_count = 0
        self.cur_frame_idx = 0
        self.playing = False
        self._last_rendered_idx = None

        # start dt
        self.start_dt = None  # datetime  None

        #
        self.behaviors = []
        self.active_bout = None  # dict with behavior + start info
        self.annotations = []    # list[Annotation]

        # bulding the user interface here below
        self._build_ui()

        # render loop tick
        self.root.after(30, self._ui_tick)

    def _build_ui(self):
        self.root.configure(bg="#111316")

        style = ttk.Style()
        try:
            style.theme_use("vista")
        except Exception:
            pass

        style.configure(".", background="#203C49", foreground="#0C4F69", fieldbackground="#2A2E33") # the foreground here defines the clour of the text in the fields an on the
        style.configure("TFrame", background="#111316")
        style.map("TButton",
                  background=[("active", "#050505"), ("pressed", "#1A2F3C")],
                  foreground=[("active", "#0D660D"), ("pressed", "#EC1717")]) # the foreground of buttons when hover on

        style.configure("TLabel", background="#111316", foreground="#EAEAEA") # colour of text on the main window
        style.configure("TButton", padding=6)
        style.configure("TLabelframe", background="#111316", foreground="#EAEAEA")
        style.configure("TLabelframe.Label", background="#111316", foreground="#EAEAEA")
        style.configure("TNotebook", background="#111316", borderwidth=0)
        style.configure("TNotebook.Tab", padding=(12, 8))
        style.map("TNotebook.Tab", background=[("selected", "#1A1F24")], foreground=[("selected", "#FFFFFF")])

        top = ttk.Frame(self.root)
        top.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)

        btn_load = ttk.Button(top, text="Load video", command=self.load_video)
        btn_load.pack(side=tk.LEFT)

        self.video_label = ttk.Label(top, text="No video loaded")
        self.video_label.pack(side=tk.LEFT, padx=10)

        # FPS override
        ttk.Label(top, text="FPS:").pack(side=tk.LEFT, padx=(20, 4))
        self.fps_var = tk.StringVar(value="")
        self.fps_entry = ttk.Entry(top, width=8, textvariable=self.fps_var)
        self.fps_entry.pack(side=tk.LEFT)
        ttk.Button(top, text="Apply FPS", command=self.apply_fps_override).pack(side=tk.LEFT, padx=6)

        # Start datetime
        ttk.Label(top, text="Start date/time (optional):").pack(side=tk.LEFT, padx=(20, 4))
        self.startdt_var = tk.StringVar(value="")
        self.startdt_entry = ttk.Entry(top, width=26, textvariable=self.startdt_var)
        self.startdt_entry.pack(side=tk.LEFT)
        ttk.Button(top, text="Set", command=self.apply_start_datetime).pack(side=tk.LEFT, padx=6)

        # 
        main = ttk.Frame(self.root)
        main.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=10, pady=(0, 10))

        left = ttk.Frame(main)
        left.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)

        right = ttk.Frame(main, width=420)
        right.pack(side=tk.RIGHT, fill=tk.Y)
        right.pack_propagate(False)

        # where we show the video and controlfs
        vid_box = ttk.LabelFrame(left, text="Video")
        vid_box.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        self.canvas = tk.Canvas(vid_box, bg="#000000", highlightthickness=0)
        self.canvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=10, pady=10)

        controls = ttk.Frame(vid_box)
        controls.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 8))

        self.btn_play = ttk.Button(controls, text="Play", command=self.toggle_play, state=tk.DISABLED)
        self.btn_play.pack(side=tk.LEFT)

        ttk.Button(controls, text="<< -100f", command=lambda: self.step_frames(-100), state=tk.NORMAL).pack(side=tk.LEFT, padx=6)
        ttk.Button(controls, text="< -1f", command=lambda: self.step_frames(-1), state=tk.NORMAL).pack(side=tk.LEFT)
        ttk.Button(controls, text="+1f >", command=lambda: self.step_frames(1), state=tk.NORMAL).pack(side=tk.LEFT, padx=6)
        ttk.Button(controls, text="+100f >>", command=lambda: self.step_frames(100), state=tk.NORMAL).pack(side=tk.LEFT)

        self.info_var = tk.StringVar(value="Frame: - / - | Time: - s | Timestamp: -")
        ttk.Label(controls, textvariable=self.info_var).pack(side=tk.RIGHT)

        self.slider = ttk.Scale(vid_box, from_=0, to=0, orient=tk.HORIZONTAL, command=self.on_slider)
        self.slider.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 10))

        # add ethograms list
        eth_box = ttk.LabelFrame(right, text="Ethogram")
        eth_box.pack(side=tk.TOP, fill=tk.X, padx=0, pady=(0, 10))

        eth_top = ttk.Frame(eth_box)
        eth_top.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)

        ttk.Label(eth_top, text="Behaviours (comma-separated):").pack(side=tk.TOP, anchor="w")
        self.eth_var = tk.StringVar(value="Replace with your behaviour")
        self.eth_entry = ttk.Entry(eth_top, textvariable=self.eth_var)
        self.eth_entry.pack(side=tk.TOP, fill=tk.X, pady=6)

        ttk.Button(eth_top, text="Apply ethogram", command=self.apply_ethogram).pack(side=tk.TOP, anchor="w")

        mode_row = ttk.Frame(eth_box)
        mode_row.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 10))
        self.mode_var = tk.StringVar(value="bout")
        ttk.Radiobutton(mode_row, text="Bout mode (Start/End)", value="bout", variable=self.mode_var).pack(side=tk.LEFT)
        ttk.Radiobutton(mode_row, text="Event mode (Mark)", value="event", variable=self.mode_var).pack(side=tk.LEFT, padx=12)

        self.active_bout_var = tk.StringVar(value="Active bout: none")
        ttk.Label(eth_box, textvariable=self.active_bout_var).pack(side=tk.TOP, anchor="w", padx=10, pady=(0, 8))

        self.beh_btn_frame = ttk.Frame(eth_box)
        self.beh_btn_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 10))

        # 
        ann_box = ttk.LabelFrame(right, text="Annotations")
        ann_box.pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        cols = ("behavior", "kind", "start_frame", "end_frame", "start_time_s", "end_time_s", "start_datetime", "end_datetime")
        self.tree = ttk.Treeview(ann_box, columns=cols, show="headings", height=12)
        for c in cols:
            self.tree.heading(c, text=c)
            self.tree.column(c, width=110 if c in ("behavior", "kind") else 120, anchor="center")
        self.tree.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=10, pady=10)

        btn_row = ttk.Frame(ann_box)
        btn_row.pack(side=tk.TOP, fill=tk.X, padx=10, pady=(0, 10))

        ttk.Button(btn_row, text="Undo last", command=self.undo_last).pack(side=tk.LEFT)
        ttk.Button(btn_row, text="Delete selected", command=self.delete_selected).pack(side=tk.LEFT, padx=6)
        ttk.Button(btn_row, text="Clear all", command=self.clear_all).pack(side=tk.LEFT, padx=6)

        ttk.Button(btn_row, text="Save CSV", command=self.save_csv).pack(side=tk.RIGHT)

        # Footer
        footer = ttk.Frame(self.root)
        footer.pack(side=tk.BOTTOM, fill=tk.X)
        ttk.Label(
            footer,
            text="Â© 2026 BiRBSLAB | UiT | Norway",
            font=("Courier New", 10),
            foreground="orange",
        ).pack(pady=(2, 0))
        
        self.hyperlink_label = tk.Label(
            footer,
            text="Developed by Hamid Taghipourbibalan",
            font=("Courier New", 8),
            foreground="#1384e7",
            bg="#111316",
            cursor="hand2"
        )
        self.hyperlink_label.pack(pady=(0, 6))
        self.hyperlink_label.bind("<Button-1>", lambda e: self.open_hyperlink("https://www.linkedin.com/in/hamid-taghipourbibalan-b7239088/"))

        # Initialise ethogram buttons
        self.apply_ethogram()

    def load_video(self):
        path = filedialog.askopenfilename(
            title="Select video",
            filetypes=[("Video files", "*.mp4 *.avi *.mov *.mkv *.m4v"), ("All files", "*.*")]
        )
        if not path:
            return

        self._close_video()

        cap = cv2.VideoCapture(path)
        if not cap.isOpened():
            messagebox.showerror("Error", "Could not open video.")
            return

        fps = cap.get(cv2.CAP_PROP_FPS)
        if fps is None or fps <= 0 or math.isnan(fps):
            fps = 0.0

        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)

        self.cap = cap
        self.video_path = path
        self.fps = float(fps)
        self.frame_count = frame_count
        self.cur_frame_idx = 0
        self.playing = False
        self._last_rendered_idx = None

        self.video_label.config(text=os.path.basename(path))
        self.fps_var.set("" if self.fps <= 0 else f"{self.fps:.3f}")

        # slider config
        self.slider.configure(from_=0, to=max(0, self.frame_count - 1))
        self.slider.set(0)

        self.btn_play.config(state=tk.NORMAL, text="Play")

        # Render first frame
        self.render_frame(0)

    def _close_video(self):
        if self.cap is not None:
            try:
                self.cap.release()
            except Exception:
                pass
        self.cap = None

    def apply_fps_override(self):
        txt = self.fps_var.get().strip()
        if not txt:
            return
        try:
            fps = float(txt)
            if fps <= 0:
                raise ValueError
            self.fps = fps
            messagebox.showinfo("FPS set", f"FPS set to {self.fps:.3f}")
            self._update_info_label()
        except Exception:
            messagebox.showerror("Invalid FPS", "Please enter a valid FPS number (e.g., 30 or 29.97).")

    def apply_start_datetime(self):
        txt = self.startdt_var.get().strip()
        if not txt:
            self.start_dt = None
            self._update_info_label()
            messagebox.showinfo("Start datetime cleared", "No absolute timestamps will be computed.")
            return
        try:
            dt = dtparser.parse(txt)
            self.start_dt = dt
            self._update_info_label()
            messagebox.showinfo("Start datetime set", f"Start datetime set to:\n{dt.isoformat(sep=' ')}")
        except Exception:
            messagebox.showerror(
                "Invalid datetime",
                "Could not parse date/time.\nExamples:\n- 2026-02-05 21:15:00\n- 05/02/2026 21:15\n- 2026-02-05T21:15:00"
            )

    def apply_ethogram(self):
        raw = self.eth_var.get().strip()
        behaviors = [b.strip() for b in raw.split(",") if b.strip()]
        if not behaviors:
            behaviors = ["Behaviour1"]

        self.behaviors = behaviors

        # clear old buttons
        for w in self.beh_btn_frame.winfo_children():
            w.destroy()

        # make buttons
        for b in self.behaviors:
            ttk.Button(self.beh_btn_frame, text=b, command=lambda bb=b: self.on_behavior_click(bb)).pack(
                side=tk.TOP, fill=tk.X, pady=3
            )

        self._refresh_active_bout_label()

    def toggle_play(self):
        if self.cap is None:
            return
        self.playing = not self.playing
        self.btn_play.config(text="Pause" if self.playing else "Play")

    def step_frames(self, n: int):
        if self.cap is None:
            return
        self.playing = False
        self.btn_play.config(text="Play")
        new_idx = max(0, min(self.frame_count - 1, self.cur_frame_idx + n))
        self.cur_frame_idx = new_idx
        self.slider.set(new_idx)
        self.render_frame(new_idx)

    def on_slider(self, _val):
        if self.cap is None:
            return
        idx = int(float(self.slider.get()))
        if idx != self.cur_frame_idx:
            self.playing = False
            self.btn_play.config(text="Play")
            self.cur_frame_idx = idx
            self.render_frame(idx)

    def _ui_tick(self):
        if self.cap is not None and self.playing:
            nxt = self.cur_frame_idx + 1
            if nxt >= self.frame_count:
                self.playing = False
                self.btn_play.config(text="Play")
            else:
                self.cur_frame_idx = nxt
                self.slider.set(nxt)
                self.render_frame(nxt)

        self.root.after(30, self._ui_tick)

    def render_frame(self, idx: int):
        if self.cap is None:
            return
        if self._last_rendered_idx == idx:
            self._update_info_label()
            return

        # seek
        self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ok, frame = self.cap.read()
        if not ok or frame is None:
            self._update_info_label()
            return

        self._last_rendered_idx = idx

        # BGR -> RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # fit to canvas
        canvas_w = max(1, self.canvas.winfo_width())
        canvas_h = max(1, self.canvas.winfo_height())

        h, w = frame_rgb.shape[:2]
        scale = min(canvas_w / w, canvas_h / h)
        new_w = max(1, int(w * scale))
        new_h = max(1, int(h * scale))

        img = Image.fromarray(frame_rgb).resize((new_w, new_h), Image.Resampling.LANCZOS)
        self._tk_img = ImageTk.PhotoImage(img)

        self.canvas.delete("all")
        x0 = (canvas_w - new_w) // 2
        y0 = (canvas_h - new_h) // 2
        self.canvas.create_image(x0, y0, anchor="nw", image=self._tk_img)

        self._update_info_label()

    def _update_info_label(self):
        if self.cap is None or self.frame_count <= 0:
            self.info_var.set("Frame: - / - | Time: - s | Timestamp: -")
            return

        t_s = self._frame_to_seconds(self.cur_frame_idx)
        ts = "-"
        if self.start_dt is not None and self.fps and self.fps > 0:
            ts_dt = self.start_dt + timedelta(seconds=t_s)
            ts = ts_dt.isoformat(sep=" ")

        self.info_var.set(
            f"Frame: {self.cur_frame_idx} / {self.frame_count - 1} | "
            f"Time: {t_s:.3f} s | Timestamp: {ts}"
        )

    def _frame_to_seconds(self, frame_idx: int) -> float:
        if self.fps and self.fps > 0:
            return frame_idx / self.fps
        return float(frame_idx)

    def _frame_to_datetime_str(self, frame_idx: int) -> str:
        if self.start_dt is None or not (self.fps and self.fps > 0):
            return ""
        dt = self.start_dt + timedelta(seconds=self._frame_to_seconds(frame_idx))
        return dt.isoformat(sep=" ")

    def on_behavior_click(self, behavior: str):
        if self.cap is None:
            messagebox.showwarning("No video", "Load a video first.")
            return

        mode = self.mode_var.get()

        if mode == "event":
            # instantaneous marker at current frame
            self._add_event(behavior)
            return

        # bout mode
        if self.active_bout is None:
            # start a bout
            self.active_bout = {
                "behavior": behavior,
                "start_frame": self.cur_frame_idx,
            }
        else:
            # if clicking same behavior => end it; if different => end old then start new
            if self.active_bout["behavior"] == behavior:
                self._end_active_bout(end_frame=self.cur_frame_idx)
            else:
                # close previous bout at current frame - 1 (or current frame), then start new
                self._end_active_bout(end_frame=max(self.active_bout["start_frame"], self.cur_frame_idx))
                self.active_bout = {"behavior": behavior, "start_frame": self.cur_frame_idx}

        self._refresh_active_bout_label()

    def _add_event(self, behavior: str):
        sf = self.cur_frame_idx
        ef = self.cur_frame_idx
        st = self._frame_to_seconds(sf)
        et = self._frame_to_seconds(ef)
        sd = self._frame_to_datetime_str(sf)
        ed = self._frame_to_datetime_str(ef)

        ann = Annotation(
            video_path=self.video_path,
            behavior=behavior,
            kind="event",
            start_frame=sf,
            end_frame=ef,
            start_time_s=st,
            end_time_s=et,
            start_datetime=sd,
            end_datetime=ed,
        )
        self.annotations.append(ann)
        self._tree_add(ann)

    def _end_active_bout(self, end_frame: int):
        if self.active_bout is None:
            return

        behavior = self.active_bout["behavior"]
        sf = int(self.active_bout["start_frame"])
        ef = int(end_frame)

        if ef < sf:
            ef = sf

        st = self._frame_to_seconds(sf)
        et = self._frame_to_seconds(ef)
        sd = self._frame_to_datetime_str(sf)
        ed = self._frame_to_datetime_str(ef)

        ann = Annotation(
            video_path=self.video_path,
            behavior=behavior,
            kind="bout",
            start_frame=sf,
            end_frame=ef,
            start_time_s=st,
            end_time_s=et,
            start_datetime=sd,
            end_datetime=ed,
        )
        self.annotations.append(ann)
        self._tree_add(ann)

        self.active_bout = None

    def _refresh_active_bout_label(self):
        if self.active_bout is None:
            self.active_bout_var.set("Active bout: none")
        else:
            b = self.active_bout["behavior"]
            sf = self.active_bout["start_frame"]
            self.active_bout_var.set(f"Active bout: {b} (started at frame {sf})")

    def _tree_add(self, ann: Annotation):
        vals = (
            ann.behavior, ann.kind, ann.start_frame, ann.end_frame,
            f"{ann.start_time_s:.6f}", f"{ann.end_time_s:.6f}",
            ann.start_datetime, ann.end_datetime
        )
        self.tree.insert("", tk.END, values=vals)

    def undo_last(self):
        if not self.annotations:
            return
        self.annotations.pop()
        children = self.tree.get_children()
        if children:
            self.tree.delete(children[-1])

    def delete_selected(self):
        sel = self.tree.selection()
        if not sel:
            return

        for item in sel:
            self.tree.delete(item)

        self._rebuild_annotations_from_tree()

    def clear_all(self):
        if not self.annotations and not self.tree.get_children():
            return
        if messagebox.askyesno("Clear all", "Delete ALL saved annotations?"):
            self.annotations = []
            for item in self.tree.get_children():
                self.tree.delete(item)

    def _rebuild_annotations_from_tree(self):
        anns = []
        for item in self.tree.get_children():
            v = self.tree.item(item, "values")
            behavior = v[0]
            kind = v[1]
            sf = int(v[2])
            ef = int(v[3])
            st = float(v[4])
            et = float(v[5])
            sd = v[6]
            ed = v[7]
            anns.append(
                Annotation(
                    video_path=self.video_path,
                    behavior=behavior,
                    kind=kind,
                    start_frame=sf,
                    end_frame=ef,
                    start_time_s=st,
                    end_time_s=et,
                    start_datetime=sd,
                    end_datetime=ed,
                )
            )
        self.annotations = anns

    def save_csv(self):
        if not self.annotations:
            messagebox.showwarning("Nothing to save", "No annotations yet.")
            return

        if self.active_bout is not None:
            if messagebox.askyesno("Active bout", "You have an active bout. End it at the current frame before saving?"):
                self._end_active_bout(end_frame=self.cur_frame_idx)
                self._refresh_active_bout_label()

        default_name = "annotations.csv"
        if self.video_path:
            base = os.path.splitext(os.path.basename(self.video_path))[0]
            default_name = f"{base}_annotations.csv"

        out_path = filedialog.asksaveasfilename(
            title="Save CSV",
            defaultextension=".csv",
            initialfile=default_name,
            filetypes=[("CSV", "*.csv")]
        )
        if not out_path:
            return

        rows = [asdict(a) for a in self.annotations]

        for r in rows:
            r["fps_used"] = self.fps
            r["frame_count"] = self.frame_count
            r["video_filename"] = os.path.basename(self.video_path)
            r["start_datetime_input"] = "" if self.start_dt is None else self.start_dt.isoformat(sep=" ")

        df = pd.DataFrame(rows)
        df.to_csv(out_path, index=False)

        messagebox.showinfo("Saved", f"Saved:\n{out_path}")

    def open_hyperlink(self, url: str):
        import webbrowser
        webbrowser.open(url)


def main():
    root = tk.Tk()
    app = BehaviourAnnotatorGUI(root)
    root.mainloop()


if __name__ == "__main__":
    main()
