In [None]:
import cv2
import json
import numpy as np
import re
import subprocess
import torch

from datetime import datetime
from multiprocessing import Process, JoinableQueue as Queue
from os import listdir, makedirs, path
from queue import Empty as QueueEmptyException
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

CAMERA_DB_PATH = "./metadata/cameras.json"

VIDEO_PATH = "../../vids/FULL-0801"
VIDEO_DB_PATH = "./metadata/videos.json"
VIDEO_OUT_PATH = "./metadata/videos"

OCR_MODEL = 'microsoft/trocr-large-printed'
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = TrOCRProcessor.from_pretrained(OCR_MODEL)
model = VisionEncoderDecoderModel.from_pretrained(OCR_MODEL).to(device)

In [None]:
with open(VIDEO_DB_PATH, "r") as f:
  video_data = json.load(f)

with open(CAMERA_DB_PATH, "r") as f:
  camera_data = json.load(f)

In [None]:
DATETIME_PATTERN = r'([0-9]{1,2})[-/]([0-9]{1,2})[-/]([0-9]{2,4}) ([0-9]{1,2}):?([0-9]{1,2}):?([0-9]{1,2})'
DATETIME_FORMAT = '%d%m%Y%H%M%S%z'

def string_to_epoch(datetime_string):
  try:
    matches = list(re.search(DATETIME_PATTERN, datetime_string).groups())
  except:
    matches = ["08", "01", "2023", "00", "00", "00"]

  matches = [('00'+m)[-2:] for m in matches]
  matches[2] = ('20'+matches[2])[-4:]
  matches[2] = '2023' if matches[2] == '2028' else matches[2]
  matches[4] = re.sub(r"8([0-9])", r"3\1", matches[4])
  matches[5] = re.sub(r"8([0-9])", r"3\1", matches[5])
  with_utc_offset = "".join(matches) + "-0300"

  try:
    dt = datetime.strptime(with_utc_offset, DATETIME_FORMAT)
  except:
    dt = datetime.strptime("08012023000000-0300", DATETIME_FORMAT)

  return int(dt.timestamp())

In [None]:
class Stamp:
  def __init__(self, seconds, timestamp):
    self.timestamp = timestamp
    self.seconds = seconds
  def __str__(self):
    return self.stamp().__str__()
  def stamp(self):
    return [self.timestamp, self.seconds]

In [None]:
def get_frames(vid, frame, n=7, step=1):
  frame_count = vid.get(cv2.CAP_PROP_FRAME_COUNT)
  start = frame - int(n / 2) * step
  start = max(0, start)
  start = min(start, frame_count - n * step)

  frames = []
  for i in range(n):
    vid.set(cv2.CAP_PROP_POS_FRAMES, start + i * step)
    _, frame = vid.read()
    frames.append(frame)
  return frames

In [None]:
def get_max_count(txts):
  counts = {}
  for txt in txts:
    counts[txt] = counts.get(txt, 0) + 1
  by_count = sorted([[k,v] for k,v in counts.items()], key=lambda x: x[1], reverse=True)
  return by_count[0][0]

In [None]:
def ocr(imgs, groups=1):
  pixel_values = processor(images=imgs, return_tensors="pt").pixel_values.to(device)
  generated_ids = model.generate(pixel_values)
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
  generated_groups = np.array(generated_text).reshape(groups, -1)
  return [get_max_count(txts) for txts in generated_groups]

In [None]:
def get_stamps(vid, keyframes):
  width = vid.get(cv2.CAP_PROP_FRAME_WIDTH)
  height = vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
  crop_x = int(width / 2)
  crop_h = int(0.11 * height)
  fps = vid.get(cv2.CAP_PROP_FPS)
  frame_seconds = [int(frame // fps) for frame in keyframes]

  frames = [get_frames(vid, frame) for frame in keyframes]
  ocr_frames = [f for fs in frames for f in fs]
  imgs = [f[0:crop_h, crop_x:] for f in ocr_frames]

  dt_str = ocr(imgs, groups=len(keyframes))

  return [Stamp(t, string_to_epoch(s)) for s,t in zip(dt_str, frame_seconds)]

In [None]:
for io_dir in sorted(camera_data.keys()):
  input_dir_path = path.join(VIDEO_PATH, io_dir)
  output_dir_path = path.join(VIDEO_OUT_PATH, io_dir)
  input_files = sorted([f for f in listdir(input_dir_path) if f.endswith("mp4")])
  makedirs(output_dir_path, exist_ok=True)

  for io_file in input_files:
    input_file_path = path.join(input_dir_path, io_file)
    file_data_out_path = path.join(output_dir_path, io_file.replace("mp4", "json"))

    if path.exists(file_data_out_path):
      continue

    print("processing:", io_file)

    if io_file not in video_data:
      video_data[io_file] = {
        "name": io_file,
        "camera": io_dir,
      }

    file_data = video_data[io_file]
    vid = None
    if not ("length_seconds" in file_data and "length_frames" in file_data):
      if vid is None:
        vid = cv2.VideoCapture(input_file_path)

      fps = vid.get(cv2.CAP_PROP_FPS)
      if not fps > 0:
        continue

      length_frames = vid.get(cv2.CAP_PROP_FRAME_COUNT)
      file_data["length_frames"] = int(length_frames)
      file_data["length_seconds"] = int(length_frames // fps)

    if not ("time_start" in file_data and "time_end" in file_data):
      if vid is None:
        vid = cv2.VideoCapture(input_file_path)

      fps = vid.get(cv2.CAP_PROP_FPS)
      length_frames = vid.get(cv2.CAP_PROP_FRAME_COUNT)
      last_frame = length_frames - 1
      length_seconds_fps = int(length_frames // fps)

      num_keyframes = 16
      keyframes_0 = [int(i * last_frame / num_keyframes) for i in range(num_keyframes // 2)]
      keyframes_1 = [int(i * last_frame / num_keyframes) for i in range(num_keyframes // 2, num_keyframes + 1)]

      stamps_0 = get_stamps(vid, keyframes_0)
      stamps_1 = get_stamps(vid, keyframes_1)
      stamps = stamps_0 + stamps_1

      file_data["time_start"] = stamps[0].timestamp
      file_data["time_end"] = stamps[-1].timestamp
      file_data["continuous"] = abs((stamps[-1].timestamp - stamps[0].timestamp) - length_seconds_fps) < 2
      file_data["seek"] = [s.stamp() for s in stamps]

    if vid is not None:
      vid.release()

    with open(file_data_out_path, "w") as f:
      json.dump(file_data, f)

In [None]:
video_data

In [None]:
with open(VIDEO_DB_PATH, "w") as f:
  json.dump(video_data, f, indent=2)

#### TEST OCR

In [None]:
vid = cv2.VideoCapture(input_file_path)
fps = vid.get(cv2.CAP_PROP_FPS)
length_frames = vid.get(cv2.CAP_PROP_FRAME_COUNT)
last_frame = length_frames - 1
length_seconds_fps = int(length_frames // fps)

frames_0 = get_frames(vid, 0, 7)
frames_n = get_frames(vid, last_frame - 7, 7)

rgb_0 = [f[0:75, 720:] for f in frames_0]
rgb_n = [f[0:75, 720:] for f in frames_n]

In [None]:
%%time
pixel_values = processor(images=rgb_0, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text)

#### Recursive Stamping

In [None]:
def stamp_center(vid, stamp_0, stamp_1):
  diff_seconds = stamp_1.seconds - stamp_0.seconds
  diff_timestamp = stamp_1.timestamp - stamp_0.timestamp

  if (diff_seconds) > 1 and abs(diff_seconds - diff_timestamp) > 1:
    center_seconds = (stamp_1.seconds + stamp_0.seconds) / 2 + stamp_0.seconds
    center_frame = center_seconds * vid.get(cv2.CAP_PROP_FPS)

    stamp_c = get_stamp(vid, center_frame)

    left_center = [] # stamp_center(vid, stamp_0, stamp_c)
    right_center = [] # stamp_center(vid, stamp_c, stamp_1)

    return left_center + [stamp_c] + right_center
  else:
    return []