In [1]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import widgets, interact
from IPython.display import display
import csidata

In [2]:
# Sample data
# data = sio.loadmat('../DeepSeg/01Data_PreProcess/Data_CsiAmplitudeCut/philipp/55philipp_uc_d.mat')['data_']
data = csidata.load('DeepSeg/01Data_PreProcess/Data_CsiAmplitudeCut/philipp/55philipp_uc_d.mat', csidata.FileFormat.MAT_7_3)
csi = data.csi
csi.shape

# csiTrace = data[0, 0, :]
csiTrace = csi[0, 0, :100]
csiTrace.shape

(100,)

In [3]:
import subprocess
import threading
import io

img_thread: threading.Thread = None
img_subprocess: subprocess.Popen = None

def display_video_frame(filename, frame_idx):
    global img_thread, img_subprocess, prev_frame_idx, tmp_file
    print(f'Requesting frame {frame_idx}')
    if img_thread is not None and img_thread.is_alive():
        if prev_frame_idx == frame_idx:
            print('Same frame being requested, wait for existing thread to finish')
            return
        
        print('Stopping previous subprocess')
        img_subprocess.kill()
        img_thread.join()
    else:
        if prev_frame_idx == frame_idx:
            print('Frame already displayed')
            return
    
    print('Creating new thread')
    img_thread = threading.Thread(target=get_video_frame, args=(filename, frame_idx))
    img_thread.start()
    prev_frame_idx = frame_idx


def get_video_frame(filename, frame_idx):
    global image, img_subprocess

    print(f'FFMPEG get frame {frame_idx}')
    img_subprocess = subprocess.Popen(['ffmpeg', '-v', 'error', 
                                  '-f', 'rawvideo', '-pix_fmt', 'yuv422p', '-s', '1920x1080', '-i', f'{filename}', 
                                  '-vf', fr'select=eq(n\,{frame_idx})', '-vframes', '1', '-f', 'image2', '-c:v', 'png', 'pipe:'], 
                                  stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
    print('FFMPEG process started')
    out, err = img_subprocess.communicate()
    if img_subprocess.returncode != 0:
        print(f'Error getting frame {frame_idx}')
        print(err)
        return
    
    print(f'Frame {frame_idx} received')
    
    image.value = out
    print(f'Set image for frame {frame_idx}')


In [4]:
def getNearestIdx(array, value):
    idx = (np.abs(array - value)).argmin()
    return idx

In [5]:
class Notifier():
    def __init__(self):
        self.func = []

    def register(self, func):
        self.func.append(func)

    def notfy(self):
        for f in self.func:
            f()

In [6]:
import cv2

def display_frame_mp4(cap, frame_idx):
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()
    if not ret:
        raise ValueError(f'Frame {frame_idx} not found')
    return cv2.imencode('.png', frame)[1].tobytes()

In [7]:
class SegmentListEntry:
    def __init__(self, segment_list_manager, segment_bounds: tuple[int, int], segment_label: str, segment_id: int):
        self.segment_bounds = segment_bounds
        self.segment_id = segment_id
        self.segment_type = None
        self.item = widgets.HBox()
        self.id_label = widgets.Label(str(segment_id))
        self.label = widgets.Label(segment_label)
        self.segment_list_manager = segment_list_manager

        space = widgets.HTML("<span style='margin-right: 1em;'></span>")

        move_down_button = widgets.Button(description='⬇️', layout=widgets.Layout(width='auto'))
        move_up_button = widgets.Button(description='⬆️', layout=widgets.Layout(width='auto'))
        remove_button = widgets.Button(description='X', layout=widgets.Layout(width='auto'))

        remove_button.on_click(lambda _: self.segment_list_manager.removeSegment(self))
        move_down_button.on_click(lambda _: self.segment_list_manager.move_down(self))
        move_up_button.on_click(lambda _: self.segment_list_manager.move_up(self))

        self.item.children = [self.id_label, space, self.label, move_down_button, move_up_button, remove_button]

    def update_id(self, new_id: int):
        self.segment_id = new_id
        self.id_label.value = str(new_id)


class SegmentListManager:
    def __init__(self, segments: list[tuple[int, int]], segment_notifier: Notifier, digits_padding: int) -> None:
        self.segment_notifier = segment_notifier
        self.segment_notifier.register(self.update)
        self.segment_list_items: list[SegmentListEntry] = []
        self.segments = segments
        self.digits_padding = digits_padding
        self.vbox = widgets.VBox(layout=widgets.Layout(width='250px'))
        
    def update(self) -> None:
        def segment_equals_entry(segment_bounds: tuple[int, int], entry: SegmentListEntry) -> bool:
            return entry.segment_bounds[0] == segment_bounds[0] and entry.segment_bounds[1] == segment_bounds[1]

        if len(self.segments) == len(self.segment_list_items): # two segments have been swapped, update list items
            for idx, s in enumerate(self.segments):
                if not segment_equals_entry(s, self.segment_list_items[idx]):
                    self.segment_list_items[idx].item.close()
                    self.segment_list_items[idx] = self.createSegmentEntry(s, idx)
            
        elif len(self.segments) > len(self.segment_list_items): # segment has been added
            missing_item = [s for s in self.segments if not any(segment_equals_entry(s, e) for e in self.segment_list_items)] [0]
            self.segment_list_items.append(self.createSegmentEntry(missing_item, self.segments.index(missing_item)))

        elif len(self.segments) < len(self.segment_list_items): # segment has been removed
            excess_item = [e for e in self.segment_list_items if not any(segment_equals_entry(s, e) for s in self.segments)] [0]
            self.segment_list_items.remove(excess_item)
            excess_item.item.close()

        # self.segment_list_items = [self.createSegmentEntry(segment) for segment in self.segments]
        self.vbox.children = [e.item for e in self.segment_list_items]
        

    def createSegmentEntry(self, segment_bounds: tuple[int, int], segment_id, segment_type: str = None) -> None:
        label = f"({segment_bounds[0]:>0{self.digits_padding}}, {segment_bounds[1]:>0{self.digits_padding}})"
        return SegmentListEntry(self, segment_bounds, label, segment_id)

    def removeSegment(self, item) -> None:
        idx = self.segment_list_items.index(item)
        self.segments.pop(idx)
        self.segment_notifier.notfy()

    def move_up(self, item) -> None:
        idx = self.segment_list_items.index(item)
        if idx > 0:
            self.segments[idx], self.segments[idx - 1] = self.segments[idx - 1], self.segments[idx]
            self.segment_list_items[idx], self.segment_list_items[idx - 1] = self.segment_list_items[idx - 1], self.segment_list_items[idx]
            self.segment_list_items[idx].update_id(idx)
            self.segment_list_items[idx - 1].update_id(idx - 1)
            self.segment_notifier.notfy()

    
    def move_down(self, item) -> None:
        idx = self.segment_list_items.index(item)
        if idx < len(self.segments) - 1:
            self.segments[idx], self.segments[idx + 1] = self.segments[idx + 1], self.segments[idx]
            self.segment_list_items[idx], self.segment_list_items[idx + 1] = self.segment_list_items[idx + 1], self.segment_list_items[idx]
            self.segment_list_items[idx].update_id(idx)
            self.segment_list_items[idx + 1].update_id(idx + 1)
            self.segment_notifier.notfy()

    
    def getListBox(self):
        return self.vbox

In [8]:
class CSIPlot:
    def __init__(self, csiTrace: np.ndarray, segments: list[tuple[int, int]], segment_notifier: Notifier) -> None:
        self.notifier = segment_notifier
        self.notifier.register(self.redraw_segments)
        self.csiTrace = csiTrace
        self.output = widgets.Output()
        self.segment_selection_start = None
        self.cursorLine = None
        self.segmentStartLine = None

        self.segments = segments

        self.segmentSelectionSpan = None
        self.segmentSpans: list[plt.axvspan] = []
        self.segment_labels: list[plt.Text] = []

        self.margin = 0.1

        with self.output:
            with plt.ioff():
                self.fig, self.ax = plt.subplots()
            self.fig: plt.Figure
            self.ax: plt.Axes

            self.fig.canvas.header_visible = False
            # self.fig.canvas.toolbar_visible = 'fade-in-fade-out'
            self.fig.canvas.toolbar_visible = False
            self.fig.canvas.footer_visible = False
            self.fig.canvas.capture_scroll = True
    
            self.ax.plot(csiTrace)

            self.cursorLine = self.ax.axvline(0, color='r', alpha=0.5)

            self.segmentStartLine = self.ax.axvline(0, color='g', alpha=0.5)
            self.segmentStartLine.set_visible(False)
            
            self.segmentSelectionSpan = self.ax.axvspan(0, 0, color='orange', alpha=0.2)
            self.segmentSelectionSpan.set_visible(False)
            
        self.fig.canvas.mpl_connect('scroll_event', self.handle_scroll)
            
    def start_selection(self, xpos):
        self.segment_selection_start = xpos
        self.segmentStartLine.set_xdata([xpos])
        self.segmentStartLine.set_visible(True)
        self.segmentSelectionSpan.set_xy([xpos, 0])
        self.segmentSelectionSpan.set_width(0)
        self.segmentSelectionSpan.set_visible(True)
        with self.output:
            self.fig.canvas.draw_idle()
    
    def end_selection(self, xpos):
        if not self.segment_selection_start:
            return

        self.segments.append((self.segment_selection_start, xpos))
        self.segment_selection_start = None
        self.segmentStartLine.set_visible(False)
        self.segmentSelectionSpan.set_visible(False)
        with self.output:
            self.redraw_segments()
            self.fig.canvas.draw_idle()
        self.notifier.notfy()
    
    def abort_selection(self):
        self.segment_selection_start = None
        self.segmentStartLine.set_visible(False)
        self.segmentSelectionSpan.set_visible(False)
        with self.output:
            self.fig.canvas.draw_idle()
    
    def handle_plot_click(self, xpos):
        if xpos < 0 or xpos >= len(self.csiTrace):
            return
        
        if self.segment_selection_start:
            self.end_selection(xpos)
        else:
            self.start_selection(xpos)

    def draw_cursor(self,xpos):
        self.cursorLine.set_xdata([xpos])
        with self.output:
            self.fig.canvas.draw_idle()
    
    def draw_current_segment_selection_span(self, cursorPosition):
        if self.segment_selection_start:
            self.segmentSelectionSpan.set_xy([self.segment_selection_start, 0])
            self.segmentSelectionSpan.set_width(cursorPosition - self.segment_selection_start)
            self.fig.canvas.draw_idle()

    def redraw_segments(self):
        for segment in self.segmentSpans:
            segment.remove()
        
        self.segmentSpans.clear()
        
        for segment in self.segments:
            self.segmentSpans.append(self.ax.axvspan(segment[0], segment[1], color='yellow', alpha=0.2))
        
        self.redraw_segment_labels()
        with self.output:
            self.fig.canvas.draw_idle()
    
    def redraw_segment_labels(self):
        for label in self.segment_labels:
            label.remove()

        self.segment_labels.clear()
        visible_yregion = self.ax.get_ylim()
        
        for idx, segment in enumerate(self.segments):
            self.segment_labels.append(self.ax.text((segment[0] + segment[1])/2, visible_yregion[1], str(idx), horizontalalignment='center', verticalalignment='top'))

    def move_cursor(self, xpos):
        global video_raw, image, cap, tmp_file
        xpos = int(xpos)
        display_video_frame(tmp_file, xpos)
        # display_frame_mp4(cap, xpos)
        self.draw_cursor(xpos)
        self.draw_current_segment_selection_span(xpos)

    def handle_scroll(self, event):
        if event.inaxes != self.ax: # ignore scroll event outside axis
            return
        zoom_factor = 0.1

        cur_xlim = self.ax.get_xlim()

        xdata = event.xdata

        if event.button == 'up':  # Zoom in
            new_xlim = [xdata - (xdata-cur_xlim[0])*(1-zoom_factor), xdata + (cur_xlim[1]-xdata)*(1-zoom_factor)]
            self.ax.set_xlim(new_xlim)
        elif event.button == 'down':  # Zoom out
            new_xlim = [xdata - (xdata-cur_xlim[0])*(1+zoom_factor), xdata + (cur_xlim[1]-xdata)*(1+zoom_factor)]
            self.ax.set_xlim(new_xlim)

        # limit x range to the data range to avoid out of bounds when determining min/max for this region
        x_idx_in_view = [max(round(new_xlim[0]), 0), 
                         min(round(new_xlim[1]), len(self.csiTrace))]

        new_ylim = [self.csiTrace[x_idx_in_view[0]:x_idx_in_view[1]].min(), self.csiTrace[x_idx_in_view[0]:x_idx_in_view[1]].max()]

        self.ax.set_ylim([new_ylim[0] + new_ylim[0] * self.margin, new_ylim[1] + new_ylim[1] * self.margin])

        self.redraw_segment_labels()
        self.fig.canvas.draw_idle()  # Redraw the figure to update the plot



In [9]:
%matplotlib widget
import os
import subprocess


segments = []

segment_notifier = Notifier()
plot = CSIPlot(csiTrace, segments, segment_notifier)
segment_manager = SegmentListManager(segments, segment_notifier, len(str(csiTrace.shape[0])))

index_slider = widgets.IntSlider(min=0, max=csiTrace.shape[0] -1, step=10, value=0)
interactiveDings = widgets.interactive(plot.move_cursor, xpos=index_slider)

mark_button = widgets.Button(description='Mark')
unmark_button = widgets.Button(description='Unmark')
mark_button.on_click(lambda _: plot.handle_plot_click(index_slider.value))
unmark_button.on_click(lambda _: plot.abort_selection())

controls_hbox = widgets.HBox([mark_button, unmark_button])

save_button = widgets.Button(description='Save')
save_button.on_click(lambda event: print(segments))
# save_button.layout.width = '100%'

def handle_segment_click(x):
    global segment_manager

    x_idx = round(x)
    plot.handle_plot_click(x_idx)

plot.fig.canvas.mpl_connect('button_press_event', lambda event: handle_segment_click(round(event.xdata)) if event.inaxes == plot.ax else None)
plot.fig.canvas.mpl_connect('motion_notify_event', lambda event: plot.move_cursor(round(event.xdata)) if event.inaxes == plot.ax else None)

image = widgets.Image()
image.width = 500

hbox_main_content = widgets.HBox([segment_manager.getListBox(), plot.fig.canvas, image])
hbox_main_content.layout.width = '100%'
plot.fig.canvas.layout.width = 'auto'

# set the hbox border color to red
# hbox_segments_plot.layout.border = '1px solid red'



display(hbox_main_content)


# display(interactiveDings)
display(controls_hbox)
display(save_button)

filename = "record/recordings/2024-09-04T13-11-58+02-00.mp4"
filename_stem = os.path.splitext(os.path.basename(filename))[0]
tmp_folder = "record/recordings/tmp"

tmp_file = f"{tmp_folder}/{filename_stem}.yuv"

video_raw =  subprocess.run(f"ffmpeg -y -v error -i {filename} -f rawvideo {tmp_file}".split(), stdout=subprocess.PIPE, check=True)
# cap = cv2.VideoCapture('/home/felix/Documents/uni/BaProj/bachelorproject/record/recordings/2024-09-04T13-11-58+02-00.mp4')
prev_frame_idx = -1
# get_video_frame(video_raw.stdout, 0)


HBox(children=(VBox(layout=Layout(width='250px')), Canvas(capture_scroll=True, footer_visible=False, header_vi…

HBox(children=(Button(description='Mark', style=ButtonStyle()), Button(description='Unmark', style=ButtonStyle…

Button(description='Save', style=ButtonStyle())

Requesting frame -1
Frame already displayed
Requesting frame 3
Creating new thread
FFMPEG get frame 3
FFMPEG process started
Requesting frame 5
Stopping previous subprocess
Creating new thread
Error getting frame 3
b''
FFMPEG get frame 5
FFMPEG process started
Requesting frame 7
Stopping previous subprocess
Creating new thread
Error getting frame 5
b''
FFMPEG get frame 7
FFMPEG process started
Requesting frame 9
Stopping previous subprocess
Creating new thread
Error getting frame 7
b''
FFMPEG get frame 9
FFMPEG process started
Requesting frame 10
Stopping previous subprocess
Creating new thread
Error getting frame 9
b''
FFMPEG get frame 10
FFMPEG process started
Requesting frame 13
Stopping previous subprocess
Creating new thread
Error getting frame 10
b''
FFMPEG get frame 13
FFMPEG process started
Requesting frame 13
Same frame being requested, wait for existing thread to finish
Frame 13 received
Set image for frame 13
Requesting frame 13
Frame already displayed
Requesting frame 13
Fr