In [None]:
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
%cd perceiver/train/


In [None]:
%cd ../

In [None]:
# Install dependencies for Google Colab.
# If you want to run this notebook on your own machine, you can skip this cell.
!pip install dm-haiku
!pip install einops

!mkdir /content/perceiver
!touch /content/perceiver/__init__.py
!wget -O /content/perceiver/io_processors.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/io_processors.py
!wget -O /content/perceiver/perceiver.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/perceiver.py
!wget -O /content/perceiver/position_encoding.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/position_encoding.py

%cd perceiver/
!mkdir train
!wget -O /content/perceiver/train/launch_local.sh https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/train/launch_local.sh
!wget -O /content/perceiver/train/autoaugment.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/train/autoaugment.py
!wget -O /content/perceiver/train/dataset.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/train/dataset.py
!wget -O /content/perceiver/train/experiment.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/train/experiment.py
!wget -O /content/perceiver/train/utils.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/train/utils.py
%cd ../

In [None]:
#@title Imports

import base64
import functools
import os
import pickle
import ssl
import re
import tempfile

from urllib import request

import cv2
import haiku as hk
import imageio
import jax
import jax.numpy as jnp
import numpy as np
import scipy.io.wavfile

from IPython.display import HTML

from perceiver import perceiver, io_processors


In [None]:
#@title Helper functions for the UCF101 dataset

# Utilities to fetch videos from UCF101 dataset
UCF_ROOT = 'https://www.crcv.ucf.edu/THUMOS14/UCF101/UCF101/'
_VIDEO_LIST = None
_CACHE_DIR = tempfile.mkdtemp()
# As of July 2020, crcv.ucf.edu doesn't use a certificate accepted by the
# default Colab environment anymore.
unverified_context = ssl._create_unverified_context()

def list_ucf_videos():
  """Lists videos available in UCF101 dataset."""
  global _VIDEO_LIST
  if not _VIDEO_LIST:
    index = request.urlopen(UCF_ROOT, context=unverified_context).read().decode('utf-8')
    videos = re.findall('(v_[\w_]+\.avi)', index)
    _VIDEO_LIST = sorted(set(videos))
  return list(_VIDEO_LIST)

def fetch_ucf_video(video):
  """Fetchs a video and cache into local filesystem."""
  cache_path = os.path.join(_CACHE_DIR, video)
  if not os.path.exists(cache_path):
    urlpath = request.urljoin(UCF_ROOT, video)
    print('Fetching %s => %s' % (urlpath, cache_path))
    data = request.urlopen(urlpath, context=unverified_context).read()
    open(cache_path, "wb").write(data)
  return cache_path

# Utilities to open video files using CV2
def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]

def load_video(path, max_frames=0, resize=(224, 224)):
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]]
      frames.append(frame)
      
      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  return np.array(frames) / 255.0

def to_gif(images):
  converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)
  imageio.mimsave('./animation.gif', converted_images, fps=25)
  with open('./animation.gif', 'rb') as f:
    gif_64 = base64.b64encode(f.read()).decode('utf-8')
  return HTML('<img src="data:image/gif;base64,%s"/>' % gif_64)

def play_audio(data, sample_rate=48000):
  scipy.io.wavfile.write('tmp_audio.wav', sample_rate, data)

  with open('./tmp_audio.wav', 'rb') as f:
    audio_64 = base64.b64encode(f.read()).decode('utf-8')
  return HTML('<audio controls src="data:audio/wav;base64,%s"/>' % audio_64)

def table(elements):
  row = ['<td>%s</td>' % el.data for el in elements]
  return HTML('<table><tr>%s</tr></table>' % ''.join(row))

In [None]:
#@title Load video and audio from UCF

video_names = list_ucf_videos()
video_path = fetch_ucf_video(video_names[0])

# Extract audio using FFMPEG and encode as pcm float wavfile (only format readable by scipy.io.wavfile).
!yes | ffmpeg -i "$video_path"  -c copy  -f wav -map 0:a pcm_f32le -ar 48000 output.wav

sample_rate, audio = scipy.io.wavfile.read("output.wav")
if audio.dtype == np.int16:
  audio = audio.astype(np.float32) / 2**15
elif audio.dtype != np.float32:
  raise ValueError('Unexpected datatype. Model expects sound samples to lie in [-1, 1]')

video = load_video(video_path)

In [None]:
#@title Kinetics 700 Classes
KINETICS_CLASSES = ["abseiling", "acting in play", "adjusting glasses", "air drumming", 
"alligator wrestling", "answering questions", "applauding", "applying cream", 
"archaeological excavation", "archery", "arguing", "arm wrestling", 
"arranging flowers", "arresting", "assembling bicycle", "assembling computer", 
"attending conference", "auctioning", "baby waking up", "backflip (human)", 
"baking cookies", "bandaging", "barbequing", "bartending", 
"base jumping", "bathing dog", "battle rope training", "beatboxing", 
"bee keeping", "being excited", "being in zero gravity", "belly dancing", 
"bench pressing", "bending back", "bending metal", "biking through snow", 
"blasting sand", "blending fruit", "blowdrying hair", "blowing bubble gum", 
"blowing glass", "blowing leaves", "blowing nose", "blowing out candles", 
"bobsledding", "bodysurfing", "bookbinding", "bottling", 
"bouncing ball (not juggling)", "bouncing on bouncy castle", "bouncing on trampoline", "bowling", 
"braiding hair", "breading or breadcrumbing", "breakdancing", "breaking boards", 
"breaking glass", "breathing fire", "brush painting", "brushing floor", 
"brushing hair", "brushing teeth", "building cabinet", "building lego", 
"building sandcastle", "building shed", "bulldozing", "bungee jumping", 
"burping", "busking", "calculating", "calligraphy", 
"canoeing or kayaking", "capoeira", "capsizing", "card stacking", 
"card throwing", "carrying baby", "carrying weight", "cartwheeling", 
"carving ice", "carving marble", "carving pumpkin", "carving wood with a knife", 
"casting fishing line", "catching fish", "catching or throwing baseball", "catching or throwing frisbee", 
"catching or throwing softball", "celebrating", "changing gear in car", "changing oil", 
"changing wheel (not on bike)", "chasing", "checking tires", "checking watch", 
"cheerleading", "chewing gum", "chiseling stone", "chiseling wood", 
"chopping meat", "chopping wood", "clam digging", "clapping", 
"clay pottery making", "clean and jerk", "cleaning gutters", "cleaning pool", 
"cleaning shoes", "cleaning toilet", "cleaning windows", "climbing a rope", 
"climbing ladder", "climbing tree", "closing door", "coloring in", 
"combing hair", "contact juggling", "contorting", "cooking chicken", 
"cooking egg", "cooking on campfire", "cooking sausages (not on barbeque)", "cooking scallops", 
"cosplaying", "coughing", "counting money", "country line dancing", 
"cracking back", "cracking knuckles", "cracking neck", "crawling baby", 
"crocheting", "crossing eyes", "crossing river", "crying", 
"cumbia", "curling (sport)", "curling eyelashes", "curling hair", 
"cutting apple", "cutting cake", "cutting nails", "cutting orange", 
"cutting pineapple", "cutting watermelon", "dancing ballet", "dancing charleston", 
"dancing gangnam style", "dancing macarena", "deadlifting", "dealing cards", 
"decorating the christmas tree", "decoupage", "delivering mail", "digging", 
"dining", "directing traffic", "disc golfing", "diving cliff", 
"docking boat", "dodgeball", "doing aerobics", "doing jigsaw puzzle", 
"doing laundry", "doing nails", "doing sudoku", "drawing", 
"dribbling basketball", "drinking shots", "driving car", "driving tractor", 
"drooling", "drop kicking", "drumming fingers", "dumpster diving", 
"dunking basketball", "dyeing eyebrows", "dyeing hair", "eating burger", 
"eating cake", "eating carrots", "eating chips", "eating doughnuts", 
"eating hotdog", "eating ice cream", "eating nachos", "eating spaghetti", 
"eating watermelon", "egg hunting", "embroidering", "entering church", 
"exercising arm", "exercising with an exercise ball", "extinguishing fire", "faceplanting", 
"falling off bike", "falling off chair", "feeding birds", "feeding fish", 
"feeding goats", "fencing (sport)", "fidgeting", "filling cake", 
"filling eyebrows", "finger snapping", "fixing bicycle", "fixing hair", 
"flint knapping", "flipping bottle", "flipping pancake", "fly tying", 
"flying kite", "folding clothes", "folding napkins", "folding paper", 
"front raises", "frying vegetables", "gargling", "geocaching", 
"getting a haircut", "getting a piercing", "getting a tattoo", "giving or receiving award", 
"gold panning", "golf chipping", "golf driving", "golf putting", 
"gospel singing in church", "grinding meat", "grooming cat", "grooming dog", 
"grooming horse", "gymnastics tumbling", "hammer throw", "hand washing clothes", 
"head stand", "headbanging", "headbutting", "helmet diving", 
"herding cattle", "high fiving", "high jump", "high kick", 
"historical reenactment", "hitting baseball", "hockey stop", "holding snake", 
"home roasting coffee", "hopscotch", "hoverboarding", "huddling", 
"hugging (not baby)", "hugging baby", "hula hooping", "hurdling", 
"hurling (sport)", "ice climbing", "ice fishing", "ice skating", 
"ice swimming", "inflating balloons", "installing carpet", "ironing", 
"ironing hair", "javelin throw", "jaywalking", "jetskiing", 
"jogging", "juggling balls", "juggling fire", "juggling soccer ball", 
"jumping bicycle", "jumping into pool", "jumping jacks", "jumping sofa", 
"jumpstyle dancing", "karaoke", "kicking field goal", "kicking soccer ball", 
"kissing", "kitesurfing", "knitting", "krumping", 
"land sailing", "laughing", "lawn mower racing", "laying bricks", 
"laying concrete", "laying decking", "laying stone", "laying tiles", 
"leatherworking", "letting go of balloon", "licking", "lifting hat", 
"lighting candle", "lighting fire", "listening with headphones", "lock picking", 
"long jump", "longboarding", "looking at phone", "looking in mirror", 
"luge", "lunge", "making a cake", "making a sandwich", 
"making balloon shapes", "making bubbles", "making cheese", "making horseshoes", 
"making jewelry", "making latte art", "making paper aeroplanes", "making pizza", 
"making slime", "making snowman", "making sushi", "making tea", 
"making the bed", "marching", "marriage proposal", "massaging back", 
"massaging feet", "massaging legs", "massaging neck", "massaging person's head", 
"metal detecting", "milking cow", "milking goat", "mixing colours", 
"moon walking", "mopping floor", "mosh pit dancing", "motorcycling", 
"mountain climber (exercise)", "moving baby", "moving child", "moving furniture", 
"mowing lawn", "mushroom foraging", "needle felting", "news anchoring", 
"opening bottle (not wine)", "opening coconuts", "opening door", "opening present", 
"opening refrigerator", "opening wine bottle", "packing", "paragliding", 
"parasailing", "parkour", "passing American football (in game)", "passing American football (not in game)", 
"passing soccer ball", "peeling apples", "peeling banana", "peeling potatoes", 
"person collecting garbage", "petting animal (not cat)", "petting cat", "petting horse", 
"photobombing", "photocopying", "picking apples", "picking blueberries", 
"pillow fight", "pinching", "pirouetting", "planing wood", 
"planting trees", "plastering", "playing accordion", "playing american football", 
"playing badminton", "playing bagpipes", "playing basketball", "playing bass guitar", 
"playing beer pong", "playing billiards", "playing blackjack", "playing cards", 
"playing cello", "playing checkers", "playing chess", "playing clarinet", 
"playing controller", "playing cricket", "playing cymbals", "playing darts", 
"playing didgeridoo", "playing dominoes", "playing drums", "playing field hockey", 
"playing flute", "playing gong", "playing guitar", "playing hand clapping games", 
"playing harmonica", "playing harp", "playing ice hockey", "playing keyboard", 
"playing kickball", "playing laser tag", "playing lute", "playing mahjong", 
"playing maracas", "playing marbles", "playing monopoly", "playing netball", 
"playing nose flute", "playing oboe", "playing ocarina", "playing organ", 
"playing paintball", "playing pan pipes", "playing piano", "playing piccolo", 
"playing pinball", "playing ping pong", "playing poker", "playing polo", 
"playing recorder", "playing road hockey", "playing rounders", "playing rubiks cube", 
"playing saxophone", "playing scrabble", "playing shuffleboard", "playing slot machine", 
"playing squash or racquetball", "playing tennis", "playing trombone", "playing trumpet", 
"playing ukulele", "playing violin", "playing volleyball", "playing with trains", 
"playing xylophone", "poaching eggs", "poking bellybutton", "pole vault", 
"polishing furniture", "polishing metal", "popping balloons", "pouring beer", 
"pouring milk", "pouring wine", "preparing salad", "presenting weather forecast", 
"pretending to be a statue", "pull ups", "pulling espresso shot", "pulling rope (game)", 
"pumping fist", "pumping gas", "punching bag", "punching person (boxing)", 
"push up", "pushing car", "pushing cart", "pushing wheelbarrow", 
"pushing wheelchair", "putting in contact lenses", "putting on eyeliner", "putting on foundation", 
"putting on lipstick", "putting on mascara", "putting on sari", "putting on shoes", 
"putting wallpaper on wall", "raising eyebrows", "reading book", "reading newspaper", 
"recording music", "repairing puncture", "riding a bike", "riding camel", 
"riding elephant", "riding mechanical bull", "riding mule", "riding or walking with horse", 
"riding scooter", "riding snow blower", "riding unicycle", "ripping paper", 
"roasting marshmallows", "roasting pig", "robot dancing", "rock climbing", 
"rock scissors paper", "roller skating", "rolling eyes", "rolling pastry", 
"rope pushdown", "running on treadmill", "sailing", "salsa dancing", 
"saluting", "sanding floor", "sanding wood", "sausage making", 
"sawing wood", "scrambling eggs", "scrapbooking", "scrubbing face", 
"scuba diving", "seasoning food", "separating eggs", "setting table", 
"sewing", "shaking hands", "shaking head", "shaping bread dough", 
"sharpening knives", "sharpening pencil", "shaving head", "shaving legs", 
"shearing sheep", "shining flashlight", "shining shoes", "shoot dance", 
"shooting basketball", "shooting goal (soccer)", "shooting off fireworks", "shopping", 
"shot put", "shouting", "shoveling snow", "shredding paper", 
"shucking oysters", "shuffling cards", "shuffling feet", "side kick", 
"sieving", "sign language interpreting", "silent disco", "singing", 
"sipping cup", "situp", "skateboarding", "ski ballet", 
"ski jumping", "skiing crosscountry", "skiing mono", "skiing slalom", 
"skipping rope", "skipping stone", "skydiving", "slacklining", 
"slapping", "sled dog racing", "sleeping", "slicing onion", 
"smashing", "smelling feet", "smoking", "smoking hookah", 
"smoking pipe", "snatch weight lifting", "sneezing", "snorkeling", 
"snowboarding", "snowkiting", "snowmobiling", "somersaulting", 
"spelunking", "spinning plates", "spinning poi", "splashing water", 
"spray painting", "spraying", "springboard diving", "square dancing", 
"squat", "squeezing orange", "stacking cups", "stacking dice", 
"standing on hands", "staring", "steer roping", "steering car", 
"sticking tongue out", "stomping grapes", "stretching arm", "stretching leg", 
"sucking lolly", "surfing crowd", "surfing water", "surveying", 
"sweeping floor", "swimming backstroke", "swimming breast stroke", "swimming butterfly stroke", 
"swimming front crawl", "swimming with dolphins", "swimming with sharks", "swing dancing", 
"swinging baseball bat", "swinging on something", "sword fighting", "sword swallowing", 
"tackling", "tagging graffiti", "tai chi", "taking photo", 
"talking on cell phone", "tango dancing", "tap dancing", "tapping guitar", 
"tapping pen", "tasting beer", "tasting food", "tasting wine", 
"testifying", "texting", "threading needle", "throwing axe", 
"throwing ball (not baseball or American football)", "throwing discus", "throwing knife", "throwing snowballs", 
"throwing tantrum", "throwing water balloon", "tickling", "tie dying", 
"tightrope walking", "tiptoeing", "tobogganing", "tossing coin", 
"tossing salad", "training dog", "trapezing", "treating wood", 
"trimming or shaving beard", "trimming shrubs", "trimming trees", "triple jump", 
"twiddling fingers", "tying bow tie", "tying knot (not on a tie)", "tying necktie", 
"tying shoe laces", "unboxing", "uncorking champagne", "unloading truck", 
"using a microscope", "using a paint roller", "using a power drill", "using a sledge hammer", 
"using a wrench", "using atm", "using bagging machine", "using circular saw", 
"using inhaler", "using megaphone", "using puppets", "using remote controller (not gaming)", 
"using segway", "vacuuming car", "vacuuming floor", "visiting the zoo", 
"wading through mud", "wading through water", "waiting in line", "waking up", 
"walking on stilts", "walking the dog", "walking through snow", "walking with crutches", 
"washing dishes", "washing feet", "washing hair", "washing hands", 
"watching tv", "water skiing", "water sliding", "watering plants", 
"waving hand", "waxing armpits", "waxing back", "waxing chest", 
"waxing eyebrows", "waxing legs", "weaving basket", "weaving fabric", 
"welding", "whistling", "windsurfing", "winking", 
"wood burning (art)", "wrapping present", "wrestling", "writing", 
"yarn spinning", "yawning", "yoga", "zumba"]

In [None]:
# Visualize inputs
table([to_gif(video), play_audio(audio)])

In [None]:
#@title Model construction
NUM_FRAMES = 16
AUDIO_SAMPLES_PER_FRAME = 48000 // 25
SAMPLES_PER_PATCH = 16
NUM_CLASSES = 700
IMG_SZ = 56

def video_autoencoder(images, audio, subsampling):
  n_audio_samples = NUM_FRAMES * AUDIO_SAMPLES_PER_FRAME
  input_preprocessor = io_processors.MultimodalPreprocessor(
      min_padding_size=4,
      modalities={
          'audio': io_processors.AudioPreprocessor(
              position_encoding_type='fourier',
              fourier_position_encoding_kwargs=dict(
                  num_bands=192,
                  max_resolution=(n_audio_samples,),
                  sine_only=False,
                  concat_pos=True,
              ),
              n_extra_pos_mlp=0,
              prep_type='patches',
              samples_per_patch=16),
          'image': io_processors.ImagePreprocessor(
              position_encoding_type='fourier',
              fourier_position_encoding_kwargs=dict(
                  num_bands=32,
                  max_resolution=(NUM_FRAMES, IMG_SZ, IMG_SZ),
                  sine_only=False,
                  concat_pos=True,
              ),
              n_extra_pos_mlp=0,
              prep_type='patches',
              spatial_downsample=4,
              temporal_downsample=1),
          'label': io_processors.OneHotPreprocessor(),
      },
      mask_probs={'image': 0.0, 'audio': 0.0, 'label': 1.0},
  )

  output_postprocessor = io_processors.MultimodalPostprocessor(
      modalities={
          'audio': io_processors.AudioPostprocessor(
              samples_per_patch=SAMPLES_PER_PATCH),
          'image': io_processors.ProjectionPostprocessor(
              num_outputs=3),
          'label': io_processors.ClassificationPostprocessor(
              num_classes=NUM_CLASSES),
      })

  encoder = encoder = perceiver.PerceiverEncoder(
      num_self_attends_per_block=8,
      # Weights won't be shared if num_blocks is set to 1.
      num_blocks=1,
      z_index_dim=28*28*1,
      num_z_channels=512,
      num_cross_attend_heads=1,
      num_self_attend_heads=8,
      cross_attend_widening_factor=1,
      self_attend_widening_factor=1,
      dropout_prob=0.0,
      z_pos_enc_init_scale=0.02,
      cross_attention_shape_for_attn='kv',
      name='encoder')

  subsampled_index_dims = {
      'audio': subsampling['audio'].shape[0],
      'image': subsampling['image'].shape[0],
      'label': 1,
  }
  image_decoder = perceiver.BasicVideoAutoencodingDecoder(
      # Autoencoding, don't pass inputs to the queries.
      concat_preprocessed_input=False,
      subsampled_index_dims=subsampling['image'],
      output_shape=images.shape[:4],
      num_z_channels=1024,
      output_num_channels=512,
      use_query_residual=False,
      position_encoding_type='fourier',
      fourier_position_encoding_kwargs=dict(
          num_bands=32,
          max_resolution=(NUM_FRAMES, IMG_SZ, IMG_SZ),
          sine_only=False,
          concat_pos=True,
      ),
  )

  decoder = perceiver.MultimodalDecoder(
      # Autoencoding, don't pass inputs to the queries.
      concat_preprocessed_input=False,
      subsampled_index_dims=subsampled_index_dims,
      # Modality specific decoders are used ONLY to generate queries.
      # All modalties are decoded together using a unified decoder.
      modalities={
          'audio': perceiver.BasicDecoder(
              # Autoencoding, don't pass inputs to the queries.
              concat_preprocessed_input=False,
              subsampled_index_dims=subsampling['audio'],
              output_index_dims=(n_audio_samples // SAMPLES_PER_PATCH,),
              num_z_channels=1024,
              output_num_channels=512,
              use_query_residual=False,
              position_encoding_type='fourier',
              fourier_position_encoding_kwargs=dict(
                  num_bands=192,
                  max_resolution=(n_audio_samples,),
                  sine_only=False,
                  concat_pos=True,
              ),
           ),
          'image': image_decoder,
          'label': perceiver.ClassificationDecoder(
              # Autoencoding, don't pass inputs to the queries.
              concat_preprocessed_input=False,
              num_classes=NUM_CLASSES,
              num_z_channels=1024,
              use_query_residual=False,
              position_encoding_type='trainable',
              trainable_position_encoding_kwargs=dict(
                  num_channels=1024,
                  init_scale=0.02,
              ),
          ),
      },
      num_outputs=None,
      output_num_channels=512,
      use_query_residual=False,)
  
  model = perceiver.Perceiver(
      input_preprocessor=input_preprocessor,
      encoder=encoder,
      decoder=decoder,
      output_postprocessor=output_postprocessor)
  
  return model({'image': images,
                'audio': audio,
                'label': np.zeros((images.shape[0], 700))},
               is_training=False, subsampled_output_points=subsampling)


video_autoencoder = hk.transform_with_state(video_autoencoder)

In [None]:
#@title Model application


def autoencode_video(params, state, rng, images, audio):
  nchunks = 128
  reconstruction = {}
  for chunk_idx in range(nchunks):
    image_chunk_size = np.prod(images.shape[1:-1]) // nchunks
    audio_chunk_size = audio.shape[1] // SAMPLES_PER_PATCH // nchunks
    subsampling = {
        'image': jnp.arange(
            image_chunk_size * chunk_idx, image_chunk_size * (chunk_idx + 1)),
        'audio': jnp.arange(
            audio_chunk_size * chunk_idx, audio_chunk_size * (chunk_idx + 1)),
        'label': None,
    }
    output, state = video_autoencoder.apply(
        params, state, rng, images, audio, subsampling)
    reconstruction['label'] = output['label']
    if 'image' not in reconstruction:
      reconstruction['image'] = output['image']
      reconstruction['audio'] = output['audio']
    else:
      reconstruction['image'] = jnp.concatenate(
          [reconstruction['image'], output['image']], axis=1)
      reconstruction['audio'] = jnp.concatenate(
          [reconstruction['audio'], output['audio']], axis=1)
      
  reconstruction['image'] = jnp.reshape(reconstruction['image'], images.shape)
  reconstruction['audio'] = jnp.reshape(reconstruction['audio'], audio.shape)
  return reconstruction

In [None]:
#@title Load parameters from checkpoint

!wget -O video_autoencoding_checkpoint.pystate https://storage.googleapis.com/perceiver_io/video_autoencoding_checkpoint.pystate

rng = jax.random.PRNGKey(42)
with open("video_autoencoding_checkpoint.pystate", "rb") as f:
  params = pickle.loads(f.read())

state = {}

In [None]:
# Auto-encode the first 16 frames of the video and one of the audio channels
reconstruction = autoencode_video(params, state, rng, video[None, :16], audio[None, :16*AUDIO_SAMPLES_PER_FRAME, 0:1])

In [None]:
# Visualize reconstruction of first 16 frames
table([to_gif(reconstruction["image"][0]), play_audio(np.array(reconstruction["audio"][0]))])

In [None]:
# Kinetics 700 Labels
scores, indices = jax.lax.top_k(jax.nn.softmax(reconstruction["label"]), 5)

for score, index in zip(scores[0], indices[0]):
  print("%s: %s" % (KINETICS_CLASSES[index], score))

In [None]:
# Auto-encode the entire video, one chunk at a time

# Partial video and audio into 16-frame chunks
nframes = video.shape[0]
# Truncate to be divisible by 16
nframes = nframes - (nframes % 16)
video_chunks = jnp.reshape(video[:nframes], [nframes // 16, 16, 224, 224, 3])
audio_chunks = jnp.reshape(audio[:nframes * AUDIO_SAMPLES_PER_FRAME],
                           [nframes // 16, 16 * AUDIO_SAMPLES_PER_FRAME, 2])

encode = jax.jit(functools.partial(autoencode_video, params, state, rng))

# Logically, what we do is the following code. We write out the loop to allocate
# GPU memory for only one chunk
#
# reconstruction = jax.vmap(encode, in_axes=1, out_axes=1)(
#     video_chunks[None, :], audio_chunks[None, :, :, 0:1])

chunks = []
for i in range(nframes // 16):
  reconstruction = encode(video_chunks[None, i], audio_chunks[None, i, :, 0:1])
  chunks.append(jax.tree_map(lambda x: np.array(x), reconstruction))

reconstruction = jax.tree_multimap(lambda *args: np.stack(args, axis=1),
                                   *chunks)

reconstruction = jax.tree_map(lambda x: np.reshape(x, [-1] + list(x.shape[2:])), reconstruction)

In [None]:
# Visualize reconstruction of entire video
table([to_gif(reconstruction['image'][0]), play_audio(np.array(reconstruction['audio'][0]))])

In [None]:
!pip install jaxline
!pip install optax

In [None]:
!pip install tensorflow-addons 
!pip install tensorflow-probability tensorflow-datasets

In [None]:
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A reference training pipeline for Perceiver/Perceiver IO on ImageNet.

We use the Jaxline (https://github.com/deepmind/jaxline) training framework.
Two sets of hyperparameters are provided, the hyperparameters we used for the
Perceiver IO paper, and scaled-down hyperparameters for local testing.
This script should run out-of-the-box with the local hyper parameters.
The scaled-up hyperparameters requires a distributed learning setup to run,
and this script will need to be adapted to your specific setup.
"""

import functools
from typing import Generator, Mapping, Text, Tuple

from absl import app
from absl import flags
from absl import logging
import haiku as hk
import jax
import jax.numpy as jnp
from jaxline import base_config
from jaxline import experiment
from jaxline import platform
from jaxline import utils as jl_utils
from ml_collections import config_dict
import numpy as np
import optax


import io_processors
import perceiver


from train import dataset
from train import utils

FLAGS = flags.FLAGS

OptState = Tuple[optax.TraceState, optax.ScaleByScheduleState, optax.ScaleState]
Scalars = Mapping[Text, jnp.ndarray]


N_TRAIN_EXAMPLES = dataset.Split.TRAIN_AND_VALID.num_examples
N_CLASSES = 1000
# Only local/debug parameters are supported out of the box.
# To use the scaled-up hyperparameters, please adapt this script to your
# training setup and set this flag to False
IS_LOCAL = True


def get_training_steps(batch_size, n_epochs):
  return (N_TRAIN_EXAMPLES * n_epochs) // batch_size


def get_config():
  """Return config object for training."""
  use_debug_settings = IS_LOCAL
  config = base_config.get_base_config()

  # Experiment config.
  local_batch_size = 2
  # Modify this to adapt to your custom distributed learning setup
  num_devices = 1
  config.train_batch_size = local_batch_size * num_devices
  config.n_epochs = 110

  def _default_or_debug(default_value, debug_value):
    return debug_value if use_debug_settings else default_value

  n_train_examples = N_TRAIN_EXAMPLES
  num_classes = N_CLASSES

  config.experiment_kwargs = config_dict.ConfigDict(
      dict(
          config=dict(
              optimizer=dict(
                  base_lr=5e-4,
                  max_norm=10.0,  # < 0 to turn off.
                  schedule_type='constant_cosine',
                  weight_decay=1e-1,
                  decay_pos_embs=True,
                  scale_by_batch=True,
                  cosine_decay_kwargs=dict(
                      init_value=0.0,
                      warmup_epochs=0,
                      end_value=0.0,
                  ),
                  step_decay_kwargs=dict(
                      decay_boundaries=[0.5, 0.8, 0.95],
                      decay_rate=0.1,
                  ),
                  constant_cosine_decay_kwargs=dict(
                      constant_fraction=0.5,
                      end_value=0.0,
                  ),
                  optimizer='lamb',
                  # Optimizer-specific kwargs:
                  adam_kwargs=dict(
                      b1=0.9,
                      b2=0.999,
                      eps=1e-8,
                  ),
                  lamb_kwargs=dict(
                      b1=0.9,
                      b2=0.999,
                      eps=1e-6,
                  ),
              ),
              # Don't specify output_channels - it's not used for
              # classifiers.
              model=dict(
                  perceiver_kwargs=dict(
                      input_preprocessor=dict(
                          prep_type='pixels',
                          # Channels for conv/conv1x1 preprocessing:
                          num_channels=64,
                          # -------------------------
                          # Position encoding arguments:
                          # -------------------------
                          position_encoding_type='fourier',
                          concat_or_add_pos='concat',
                          spatial_downsample=1,
                          # If >0, project position to this size:
                          project_pos_dim=-1,
                          trainable_position_encoding_kwargs=dict(
                              num_channels=258,  # Match default # for Fourier.
                              init_scale=0.02,
                          ),
                          fourier_position_encoding_kwargs=dict(
                              num_bands=64,
                              max_resolution=(224, 224),
                              sine_only=False,
                              concat_pos=True,
                          ),
                      ),
                      encoder=dict(
                          num_self_attends_per_block=_default_or_debug(6, 2),
                          # Weights won't be shared if num_blocks is set to 1.
                          num_blocks=_default_or_debug(8, 2),
                          z_index_dim=512,
                          num_z_channels=1024,
                          num_cross_attend_heads=1,
                          num_self_attend_heads=8,
                          cross_attend_widening_factor=1,
                          self_attend_widening_factor=1,
                          dropout_prob=0.0,
                          # Position encoding for the latent array.
                          z_pos_enc_init_scale=0.02,
                          cross_attention_shape_for_attn='kv',
                          use_query_residual=True,
                          ),
                      decoder=dict(
                          num_z_channels=1024,
                          use_query_residual=True,
                          # Position encoding for the output logits.
                          position_encoding_type='trainable',
                          trainable_position_encoding_kwargs=dict(
                              num_channels=1024,
                              init_scale=0.02,
                          ),
                      ),
                  ),
              ),
              training=dict(
                  images_per_epoch=n_train_examples,
                  label_smoothing=0.1,
                  n_epochs=config.get_oneway_ref('n_epochs'),
                  batch_size=config.get_oneway_ref('train_batch_size')
              ),
              data=dict(
                  num_classes=num_classes,
                  # Run on smaller images to debug.
                  im_dim=_default_or_debug(224, 32),
                  augmentation=dict(
                      # Typical randaug params:
                      # num_layers in [1, 3]
                      # magnitude in [5, 30]
                      # Set randaugment to None to disable.
                      randaugment=dict(
                          num_layers=4,
                          magnitude=5),
                      cutmix=True,
                      # Mixup alpha should be in [0, 1].
                      # Set to None to disable.
                      mixup_alpha=0.2,
                  ),
                  ),
              evaluation=dict(
                  subset='test',
                  batch_size=2,
              ),
          )
      )
  )

  # Training loop config.
  config.training_steps = get_training_steps(
      config.get_oneway_ref('train_batch_size'),
      config.get_oneway_ref('n_epochs'))
  config.log_train_data_interval = 60
  config.log_tensors_interval = 60
  config.save_checkpoint_interval = 300
  config.eval_specific_checkpoint_dir = ''
  config.best_model_eval_metric = 'eval_top_1_acc'
  config.checkpoint_dir = '/tmp/perceiver_imagnet_checkpoints'
  config.train_checkpoint_all_hosts = False

  # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
  config.lock()

  return config


class Experiment(experiment.AbstractExperiment):
  """ImageNet experiment."""

  # A map from object properties that will be checkpointed to their name
  # in a checkpoint. Currently we assume that these are all sharded
  # device arrays.
  CHECKPOINT_ATTRS = {
      '_params': 'params',
      '_state': 'state',
      '_opt_state': 'opt_state',
  }

  def __init__(self, mode, init_rng, config):
    """Initializes experiment."""

    super(Experiment, self).__init__(mode=mode, init_rng=init_rng)

    self.mode = mode
    self.init_rng = init_rng
    self.config = config

    # Checkpointed experiment state.
    self._params = None
    self._state = None
    self._opt_state = None

    # Input pipelines.
    self._train_input = None
    self._eval_input = None

    self.forward = hk.transform_with_state(self._forward_fn)

    # NOTE: We "donate" the `params, state, opt_state` arguments which allows
    # JAX (on some backends) to reuse the device memory associated with these
    # inputs to store the outputs of our function (which also start with
    # `params, state, opt_state`).
    self._update_func = jax.pmap(self._update_func, axis_name='i',
                                 donate_argnums=(0, 1, 2))
    self._eval_batch = jax.jit(self._eval_batch)

  def _forward_fn(
      self,
      inputs: dataset.Batch,
      is_training: bool,
  ) -> jnp.ndarray:

    images = inputs['images']

    perceiver_kwargs = self.config.model.perceiver_kwargs
    input_preprocessor = io_processors.ImagePreprocessor(
        **perceiver_kwargs['input_preprocessor'])
    encoder = perceiver.PerceiverEncoder(**perceiver_kwargs['encoder'])
    decoder = perceiver.ClassificationDecoder(
        self.config.data.num_classes,
        **perceiver_kwargs['decoder'])
    model = perceiver.Perceiver(
        encoder=encoder,
        decoder=decoder,
        input_preprocessor=input_preprocessor)

    return model(images, is_training=is_training)

  #  _             _
  # | |_ _ __ __ _(_)_ __
  # | __| '__/ _` | | '_ \
  # | |_| | | (_| | | | | |
  #  \__|_|  \__,_|_|_| |_|
  #

  def step(self, global_step: int, rng: jnp.ndarray,
           *unused_args, **unused_kwargs):
    """See base class."""

    if self._train_input is None:
      self._initialize_train()

    inputs = next(self._train_input)

    self._params, self._state, self._opt_state, scalars = (
        self._update_func(
            self._params, self._state, self._opt_state, inputs, rng, global_step
            ))

    scalars = jl_utils.get_first(scalars)
    return scalars

  def _initialize_train(self):
    self._train_input = jl_utils.py_prefetch(self._build_train_input)

    total_batch_size = self.config.training.batch_size
    steps_per_epoch = (
        self.config.training.images_per_epoch / self.config.training.batch_size)
    total_steps = self.config.training.n_epochs * steps_per_epoch
    # Scale by the (negative) learning rate.
    self._lr_schedule = utils.get_learning_rate_schedule(
        total_batch_size, steps_per_epoch, total_steps, self.config.optimizer)

    self._optimizer = utils.make_optimizer(
        self.config.optimizer,
        self._lr_schedule)

    # Check we haven't already restored params
    if self._params is None:
      logging.info('Initializing parameters.')

      inputs = next(self._train_input)

      init_net = jax.pmap(lambda *a: self.forward.init(*a, is_training=True))
      init_opt = jax.pmap(self._optimizer.init)

      # Init uses the same RNG key on all hosts+devices to ensure everyone
      # computes the same initial state.
      init_rng = jl_utils.bcast_local_devices(self.init_rng)

      self._params, self._state = init_net(init_rng, inputs)
      self._opt_state = init_opt(self._params)

  def _load_data(self, split, is_training, batch_dims):
    """Wrapper for dataset loading."""

    return dataset.load(
        split=split,
        is_training=is_training,
        batch_dims=batch_dims,
        im_dim=self.config.data.im_dim,
        augmentation_settings=self.config.data.augmentation,
        )

  def _build_train_input(self) -> Generator[dataset.Batch, None, None]:
    """See base class."""
    num_devices = jax.device_count()
    global_batch_size = self.config.training.batch_size
    per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

    if ragged:
      raise ValueError(
          'Global batch size {global_batch_size} must be divisible by '
          'num devices {num_devices}')

    split = dataset.Split.TRAIN_AND_VALID

    return self._load_data(
        split=split,
        is_training=True,
        batch_dims=[jax.local_device_count(), per_device_batch_size])

  def _one_hot(self, value):
    """One-hot encoding potentially over a sequence of labels."""
    y = jax.nn.one_hot(value, self.config.data.num_classes)
    return y

  def _loss_fn(
      self,
      params: hk.Params,
      state: hk.State,
      inputs: dataset.Batch,
      rng: jnp.ndarray,
  ) -> Tuple[jnp.ndarray, Tuple[Scalars, hk.State]]:
    logits, state = self.forward.apply(
        params, state, rng, inputs, is_training=True)

    label = self._one_hot(inputs['labels'])
    # Handle cutmix/mixup label mixing:
    if 'mix_labels' in inputs:
      logging.info('Using mixup or cutmix!')
      mix_label = self._one_hot(inputs['mix_labels'])
      mix_ratio = inputs['ratio'][:, None]
      label = mix_ratio * label + (1. - mix_ratio) * mix_label

    # Apply label-smoothing to one-hot labels.
    label_smoothing = self.config.training.label_smoothing
    if not (label_smoothing >= 0. and label_smoothing < 1.):
      raise ValueError(
          "'label_smoothing is {label_smoothing} and should be in [0, 1)")
    if label_smoothing > 0:
      smooth_positives = 1. - label_smoothing
      smooth_negatives = label_smoothing / self.config.data.num_classes
      label = smooth_positives * label + smooth_negatives

    loss_w_batch = utils.softmax_cross_entropy(logits, label)
    loss = jnp.mean(loss_w_batch, dtype=loss_w_batch.dtype)
    scaled_loss = loss / jax.device_count()

    metrics = utils.topk_correct(logits, inputs['labels'], prefix='')
    metrics = jax.tree_map(jnp.mean, metrics)

    top_1_acc = metrics['top_1_acc']
    top_5_acc = metrics['top_5_acc']

    loss_scalars = dict(
        loss=loss,
        top_1_acc=top_1_acc,
        top_5_acc=top_5_acc,
    )

    return scaled_loss, (loss_scalars, state)

  def _update_func(
      self,
      params: hk.Params,
      state: hk.State,
      opt_state: OptState,
      inputs: dataset.Batch,
      rng: jnp.ndarray,
      global_step: int,
  ) -> Tuple[hk.Params, hk.State, OptState, Scalars]:
    """Applies an update to parameters and returns new state."""
    # This function computes the gradient of the first output of loss_fn and
    # passes through the other arguments unchanged.
    grad_loss_fn = jax.grad(self._loss_fn, has_aux=True)
    scaled_grads, (loss_scalars, state) = grad_loss_fn(
        params, state, inputs, rng)
    grads = jax.lax.psum(scaled_grads, axis_name='i')

    # Grab the learning rate to log before performing the step.
    learning_rate = self._lr_schedule(global_step)

    # Compute and apply updates via our optimizer.
    updates, opt_state = self._optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    n_params = 0
    for k in params.keys():
      for l in params[k]:
        n_params = n_params + np.prod(params[k][l].shape)

    # Scalars to log (note: we log the mean across all hosts/devices).
    scalars = {'learning_rate': learning_rate,
               'n_params (M)': float(n_params/1e6),
               'global_gradient_norm': optax.global_norm(grads)}
    loss_scalars = {'train_{k}': v for k, v in loss_scalars.items()}
    scalars.update(loss_scalars)
    scalars = jax.lax.pmean(scalars, axis_name='i')

    return params, state, opt_state, scalars

  #                  _
  #   _____   ____ _| |
  #  / _ \ \ / / _` | |
  # |  __/\ V / (_| | |
  #  \___| \_/ \__,_|_|
  #

  def evaluate(self, global_step, rng, **unused_args):
    """See base class."""
    global_step = np.array(jl_utils.get_first(global_step))
    scalars = jax.device_get(self._eval_epoch(jl_utils.get_first(rng)))

    logging.info('[Step %d] Eval scalars: %s', global_step, scalars)
    return scalars

  def _eval_batch(
      self,
      params: hk.Params,
      state: hk.State,
      inputs: dataset.Batch,
      rng: jnp.ndarray,
  ) -> Scalars:
    """Evaluates a batch."""
    logits, _ = self.forward.apply(
        params, state, rng, inputs, is_training=False)

    labels = self._one_hot(inputs['labels'])
    loss = utils.softmax_cross_entropy(logits, labels)

    metrics = utils.topk_correct(logits, inputs['labels'], prefix='')
    metrics = jax.tree_map(jnp.mean, metrics)
    top_1_acc = metrics['top_1_acc']
    top_5_acc = metrics['top_5_acc']

    bs = logits.shape[0]

    top_1_acc = jnp.expand_dims(top_1_acc, axis=0) * bs
    top_5_acc = jnp.expand_dims(top_5_acc, axis=0) * bs

    # NOTE: Returned values will be summed and finally divided by num_samples.
    return {
        'eval_loss': loss,
        'eval_top_1_acc': top_1_acc, 'eval_top_5_acc': top_5_acc}

  def _build_eval_input(self) -> Generator[dataset.Batch, None, None]:
    split = dataset.Split.from_string(self.config.evaluation.subset)

    return self._load_data(
        split=split,
        is_training=False,
        batch_dims=[self.config.evaluation.batch_size])

  def _eval_epoch(self, rng):
    """Evaluates an epoch."""
    num_samples = 0.
    summed_scalars = None

    params = jl_utils.get_first(self._params)
    state = jl_utils.get_first(self._state)

    for inputs in self._build_eval_input():
      num_samples += inputs['labels'].shape[0]
      scalars = self._eval_batch(params, state, inputs, rng)

      # Accumulate the sum of scalars for each step.
      scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
      if summed_scalars is None:
        summed_scalars = scalars
      else:
        summed_scalars = jax.tree_multimap(jnp.add, summed_scalars, scalars)

    mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
    return mean_scalars


if __name__ == '__main__':
  # sys.argv[0] = "./perceiver/train/experiment.py"
  #sys.argv[1] = "./perceiver/train/experiment.py"
  flags.mark_flag_as_required('config')
  app.run(functools.partial(platform.main, Experiment))


In [None]:
! python perceiver/train/experiment.py --config=perceiver/train/experiment.py --logtostderr


In [None]:
import tensorflow as tf
tf.compat.v1.flags.DEFINE_string('f','','')

In [None]:
!pip install jupyter-console 
!pip install absl-py


In [None]:
!chmod 777 /content/perceiver/train/launch_local.sh
!bash ./perceiver/train/launch_local.sh

In [None]:
%cd ../

In [None]:
%run "/content/perceiver/io_processors.py"
%run  "/content/perceiver/perceiver.py"
%run "/content/perceiver/position_encoding.py"
%run "/content/perceiver/train/dataset.py"

In [None]:
import sys
sys.path.append('/content/perceiver')
sys.path.append('/content/perceiver/train')

import io_processors

In [None]:
!pip install altair

In [None]:
!python ./perceiver/train/experiment.py --config=./perceiver/train/experiment.py --logtostderr


In [None]:
!pip install tensorflow-addons

#  Mine 


In [None]:
import enum
from typing import Any, Generator, Mapping, Optional, Sequence, Text, Tuple

import jax
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

from train import autoaugment


Batch = Mapping[Text, np.ndarray]
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
AUTOTUNE = tf.data.experimental.AUTOTUNE

INPUT_DIM = 224  # The number of pixels in the image resize.


class Split(enum.Enum):
  """ImageNet dataset split."""
  TRAIN = 1
  TRAIN_AND_VALID = 2
  VALID = 3
  TEST = 4

  @classmethod
  def from_string(cls, name: Text) -> 'Split':
    return {'TRAIN': Split.TRAIN, 'TRAIN_AND_VALID': Split.TRAIN_AND_VALID,
            'VALID': Split.VALID, 'VALIDATION': Split.VALID,
            'TEST': Split.TEST}[name.upper()]

  @property
  def num_examples(self):
    return {Split.TRAIN_AND_VALID: 1281167, Split.TRAIN: 1271167,
            Split.VALID: 10000, Split.TEST: 50000}[self]


def load(
    split: Split,
    *,
    is_training: bool,
    # batch_dims should be:
    # [device_count, per_device_batch_size] or [total_batch_size]
    batch_dims: Sequence[int],
    augmentation_settings: Mapping[str, Any],
    # The shape to which images are resized.
    im_dim: int = INPUT_DIM,
    threadpool_size: int = 48,
    max_intra_op_parallelism: int = 1,
) -> Generator[Batch, None, None]:
  """Loads the given split of the dataset."""
  start, end = _shard(split, jax.host_id(), jax.host_count())

  im_size = (im_dim, im_dim)

  total_batch_size = np.prod(batch_dims)

  tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split),
                                         from_=start, to=end, unit='abs')

  ds = tfds.load('imagenet2012:5.*.*', split=tfds_split,
                 decoders={'image': tfds.decode.SkipDecoding()})

  options = tf.data.Options()
  options.experimental_threading.private_threadpool_size = threadpool_size
  options.experimental_threading.max_intra_op_parallelism = (
      max_intra_op_parallelism)
  options.experimental_optimization.map_parallelization = True
  if is_training:
    options.experimental_deterministic = False
  ds = ds.with_options(options)

  if is_training:
    if jax.host_count() > 1:
      # Only cache if we are reading a subset of the dataset.
      ds = ds.cache()
    ds = ds.repeat()
    ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0)

  else:
    if split.num_examples % total_batch_size != 0:
      raise ValueError(f'Test/valid must be divisible by {total_batch_size}')

  def crop_augment_preprocess(example):
    image, _ = _preprocess_image(
        example['image'], is_training, im_size, augmentation_settings)

    label = tf.cast(example['label'], tf.int32)

    out = {'images': image, 'labels': label}

    if is_training:
      if augmentation_settings['cutmix']:
        out['mask'] = cutmix_padding(*im_size)
        out['cutmix_ratio'] = tf.reduce_mean(out['mask'])
      if augmentation_settings['mixup_alpha'] is not None:
        beta = tfp.distributions.Beta(
            augmentation_settings['mixup_alpha'],
            augmentation_settings['mixup_alpha'])
        out['mixup_ratio'] = beta.sample()
    return out

  ds = ds.map(crop_augment_preprocess, num_parallel_calls=AUTOTUNE)

  # Mixup/cutmix by temporarily batching (using the per-device batch size):
  use_cutmix = augmentation_settings['cutmix']
  use_mixup = augmentation_settings['mixup_alpha'] is not None
  if is_training and (use_cutmix or use_mixup):
    inner_batch_size = batch_dims[-1]
    # Apply mixup, cutmix, or mixup + cutmix on batched data.
    # We use data from 2 batches to produce 1 mixed batch.
    ds = ds.batch(inner_batch_size * 2)
    if not use_cutmix and use_mixup:
      ds = ds.map(my_mixup, num_parallel_calls=AUTOTUNE)
    elif use_cutmix and not use_mixup:
      ds = ds.map(my_cutmix, num_parallel_calls=AUTOTUNE)
    elif use_cutmix and use_mixup:
      ds = ds.map(my_mixup_cutmix, num_parallel_calls=AUTOTUNE)

    # Unbatch for further processing.
    ds = ds.unbatch()

  for batch_size in reversed(batch_dims):
    ds = ds.batch(batch_size)

  ds = ds.prefetch(AUTOTUNE)

  yield from tfds.as_numpy(ds)


def _shard(
    split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]:
  """Returns [start, end) for the given shard index."""
  assert shard_index < num_shards
  arange = np.arange(split.num_examples)
  shard_range = np.array_split(arange, num_shards)[shard_index]
  start, end = shard_range[0], (shard_range[-1] + 1)
  if split == Split.TRAIN:
    # Note that our TRAIN=TFDS_TRAIN[10000:] and VALID=TFDS_TRAIN[:10000].
    offset = Split.VALID.num_examples
    start += offset
    end += offset
  return start, end


def _preprocess_image(
    image_bytes: tf.Tensor,
    is_training: bool,
    image_size: Sequence[int],
    augmentation_settings: Mapping[str, Any],
) -> Tuple[tf.Tensor, tf.Tensor]:
  """Returns processed and resized images."""

  # Get the image crop.
  if is_training:
    image, im_shape = _decode_and_random_crop(image_bytes)
    image = tf.image.random_flip_left_right(image)
  else:
    image, im_shape = _decode_and_center_crop(image_bytes)
  assert image.dtype == tf.uint8

  # Optionally apply RandAugment: https://arxiv.org/abs/1909.13719
  if is_training:
    if augmentation_settings['randaugment'] is not None:
      # Input and output images are dtype uint8.
      image = autoaugment.distort_image_with_randaugment(
          image,
          num_layers=augmentation_settings['randaugment']['num_layers'],
          magnitude=augmentation_settings['randaugment']['magnitude'])

  # Resize and normalize the image crop.
  # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
  # clamping overshoots. This means values returned will be outside the range
  # [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
  image = tf.image.resize(
      image, image_size, tf.image.ResizeMethod.BICUBIC)
  image = _normalize_image(image)

  return image, im_shape

##########################################################
 #Utilities to open video files using CV2
def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]

def load_video(path, max_frames=0, resize=(224, 224)):
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]]
      frames.append(frame)
      
      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  return np.array(frames) / 255.0
############################################

def _normalize_image(image: tf.Tensor) -> tf.Tensor:
  """Normalize the image to zero mean and unit variance."""
  image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
  image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)
  return image



def _decode_whole_image(image_bytes: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
  image = tf.io.decode_jpeg(image_bytes, channels=3)
  im_shape = tf.io.extract_jpeg_shape(image_bytes, output_type=tf.int32)
  return image, im_shape


def _center_crop(image, crop_dim):
  """Center crops an image to a target dimension."""
  image_height = image.shape[0]
  image_width = image.shape[1]
  offset_height = ((image_height - crop_dim) + 1) // 2
  offset_width = ((image_width - crop_dim) + 1) // 2
  return tf.image.crop_to_bounding_box(
      image, offset_height, offset_width, crop_dim, crop_dim)


def _decode_and_center_crop(
    image_bytes: tf.Tensor,
    jpeg_shape: Optional[tf.Tensor] = None,
) -> Tuple[tf.Tensor, tf.Tensor]:
  """Crops to center of image with padding then scales."""
  if jpeg_shape is None:
    if image_bytes.dtype == tf.dtypes.string:
      jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)
    else:
      jpeg_shape = tf.shape(image_bytes)

  image_height = jpeg_shape[0]
  image_width = jpeg_shape[1]

  padded_center_crop_size = tf.cast(
      ((INPUT_DIM / (INPUT_DIM + 32)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = [offset_height, offset_width,
                 padded_center_crop_size, padded_center_crop_size]

  if image_bytes.dtype == tf.dtypes.string:
    image = tf.image.decode_and_crop_jpeg(image_bytes,
                                          tf.stack(crop_window),
                                          channels=3)
  else:
    image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)

  im_shape = tf.stack([padded_center_crop_size, padded_center_crop_size])
  return image, im_shape

