In [1]:
import json
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from matplotlib.text import Text
from matplotlib.patches import Rectangle
from matplotlib.widgets import Slider
import numpy as np
from pathlib import Path
from collections import defaultdict
from typing import Literal, Optional
from PIL import Image
%matplotlib qt6

In [2]:
class VideoVisualizer:
    def __init__(
            self,
            datasets: dict,
            splits_dir: Path,
            split: Literal['train', 'valid', 'test'],
            video_dir: str
        ) -> None:
        self.datasets: dict[str, dict[str, list[dict]]] = datasets
        self.splits_dir = Path(splits_dir)
        self.split = split
        self.video_dir = video_dir

        self.img_ids: list[int] = []
        self.img_names: list[str] = []
        self.imgs_path: Path = self.splits_dir / self.split / self.video_dir / 'img1'

        self.img_id: int = None

        self.anns_by_frame: defaultdict[int, list[dict]] = defaultdict(list)
        self.category_names: dict[int, str] = None

        self.IMG_WIDTH: int = self.datasets[self.split]['images'][0]['width']
        self.IMG_HEIGHT: int = self.datasets[self.split]['images'][0]['height']

        self.bboxes: dict[int, tuple[Rectangle, Text]] = None
        self.bbox_colors = {1: '#3498db', 2: '#2ecc71', 3: '#e74c3c'}
        self.bboxes_visible = True
        self.bbox_labels_visible = True
        self.visible_track_ids: set[int] = set()
        self.selected_track_id: Optional[int] = None
        self.default_linewidth = 2.0

        self.ax: Axes = None
        self.frame_slider: Slider = None
        self.img_show: AxesImage = None

        self._collect_images()
        self._collect_annotations()
    

    def _collect_images(self) -> None:
        for image in self.datasets[self.split]['images']:
            dir_name, file_name = image['file_name'].split('/img1/')
            if dir_name.split('/')[-1] == self.video_dir:       # video_dir сравнивается с названием директории над img1 (иначе выглядит как-то запарно при работе с разными датасетами)
                self.img_ids.append(image['id'])
                self.img_names.append(file_name)
    

    def _collect_annotations(self) -> None:
        img_ids_set = set(self.img_ids)
        for ann in self.datasets[self.split]['annotations']:
            if ann['image_id'] in img_ids_set:
                self.anns_by_frame[ann['image_id']].append(ann)
        self.category_names = {category['id']: category['name'] for category in self.datasets[self.split]['categories']}


    def _init_bboxes(self) -> None:
        for anns in self.anns_by_frame.values():
            for ann in anns:
                track_id = ann['track_id']
                if track_id in self.bboxes:
                    continue
                category_id = ann['category_id']

                rect = Rectangle(
                    (np.nan, np.nan), np.nan, np.nan,
                    linewidth=self.default_linewidth, edgecolor=self.bbox_colors.get(category_id, 'yellow'), facecolor='none')
                self.ax.add_patch(rect)
                
                label_text = f'{self.category_names[category_id]} (ID:{track_id})'
                label = self.ax.text(
                    np.nan, np.nan, label_text, color='white', fontsize=9,
                    bbox=dict(boxstyle='round,pad=0.3', facecolor=self.bbox_colors.get(category_id, 'yellow'), edgecolor='white'))

                self.bboxes[track_id] = (rect, label)
    

    def _update_bboxes(self) -> None:
        self.visible_track_ids.clear()

        if self.bboxes_visible:
            for ann in self.anns_by_frame[self.img_id]:
                track_id = ann['track_id']
                if not self.selected_track_id:
                    linewidth = self.default_linewidth
                    alpha = 0.75
                    zorder = 0
                else:
                    linewidth = 1.5 * self.default_linewidth if track_id == self.selected_track_id else self.default_linewidth
                    alpha = 1.0 if track_id == self.selected_track_id else 0.5
                    zorder = 1 if track_id == self.selected_track_id else 0
                bbox = self.bboxes[track_id]
                x, y, w, h = ann['bbox']
                bbox[0].set_bounds(x, y, w, h)
                bbox[0].set_alpha(alpha)
                bbox[0].set_linewidth(linewidth)
                bbox[0].set_zorder(zorder)
                if self.bbox_labels_visible:
                    bbox[1].set_position((x, y))
                    bbox[1].set_alpha(alpha)
                    bbox[1].get_bbox_patch().set_alpha(alpha)
                    bbox[1].set_zorder(zorder)
                self.visible_track_ids.add(track_id)
        
        for track_id, bbox in self.bboxes.items():
            track_not_visible = track_id not in self.visible_track_ids
            if track_not_visible:
                bbox[0].set_bounds(np.nan, np.nan, np.nan, np.nan)
            if not self.bbox_labels_visible or track_not_visible:
                bbox[1].set_position((np.nan, np.nan))

    
    def _show_bboxes(self, event) -> None:
        if event.key in ('b', 'и'):
            self.bboxes_visible = not self.bboxes_visible
            self.bbox_labels_visible = not self.bbox_labels_visible
            self._update_bboxes()
            self.ax.figure.canvas.draw_idle()


    def _show_bbox_labels(self, event) -> None:
        if event.key in ('t', 'е'):
            self.bbox_labels_visible = not self.bbox_labels_visible
            self._update_bboxes()
            self.ax.figure.canvas.draw_idle()


    def _select_track(self, event) -> None:
        if event.key == 'control':
            x, y = event.xdata, event.ydata
            new_selected_track_id = None
            for ann in self.anns_by_frame[self.img_id]:
                track_id = ann['track_id']
                x0, y0, w, h = ann['bbox']
                if 0 <= x - x0 <= w and 0 <= y - y0 <= h:
                    new_selected_track_id = track_id
                    break
            if self.selected_track_id == new_selected_track_id:
                self.selected_track_id = None
            else:
                self.selected_track_id = new_selected_track_id
            self._update_bboxes()
            self.ax.figure.canvas.draw_idle()

    
    def _init_frame(self, fig: Figure) -> None:
        self.ax = fig.add_subplot(1, 1, 1)
        self.ax.set_xticks(np.linspace(0, self.IMG_WIDTH, 5))
        self.ax.set_yticks(np.linspace(self.IMG_HEIGHT, 0, 5))

        ax_pos = self.ax.get_position()
        slider_width = 0.6
        slider_height = 0.02
        left = ax_pos.x0 + (ax_pos.width - slider_width) / 2.0
        bottom = ax_pos.y0 - 0.08
        ax_slider = fig.add_axes([left, bottom, slider_width, slider_height])
        self.frame_slider = Slider(ax=ax_slider, label='frame', valmin=1, valmax=len(self.img_ids), valstep=1, valinit=1, color='#405bc9')

        self.img_show = self.ax.imshow([[np.nan]], extent=(0, self.IMG_WIDTH, self.IMG_HEIGHT, 0))


    def _frames_switch(self, event) -> None:
        if event.key in ('d', 'в') and self.frame_slider.val < self.frame_slider.valmax:
            self.frame_slider.set_val(self.frame_slider.val + self.frame_slider.valstep)
        elif event.key in ('a', 'ф') and self.frame_slider.val > self.frame_slider.valmin:
            self.frame_slider.set_val(self.frame_slider.val - self.frame_slider.valstep)


    def _update_frame(self, val) -> None:
        frame_num = self.frame_slider.val
        self.img_id = self.img_ids[frame_num - 1]

        img_path = self.imgs_path / self.img_names[frame_num - 1]
        img = Image.open(img_path)
        self.img_show.set_array(img)

        self._update_bboxes()
        self.ax.figure.canvas.draw_idle()

    
    def run(self, fig: Figure) -> None:
        self._init_frame(fig)
        self.bboxes = {}
        self._init_bboxes()

        fig.canvas.mpl_connect('key_press_event', self._frames_switch)
        fig.canvas.mpl_connect('key_press_event', self._show_bboxes)
        fig.canvas.mpl_connect('key_press_event', self._show_bbox_labels)
        fig.canvas.mpl_connect('button_press_event', self._select_track)

        self.frame_slider.on_changed(self._update_frame)
        self._update_frame(None)

        self.ax.set_title(f'{self.imgs_path.as_posix()} | {len(self.img_ids)} frames | {len(self.bboxes)} objects')

- `a` и `d` — перемотка на предыдущий и следующий кадр соответственно

- `ctrl + leftclick` — кликнуть на боск для выделения / снятия выделения (если кликнуть на уже выделенный бокс или просто в пустоту, выделение снимется)

- `t` — показать/скрыть тексты лейблов к боксам

- `b` — показать/скрыть все боксы (тексты лейблов тоже при этом скрываются)

In [4]:
coco_dir = Path('../converted_labels/coco')

datasets: dict[str, dict[str, list[dict]]] = {}
for split in ['train', 'valid', 'test']:
    with open(coco_dir / f'{split}.json') as f:
        datasets[split] = json.load(f)
    print(f"{split}: {len(datasets[split]['images'])} images, {len(datasets[split]['annotations'])} annotations")

train: 42750 images, 690624 annotations
valid: 43500 images, 707367 annotations
test: 36750 images, 529105 annotations


In [5]:
splits_dir = Path(f'../data')
split = 'train'
video_dir = 'SNGS-060'

In [6]:
visualizer = VideoVisualizer(datasets, splits_dir, split, video_dir)

fig = plt.figure(figsize=(16, 9))
visualizer.run(fig)