In [None]:
import tator

from tator.util import clone_localization_list
import os
from os import makedirs
from os.path import isdir, join
from pathlib import Path
import cv2
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [None]:
HOST = "https://tator.whoi.edu"
TOKEN = open("tator_token.txt", "r").readlines()[0].strip()

FISH_DETECTION_PROJECT = 2
ANIMAL_BBOX_LOCALIZATION_TYPE = 2
CUREE_VIDEO_TYPE = 2
DIVER_SURVEY_TYPE = 3

In [None]:
api = tator.get_api(HOST, TOKEN)

In [None]:
api.get_project_list()

In [None]:
api.get_version_list(FISH_DETECTION_PROJECT)

## Getting / deleting localization lists

In [None]:
# Get localization lists by media_id
media_ids = []

localization_list = api.get_localization_list(
    project=FISH_DETECTION_PROJECT, 
    type=ANIMAL_BBOX_LOCALIZATION_TYPE,
    media_id=media_ids
)
print(len(localization_list))


In [None]:
# Get localization lists by version
version_ids = [5]

localization_list = api.get_localization_list(
    project=FISH_DETECTION_PROJECT, 
    type=ANIMAL_BBOX_LOCALIZATION_TYPE,
    version=version_ids
)
print(len(localization_list))

In [None]:
# Delete localization lists
media_ids = []

ret = api.delete_localization_list(
    project=FISH_DETECTION_PROJECT, 
    type=ANIMAL_BBOX_LOCALIZTION_TYPE,
    media_id=media_ids
)
print(ret)

## Create YOLO training dataset from Version

In [None]:
# generate a list of media ids corresponding to diver / curee videos
media_list = api.get_media_list(FISH_DETECTION_PROJECT)
mediaId_by_mediaType = {DIVER_SURVEY_TYPE: [], CUREE_VIDEO_TYPE: []}
for media in media_list:
    mediaId_by_mediaType[media.meta].append(media.id)
print(mediaId_by_mediaType)

In [None]:
# Create new YOLO-style dataset from a version

#***********************
#Modify these things
#***********************
version_ids = [5]
media_type = CUREE_VIDEO_TYPE # DIVER_SURVEY_TYPE or CUREE_VIDEO_TYPE
dataset_base_path = f"/data_nvme/dxy/datasets" # /media/data/warp_data
offline_frame_dir = "/data_nvme/dxy/biomap" 
dry_run = False
#***********************
#***********************

dataset_name = f"wrs_{'curee' if media_type == CUREE_VIDEO_TYPE else 'diver'}_yolo_dataset_v{version_ids[0]}"
output_dir = f"{dataset_base_path}/{dataset_name}/"

if not isdir(output_dir) and not dry_run:
    makedirs(output_dir)
    makedirs(join(output_dir, "images"), exist_ok=True)
    makedirs(join(output_dir, "labels"), exist_ok=True)
    makedirs(join(output_dir, "groundtruth"), exist_ok=True)

if os.path.exists(offline_frame_dir):
    print("You have a specified a path for offline frames that exists. Will try to load frames from here!")
    print("Note this code path has some assumptions about video/frame/folder naming")
    print("Ideally, some metadata in media would contain a pointer to where to find the frames locally")
    print("Be careful about off by one errors!")
    print("TODO: standardize on a pipeline for diver and curee videos\n")
    trying_offline_frames = True
else:
    print("Offline frame path does not exist. Will load frames from tator!\n")
    trying_offline_frames = False

# Iterate through all verified images corresponding to the media ids in the media type we care about
# (this is sketchy if you have multiple state types)
state_list = api.get_state_list(FISH_DETECTION_PROJECT, version=version_ids, media_id=mediaId_by_mediaType[media_type])
print(f"Num verified frames for media type {media_type}: {len(state_list)} ")

num_localizations = 0
all_localizations = []

for state in tqdm(state_list):
    # get file info
    media = api.get_media(state.media[0])
    
    if media.meta != media_type:
        assert False # this shouldn't happen since we query state_list by media_id now
                  
    # get localizations
    localizations = api.get_localization_list(FISH_DETECTION_PROJECT, version=version_ids, media_id=state.media, frame=state.frame)
    

    # CUREE videos will have something like warpauv_3_xavier4_2022-11-03-10-12-06_forward.mp4 as the media.name
    bag_name = Path(media.name).stem.split('_forward')[0]
    frame_str = str(state.frame).zfill(6) # NOTE: arbitrary amount of padding here
    frame_path = f"{offline_frame_dir}/{bag_name}/forward/vanilla/frame_{frame_str}.png"
    if os.path.exists(frame_path):
        image = Image.open(frame_path)
        assert image.height == media.height
        assert image.width == media.width
    else:        
        if trying_offline_frames:
            print(f"Could not find {frame_path}; grabbing from tator instead")
        # get PIL image from tator
        imgpath = api.get_frame(state.media[0], frames=[state.frame])
        video = api.get_media(state.media[0])
        images = tator.util.get_images(imgpath, video, width=media.width, height=media.height, num_images=1) #note: this function can only retrieve a max of 32 images at a time for some reason
        assert(len(images) == 1) # should only be 1 image per state, otherwise somethings wrong
        image = images[0]
    
    if dry_run:
        continue
    
    # save localizations in YOLO format
    cls = 0
    yolo_cxywhs = np.array([[cls, localization.x+localization.width/2, localization.y+localization.height/2, localization.width, localization.height] for localization in localizations])
    if len(yolo_cxywhs) > 0:
        np.savetxt(join(output_dir, "labels", media.name+f"_f{frame_str}.txt"),yolo_cxywhs, fmt="%i %f %f %f %f")
    else:
        np.savetxt(join(output_dir, "labels", media.name+f"_f{frame_str}.txt"),yolo_cxywhs)
    # save image
    image.save(join(output_dir, "images", media.name+f"_f{frame_str}.png"))
    
    # save groundtruth
    draw = ImageDraw.Draw(image)
    for cxywh in yolo_cxywhs:
        [c,x,y,w,h] = cxywh
        img_w, img_h = image.size
        draw.rectangle([int((x-w/2)*img_w), int((y-h/2)*img_h), int((x+w/2)*img_w), int((y+h/2)*img_h)], outline=(255,0,0))
    image.save(join(output_dir, "groundtruth", media.name+f"_f{frame_str}.png"))
    
    all_localizations.append(localizations)
    num_localizations += len(localizations)

print("Total localizations: ", num_localizations)



In [None]:
# s = api.get_state_list(FISH_DETECTION_PROJECT, version=version_ids, media_id=[417])
m = api.get_media_list(FISH_DETECTION_PROJECT)
[x.meta for x in m]
m = api.get_media(436)
m.height

In [None]:
m = api.get_media(417)
m

## Generate train/test/val splits using yolov5

In [None]:
image_height = 1280
image_width = 720
output_dir = "/data_nvme/dxy/datasets/wrs_curee_yolo_dataset_v5"

import sys
curr_repo = Path(os.path.abspath(os.getcwd()))
yolov5_dir = curr_repo / 'yolov5'
print(f"YOLOV5 directory: {yolov5_dir}")
if str(yolov5_dir) not in sys.path:
    sys.path.append(str(yolov5_dir))

from yolov5.utils.dataloaders import autosplit

print(output_dir)
autosplit(Path(output_dir) / "images", weights=(0.8, 0.1, 0.1))

strs = ["train", "test", "val"]
for split in strs:
    with open(Path(output_dir) / f"autosplit_{split}.txt") as f:
        image_files = f.readlines()
        objects = 0
        areas, widths, heights = [], [], []
        num_bg = 0
        num_images = len(image_files)
        for img_file in image_files:
            label_file = img_file.replace("images", "labels").replace("png", "txt").strip()
            with open(Path(output_dir) / label_file) as label_f:
                labels = label_f.readlines()
                for label in labels:
                    c,x,y,w,h = label.split()
                    width = image_width * float(w)
                    height = image_height * float(h)
                    areas.append(width * height)
                    widths.append(width)
                    heights.append(height)
                    
                objects += len(labels)
                num_bg += len(labels) == 0
        print(f"---------------------")
        print(f"{num_images} total images, {num_bg} background")
        print(f"{split}: {objects} objects")
        print(f"average area of {np.mean(areas)}, ({np.mean(widths)} x {np.mean(heights)})")

## Deleting versions (layers)

In [None]:
api.get_version_list(project=FISH_DETECTION_PROJECT)

In [None]:
# Delete a whole version
version_id = 4
print(api.get_version(id=version_id))
res = input(f"Are you sure you want to delete version {version_id}? (y/n)")
if res == 'y':
    print(f"deleting...")
    api.delete_version(id=version_id)
else:
    print(f"no action taken")

## Cloning layers

In [None]:
query_params = {'project': 2, 'media_id': [437, 438]}
dest_project = 2
version_mapping = {2: 5}
media_mapping = {437: 437, 438: 438}
localization_type_mapping = {2: 2}
created_ids = []
generator = clone_localization_list(api, query_params, dest_project, version_mapping,
                                    media_mapping, localization_type_mapping)
for num_created, num_total, response, id_map in generator:
    print(f"Created {num_created} of {num_total} localizations...")
    created_ids += response.id
print(f"Finished creating {num_created} localizations!")