In [1]:
import json
import random
import zipfile
import numpy as np
import cv2 as cv
from joblib import Parallel, delayed
from timeit import default_timer as timer
from torchvision.transforms.functional import crop as crop_image
from os.path import exists, join, basename, isdir
from os import makedirs, remove, listdir, rmdir, rename
from six.moves import urllib
from PIL import Image

In [15]:
import torch.utils.data as data
import torch
from torchvision.transforms import CenterCrop, RandomHorizontalFlip, RandomVerticalFlip, RandomRotation
import random
from torch import nn

In [3]:
CROP_SIZE: int = 128
RANDOM_TEMPORAL_ORDER_SWAP_PROB: float = 0.5
MAX_TRAINING_SAMPLES: int = 500_000
MAX_VALIDATION_SAMPLES: int = 100

In [4]:
def prepare_dataset():
    
    dataset_dir = ""
    workers = 0
    json_path = join(dataset_dir, 'patches.json')
    
    davis_dir = ""
    tuples = tuples_from_davis(davis_dir, res='480p')
    

    patches = extract_patches(tuples,max_per_frame=20,trials_per_tuple=30,flow_threshold=25.0,jumpcut_threshold=8e-3)
    
    random.shuffle(patches)
    
    with open(json_path, 'w') as f:
        json.dump(patches, f)
        
    return patches

In [5]:
def tuples_from_davis(davis_dir, res='480p'):

    subdir = join(davis_dir, "JPEGImages/" + res)

    video_dirs = [join(subdir, x) for x in listdir(subdir)]
    video_dirs = [x for x in video_dirs if isdir(x)]

    tuples = []
    for video_dir in video_dirs:

        frame_paths = [join(video_dir, x) for x in listdir(video_dir)]
        frame_paths = [x for x in frame_paths if is_image(x)]
        frame_paths.sort()

        for i in range(len(frame_paths) // 3):
            x1, t, x2 = frame_paths[i * 3], frame_paths[i * 3 + 1], frame_paths[i * 3 + 2]
            tuples.append((x1, t, x2))

    return tuples

In [6]:
def extract_patches(tuples, max_per_frame=1, trials_per_tuple=100, flow_threshold=0.0,jumpcut_threshold=np.inf):

    patch_h, patch_w = 150,150
    n_tuples = len(tuples)
    all_patches = []
    jumpcuts = 0
    flowfiltered = 0
    total_iters = n_tuples * trials_per_tuple

    pil_to_numpy = lambda x: np.array(x)[:, :, ::-1]

    for tup_index in range(n_tuples):
        tup = tuples[tup_index]

        left, middle, right = (load_img(x) for x in tup)
        img_w, img_h = left.size

        left = pil_to_numpy(left)
        middle = pil_to_numpy(middle)
        right = pil_to_numpy(right)

        selected_patches = []

        for _ in range(trials_per_tuple):

            i = random.randint(0, img_h - patch_h)
            j = random.randint(0, img_w - patch_w)

            left_patch = left[i:i + patch_h, j:j + patch_w, :]
            right_patch = right[i:i + patch_h, j:j + patch_w, :]
            middle_patch = middle[i:i + patch_h, j:j + patch_w, :]

            if is_jumpcut(left_patch, middle_patch, jumpcut_threshold) or \
                    is_jumpcut(middle_patch, right_patch, jumpcut_threshold):
                jumpcuts += 1
                continue

            avg_flow = simple_flow(left_patch, right_patch)
            if random.random() > avg_flow / flow_threshold:
                flowfiltered += 1
                continue

            selected_patches.append({
                "left_frame": tup[0],
                "middle_frame": tup[1],
                "right_frame": tup[2],
                "patch_i": i,
                "patch_j": j,
                "avg_flow": avg_flow
            })

        sorted(selected_patches, key=lambda x: x['avg_flow'], reverse=True)
        all_patches += selected_patches[:max_per_frame]

    print('===> Processed {} tuples, {} patches extracted, {} discarded as jumpcuts, {} filtered by flow'.format(
        n_tuples, len(all_patches), 100.0 * jumpcuts / total_iters, 100.0 * flowfiltered / total_iters
    ))

    return all_patches

In [7]:
def is_image(file_path):
    return any(file_path.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

In [8]:
def is_jumpcut(frame1, frame2, threshold=np.inf):
    pixels_per_channel = frame1.size / 3
    hist = lambda x: np.histogram(x.reshape(-1), 8, (0, 255))[0] / pixels_per_channel
    err = lambda a, b: ((hist(a) - hist(b)) ** 2).mean()

    return err(frame1[:, :, 0], frame2[:, :, 0]) > threshold or \
           err(frame1[:, :, 1], frame2[:, :, 1]) > threshold or \
           err(frame1[:, :, 2], frame2[:, :, 2]) > threshold


def simple_flow(frame1, frame2):
    flow = cv.optflow.calcOpticalFlowSF(frame1, frame2, layers=3, averaging_block_size=2, max_flow=4)
    n = np.sum(1 - np.isnan(flow), axis=(0, 1))
    flow[np.isnan(flow)] = 0
    return np.linalg.norm(np.sum(flow, axis=(0, 1)) / n)

In [9]:
def load_patch(patch):
    paths = (patch['left_frame'], patch['middle_frame'], patch['right_frame'])
    i, j = (patch['patch_i'], patch['patch_j'])
    imgs = [load_img(x) for x in paths]
    h, w = config.PATCH_SIZE
    return tuple(crop_image(x, i, j, h, w) for x in imgs)

In [10]:
def load_img(file_path):
    return Image.open(file_path).convert('RGB')

In [11]:
def pil_to_numpy(x_pil):
    return np.rollaxis(np.asarray(x_pil) / 255.0, 2)


def pil_to_tensor(x_pil):
    x_np = pil_to_numpy(x_pil)
    return torch.from_numpy(x_np).float()


def numpy_to_pil(x_np):
    x_np = x_np.copy()
    x_np *= 255.0
    x_np = x_np.clip(0, 255)
    x_np = np.rollaxis(x_np, 0, 3).astype(np.uint8)
    return Image.fromarray(x_np, mode='RGB')

In [12]:
class PatchDataset(data.Dataset):

    def __init__(self, patches):
        super(PatchDataset, self).__init__()
        self.patches = patches
        self.crop = CenterCrop(CROP_SIZE)
        self.random_transforms = [RandomRotation((90, 90)), RandomVerticalFlip(1.0), RandomHorizontalFlip(1.0),(lambda x: x)]
        self.get_aug_transform = (lambda: random.sample(self.random_transforms, 1)[0])
        self.load_patch = load_patch

        print('Dataset ready with {} tuples.'.format(len(patches)))

    @staticmethod
    def random_temporal_order_swap(x1, x2):
        if random.random() <= RANDOM_TEMPORAL_ORDER_SWAP_PROB:
            return x2, x1
        else:
            return x1, x2

    def __getitem__(self, index):
        frames = self.load_patch(self.patches[index])
        aug_transform = self.get_aug_transform()
        x1, target, x2 = (pil_to_tensor(self.crop(aug_transform(x))) for x in frames)
        x1, x2, = self.random_temporal_order_swap(x1, x2)
        input = torch.cat((x1, x2), dim=0)
        return input, target

    def __len__(self):
        return len(self.patches)

In [13]:
class ValidationDataset(data.Dataset):

    def __init__(self, tuples):
        super(ValidationDataset, self).__init__()
        self.tuples = tuples
        self.crop = CenterCrop(CROP_SIZE)

    def __getitem__(self, index):
        frames = self.tuples[index]
        x1, target, x2 = (pil_to_tensor(self.crop(load_img(x))) for x in frames)
        input = torch.cat((x1, x2), dim=0)
        return input, target

    def __len__(self):
        return len(self.tuples)

In [14]:
def get_training_set():
    patches = prepare_dataset()
    patches = patches[:MAX_TRAINING_SAMPLES]
    return PatchDataset(patches)

def get_validation_set():
    davis_17_test = "path to davis_17_test"
    tuples = tuples_from_davis(davis_17_test, res='480p')
    n_samples = min(len(tuples), MAX_VALIDATION_SAMPLES)
    random_ = random.Random(42)
    tuples = random_.sample(tuples, n_samples)
    return ValidationDataset(tuples)