Copyright 2021 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Data Loading

CoLAB demonstrating how to download and load compressed video data from Kinetics600.
Download the data to be loaded from [here](https://www.deepmind.com/open-source/kinetics) into a folder and unzip the folder. Update `_DATA_FOLDER` to point to the unzipped folder.

In [None]:
_DATA_FOLDER = '' # @param {type: 'string'}
_SPLIT = 'train' # @param {type: 'string'} ['train', 'valid', 'test']

In [None]:
# @title Installation.
# @markdown This can be skipped if you have installed and are running this locally.

!pip install dm-haiku
!pip install jax

In [None]:
# @title Installation.
# @markdown This can be skipped if you have installed and are running this locally.

!mkdir /content/compressed_vision
!touch /content/compressed_vision/__init__.py
!mkdir /content/compressed_vision/utils/
!touch /content/compressed_vision/utils/__init__.py
!wget -O /content/compressed_vision/utils/video_utils.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/compressed_vision/utils/video_utils.py

In [None]:
# @title Imports
import collections
import json
import os
from typing import Iterable, List, Mapping
import sklearn
from sklearn import decomposition
from IPython.display import Image as display_image

import tensorflow as tf
import pickle
import jax
import numpy as np
import haiku as hk
from PIL import Image

from compressed_vision.utils import video_utils

In [None]:
# @title Load an example
train_path = os.path.join(_DATA_FOLDER, f"{_SPLIT}.tfrecord-*")
filenames = tf.data.Dataset.list_files(train_path, shuffle=True)

ds = tf.data.TFRecordDataset(filenames)

feature_dict = {
    'video_id': tf.io.FixedLenFeature([], tf.string),
    'start_sec': tf.io.FixedLenFeature([], tf.float32),
    'end_sec': tf.io.FixedLenFeature([], tf.float32),
    'compressed_repr': tf.io.FixedLenFeature([], tf.string),
}

if _SPLIT != 'test':
  feature_dict.update({
    'label': tf.io.FixedLenFeature([], tf.string),
    'label_id': tf.io.FixedLenFeature([], tf.int64),})

ds = ds.shuffle(buffer_size=100)

def parse_example(row):
  example = tf.io.parse_example(row, feature_dict)
  example['compressed_repr'] = tf.io.parse_tensor(example['compressed_repr'], tf.uint8)
  return example

ds = ds.map(parse_example)
batch = ds.take(2).batch(2)  # Return 2 batches each with 2 elements

for item in batch:
  for k, v in item.items():
    if len(v.shape) > 1:
      print(k + ': ' + str(v.shape))
    else:
      print(k + ' ' + str(v.numpy()))


In [None]:
# @title Project codes of compressed representation to quantized representation.
# @markdown And plot the PCA of this representation.

!wget 'https://storage.googleapis.com/dm_compressed_vision/data_mappings/41861759_2_cr%3D192.pkl' -O /tmp/model_path.pkl

with open('/tmp/model_path.pkl', 'rb') as f:
  embed_state = pickle.load(f)

def quantize(encoding_indices):
  w = embed_state['embeddings'].swapaxes(1, 0)
  w = jax.device_put(w)  # Required when embeddings is a NumPy array.
  return w[(encoding_indices,)]

quantizations = quantize(encoding_indices=item['compressed_repr'].numpy())
pca_decomp = sklearn.decomposition.PCA(3)
quantizations = quantizations[0, 0]
x_new = pca_decomp.fit_transform(quantizations.reshape(-1, 256))
x_new = x_new.reshape(quantizations.shape[:-1] + (3,))

vis_x = (x_new - x_new.mean()) / (x_new.std()) + 0.5
vis_x = np.clip(vis_x, 0, 1)

v = video_utils.video_reshaper(vis_x)
video_utils.save_video((v * 255).astype(np.uint8), '/tmp/display.gif')
display_image(filename='/tmp/display.gif', embed=True)