In [1]:
from comet_ml import Experiment
from ast import parse
from logging import root
import re
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler, SequentialSampler
import numpy as np

from torchvision import transforms

from pytorchvideo.models import x3d
# from pytorchvideo.data import RandomClipSampler, UniformClipSampler
from pytorchvideo.data import Ucf101, RandomClipSampler, UniformClipSampler, Kinetics
# from compression_dataset import segUcf101, segkinetics

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)

from tqdm import tqdm
from collections import OrderedDict
import itertools
import os
import pickle

import imageio
import imgaug as ia
import imgaug.augmenters as iaa
import matplotlib.pyplot as plt
import argparse
import configparser
from fractions import Fraction
from compression_transform import SameClipSampler


In [2]:
class LimitDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [3]:
def get_kinetics(subset, config):
    """
    Kinetics400のデータセットを取得
    Args:
        subset (str): "train" or "val"
    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset: 取得したデータセット
    """

    transform = Compose([
        ApplyTransformToKey(
            key="video",
            transform=Compose([
                UniformTemporalSubsample(
                    int(config['VIDEO_NUM_SUBSAMPLED'])),
                # print(config['NUM_WORKERS']),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                RandomShortSideScale(min_size=256, max_size=320,),
                RandomCrop(224),
                RandomHorizontalFlip(),
            ]),
        ),
        ApplyTransformToKey(
            key="label",
            transform=transforms.Lambda(lambda x: x),
        ),
        RemoveKey("audio"),
    ])

    # root_kinetics = '/mnt/NAS-TVS872XT/dataset/Kinetics400/'
    root_kinetics = config['DATA_PATH']
    
    dataset = Kinetics(
        data_path=root_kinetics + subset,
        video_path_prefix=root_kinetics + subset,
        clip_sampler=SameClipSampler(
            clip_duration=float(Fraction(config['CLIP_DURATION']))),
        video_sampler=SequentialSampler,
        decode_audio=False,
        transform=transform,
    )

    return dataset

In [4]:
def make_loader(dataset, config):
    """
    データローダーを作成
    Args:
        dataset (pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): get_datasetメソッドで取得したdataset
    Returns:
        torch.utils.data.DataLoader: 取得したデータローダー
    """
    # args = Args()
    # loader = DataLoader(LimitDataset(dataset),
    #                     batch_size=int(config['BATCH_SIZE']),
    #                     drop_last=True,
    #                     num_workers=int(config['NUM_WORKERS']))
    # return loader
    loader = DataLoader(LimitDataset(dataset),
                        batch_size=8,
                        drop_last=True,
                        num_workers=int(config['NUM_WORKERS']))
    return loader

In [5]:
def get_dataset(dataset, subset, config):
    """
    データセットを取得
    Args:
        dataset (str): "Kinetis400" or "UCF101"
        subset (str): "train" or "val"
    Returns:
        pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset): 取得したデータセット
    """
    # TODO: get_kineticsにjpegcompressionを渡す
    # print('get_dataset:', jpeg_compression)
    if dataset == "Kinetics400":
        return get_kinetics(subset, config)
    # elif dataset == "UCF101":
    #     return get_ucf101(subset, compression, root_path, config, en)
    return False

In [7]:
config = configparser.ConfigParser()
config.read('config.ini')
config = config['Kinetics400']
dataset = get_dataset("Kinetics400", "val", config)
print('num_video', dataset.num_videos)

num_video 19881


In [8]:
sample_loader = make_loader(dataset, config)

RandomClipSampler

In [46]:
# RandomClipSampler
i = 0
for inputs in sample_loader:
    print(i)
    print(inputs["video_name"])
    if i == 10:
        break
    i += 1

0
['0wR5jVB-WPk_000417_000427.mp4', '3caPS4FHFF8_000036_000046.mp4', '3yaoNwz99xM_000062_000072.mp4', '6IbvOJxXnOo_000047_000057.mp4', '6_4kjPiQr7w_000191_000201.mp4', '9EnSwbXxu5g_000056_000066.mp4', '9eTXpQq9QTA_000126_000136.mp4', 'ASFOR-PFsiM_000040_000050.mp4', 'BJ8ZBTYXyAE_000076_000086.mp4', 'BXXE1G7O5ec_000030_000040.mp4']
1
['qMGREwwh4xQ_000068_000078.mp4', 'qNcTjhOhg0M_000005_000015.mp4', 'rlnvv7YF2J0_000042_000052.mp4', 'ursxaisH0eI_000004_000014.mp4', 'v-YI47Q0Mv4_000000_000010.mp4', 'xbwxysJ8Cjs_000015_000025.mp4', '-1iWWmJmwzI_000048_000058.mp4', '-kdfl_GwOtE_000000_000010.mp4', '-shfV8JDCMk_000001_000011.mp4', '0l74QJmidNw_000016_000026.mp4']
2
['rTDOx0cndCc_000005_000015.mp4', 'sA3FAtCfxpw_000003_000013.mp4', 'sUb1RJUivRU_000160_000170.mp4', 'tI1t5ugTvGY_000299_000309.mp4', 'uHR8pakSKkI_000055_000065.mp4', 'upXS1qv56Ww_000077_000087.mp4', 'wCTJ9bKn_sI_000180_000190.mp4', 'wJTOt6reIpg_000015_000025.mp4', 'wdrRBrKBycc_000434_000444.mp4', '06A7IVDdaXo_000091_000101.mp4']
3

UniformClipSampler

In [52]:
# UniformClipSampler
i = 0
for inputs in sample_loader:
    print(i)
    print(inputs["video_name"])
    if i == 10:
        break
    i += 1

0
['0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4']
1
['qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4', 'qMGREwwh4xQ_000068_000078.mp4']
2
['rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4', 'rTDOx0cndCc_000005_000015.mp4']
3

num_workers = 0 


UniformClipSampler

In [65]:
# num_workers = 0
# UniformClipSampler
i = 0
for inputs in sample_loader:
    print(i)
    print(inputs["video_name"])
    if i == 20:
        break
    i += 1

0
['0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4']
1
['0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '0wR5jVB-WPk_000417_000427.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4']
2
['3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4', '3caPS4FHFF8_000036_000046.mp4']
3

# バッチの中身を確認

In [9]:
i = 0
for inputs in sample_loader:
    print(i)
    print(inputs["label"])
    # if i == 20:
    #     break
    i += 1

0
tensor([0, 0, 0, 0, 0, 0, 0, 0])
1
tensor([49, 49, 49, 49, 49, 49, 50, 50])
2
tensor([99, 99, 99, 99, 99, 99, 99, 99])
3
tensor([150, 150, 150, 150, 150, 150, 150, 150])
4
tensor([200, 200, 200, 200, 200, 200, 200, 200])
5
tensor([249, 249, 249, 249, 250, 250, 250, 250])
6
tensor([299, 299, 299, 299, 299, 299, 299, 300])
7
tensor([350, 350, 350, 350, 350, 350, 350, 350])
8
tensor([0, 0, 0, 0, 0, 0, 0, 0])
9
tensor([50, 50, 50, 50, 50, 50, 50, 50])
10
tensor([ 99, 100, 100, 100, 100, 100, 100, 100])
11
tensor([150, 150, 150, 150, 150, 150, 150, 150])
12
tensor([200, 200, 200, 200, 200, 200, 200, 200])
13
tensor([250, 250, 250, 250, 250, 250, 250, 250])
14
tensor([300, 300, 300, 300, 300, 300, 300, 300])
15
tensor([350, 350, 350, 350, 350, 350, 350, 350])
16
tensor([0, 0, 0, 0, 0, 0, 0, 0])
17
tensor([50, 50, 50, 50, 50, 50, 50, 50])
18
tensor([100, 100, 100, 100, 100, 100, 100, 100])
19
tensor([150, 150, 150, 150, 150, 150, 150, 150])
20
tensor([200, 200, 200, 200, 200, 200, 200, 200]

KeyboardInterrupt: 