In [1]:
import os

import cv2
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.combine import concatenate_datasets
from matplotlib import pyplot as plt
import numpy as np
from transformers import AutoTokenizer, TimesformerForVideoClassification

FRAMES_PER_VIDEO = 16

In [3]:
vision_model = TimesformerForVideoClassification.from_pretrained("facebook/timesformer-base-finetuned-k600")
actions = set()
for action in vision_model.config.id2label.values():
    if len(action.split(" ")) == 1:
        actions.add(action)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [7]:
dataset = load_from_disk("dataset/processed/8frames_part2")
combined = concatenate_datasets([dataset["train"], dataset["validation"]])
combined

Dataset({
    features: ['videoID', 'pixel_values', 'labels'],
    num_rows: 25538
})

In [None]:
action_idxs = {}
for i, item in enumerate(combined):
    tokens = item["labels"]
    caption = tokenizer.decode(tokens, skip_special_tokens=True)
    for word in caption.split(" "):
        if word in actions:
            if word in action_idxs:
                action_idxs[word].append(i)
            else:
                action_idxs[word] = [i]
            break

In [None]:
train_idxs, val_idxs = [], []
for action, idxs in action_idxs.items():
    pivot = int(0.91 * len(idxs))
    train_idxs.extend(idxs[:pivot])
    val_idxs.extend(idxs[pivot:])

dataset["train"] = combined.select(train_idxs)
dataset["validation"] = combined.select(val_idxs)
dataset

In [None]:
dataset.save_to_disk("dataset/processed/k600")

In [None]:
def preprocess(example):
    video_id = example["videoID"]
    captions = example["enCap"]
    
    videos_path = "dataset/videos"
    video_path = os.path.join(videos_path, "%s.mp4" % video_id)
    if not os.path.isfile(video_path):
        video_path = os.path.join(videos_path, "%s.webm" % video_id)
    
    # count number of frames
    video = cv2.VideoCapture(video_path)
    frame_count = 0
    while True:
        ret, _ = video.read()
        if not ret:
            break
        frame_count += 1
    video.release()
        
    # fixed frame sampling
    indices = np.linspace(0, frame_count, num=FRAMES_PER_VIDEO, endpoint=False).astype(np.int64)
    # random frame sampling
    #indices = np.sort(np.random.uniform(low=0, high=frame_count, size=self.num_frames).astype(np.int64))
    
    # get frames
    video = cv2.VideoCapture(video_path)
    frames = []
    frame_count, frame_idx = 0, 0
    while frame_idx < len(indices):
        if frame_count == indices[frame_idx]:
            _, frame = video.read()
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
            frame_idx += 1
        else:
            video.grab()
        frame_count += 1
    video.release()
        
    # longest caption
    max_len = -np.inf
    caption = None
    for cap in captions:
        length = len(cap.split(" "))
        if length > max_len:
            max_len = length
            caption = cap
    # random caption
    #caption = captions[random.randint(0, 9)]

    labels = tokenizer(caption, padding="max_length").input_ids
    return {"pixel_values": frames, "labels": labels}
    
    # pixel_values = image_processor(frames, return_tensors="pt").pixel_values
    # labels = tokenizer(caption, padding="max_length").input_ids
    # return {"pixel_values": pixel_values[0], "labels": labels}

In [2]:
# load json data
data_files = {"train": "dataset/vatex_train_captions.json", "validation": "dataset/vatex_val_captions.json"}
dataset = load_dataset("json", data_files=data_files)
dataset

# dataset["train"] = dataset["train"].select(np.arange(6))
# dataset["validation"] = dataset["validation"].select(np.arange(3))

# dataset = dataset.map(function=preprocess, remove_columns=["enCap", "chCap"])
# dataset.save_to_disk("dataset/raw_frames_16")

Using custom data configuration default-dc0812067ce11954
Found cached dataset json (/home/922201615/.cache/huggingface/datasets/json/default-dc0812067ce11954/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['videoID', 'enCap', 'chCap'],
        num_rows: 22895
    })
    validation: Dataset({
        features: ['videoID', 'enCap', 'chCap'],
        num_rows: 2643
    })
})