In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tempfile
from IPython.display import HTML
from base64 import b64encode

from vicas.dataset import ViCaSDataset, ViCaSVideo
from vicas.caption_parsing import parse_caption

## Dataset API

`ViCaSDataset` is a wrapper class to easily iterate over all videos.

**TODO:** Set `annotations_dir` to the directory path where all the JSON annotations are saved

In [None]:
annotations_dir = "/path/to/vicas/json/annotations/dir"
video_frames_dir = "demo_data/video_frames"
split = None # can be set to 'train', 'val' or 'test' to load a particular split

dataset = ViCaSDataset(
    annotations_dir, 
    split=split,
    video_frames_dir=video_frames_dir
)
print(f"Indexed {len(dataset)} videos from the dataset")

## Video API

`ViCaSVideo` is a wrapper for each video. You can instantiate it through the dataset, or also separately by running `ViCaSVideo.from_json(...)` and providing the path to the JSON file for that video

In [None]:
example_video_id = 9505  # videos for a few example videos are provided in 'demo_data'.
video = dataset.parse_video(example_video_id)
# Alternate:
# ViCaSVideo.from_json(f"{annotations_dir}/{example_video_id:06d}.json")

#### Visualization

After running `video.visualize()`, two outputs will be saved to disk:
- The visualization will be saved as a video to `<viz_temp_dir>/video.mp4`
- The individual video frames of the visualization will be saved to a sub-directory at `<viz_temp_dir/frames`

In [None]:
viz_temp_dir = os.path.join(tempfile.gettempdir(), "ViCaS_demo", f"{example_video_id:06d}")
os.makedirs(viz_temp_dir, exist_ok=True)
print(f"Visualization output will be saved to: {viz_temp_dir}")

In [None]:
video.visualize(viz_temp_dir)

In [None]:
# helper function to play video
def play_video(filename):
    html = ''
    with open(filename, 'rb') as fh:
        video = fh.read()
    src = 'data:video/mp4;base64,' + b64encode(video).decode()
    html += '<video width=600 controls autoplay loop><source src="%s" type="video/mp4"></video>' % src 
    return HTML(html)

# play the visualization video
play_video(os.path.join(viz_temp_dir, "video.mp4"))

#### Captions

The `ViCaSVideo` object contains multiple properties for the caption:
- `caption_orig_raw`: This is the original, human-written caption with our custom syntax for marking phrase grounding
- `caption_orig_parsed`: This is the same as above, but with the custom syntax stripped away i.e. a standard caption
- `caption_gpt_raw`: The result of using GPT4 to remove errors and improve the wording of `caption_orig_raw`.
- `caption_gpt_parsed`: This is the same as above, but with the custom syntax stripped away.

**NOTE:** We use the GPT-improved captions are used for training and evaluation.

In [None]:
print("Caption with phrase-grounding syntax: " + video.caption_gpt_raw)
print("Parsed caption without grounding syntax: " + video.caption_gpt_parsed)

**NOTE:** We also provide an API to parse the raw caption with syntax if you want to programmatically extract the phrase-grounding information:

In [None]:
caption_obj = parse_caption(video.caption_gpt_raw)
print("Parsed caption without grounding syntax: " + caption_obj.parsed)  # same as the parsed caption printed above

You can call print on a `VideoCaption` to pretty-print all the attributes of the caption: the raw and parsed version, and a list of grounding phrases including the object IDs, the string indices of the phrase, and the phrase itself.

In [None]:
print(caption_obj)

#### Language-Guided Video Instance Segmentation (LG-VIS)

The LG-VIS prompts and associated masks can be obtained by calling `video.parse_lgvis()`. This function returns an iterator over the prompts. Each element is a tuple with four entires:
- **prompt** *(str)*: The text prompt
- **masks** *(List[List[np.ndarray]])*: The object masks. The inner list is over different objects (a single prompt can reference multiple objects). The outer-list is over time/frames.
- **track_ids** *(List[int])*: The IDs of the objects.
- **filenames** *(List[str])*: The filenames of the video frames (same length as the `masks`).
- Optional: **viz_frames** *(List[np.ndarray])*: If `return_viz` is set to true, a list of image frames with the prompt and mask visualized will be returned.

In [None]:
print(f"This video contains {video.num_lgvis_prompts} LG-VIS prompts")

In [None]:
for prompt, masks, track_ids, filenames, viz_frames in video.parse_lgvis(return_viz=True): # iterate over prompts
    print("Prompt: " + prompt)
    print(f"There are {len(masks)} frame-level masks")
    print(f"This prompt references {len(masks[0])} object tracks with IDs {track_ids}")
    print(f"Each mask array has shape {masks[0][0].shape} and dtype {masks[0][0].dtype}")
    
    # display 6 frames
    frame_indices = np.linspace(0, len(viz_frames)-1, 6).astype(int).tolist()
    viz_frames = [viz_frames[i] for i in frame_indices]
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    for i, ax in enumerate(axes.flat):
        ax.imshow(viz_frames[i][:, :, ::-1]) # convert image BGR to RGB
        ax.axis('off')
        
    plt.tight_layout()
    plt.show()