In [1]:
import functools
from typing import Any, Optional, Callable, Sequence

from absl import logging
import flax.linen as nn
from flax.linen.linear import default_kernel_init
from immutabledict import immutabledict
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.common_lib import video_utils
from scenic.model_lib.base_models import base_model
from scenic.model_lib.base_models import classification_model
from scenic.model_lib.base_models import model_utils as base_model_utils
from scenic.model_lib.base_models.classification_model import ClassificationModel
from scenic.model_lib.layers import attention_layers
from scenic.model_lib.layers import nn_layers
from scenic.projects.baselines import vit
from scenic.projects.vivit import model_utils
Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray]
_AXIS_TO_NAME = immutabledict({
    1: 'time',
    2: 'space',
})

KERNEL_INITIALIZERS = immutabledict({
    'zero': nn.initializers.zeros,
    'xavier': nn.initializers.xavier_uniform(),
})
ViViT_CLASSIFICATION_METRICS_BASIC = immutabledict({
    'accuracy': (base_model_utils.weighted_correctly_classified,
                 base_model_utils.num_examples),
    'loss': (base_model_utils.weighted_unnormalized_softmax_cross_entropy,
             base_model_utils.num_examples)
})

ViViT_CLASSIFICATION_METRICS = immutabledict({
    **ViViT_CLASSIFICATION_METRICS_BASIC,
    'accuracy_top_5': (functools.partial(
        base_model_utils.weighted_topk_correctly_classified,
        k=5), base_model_utils.num_examples),
})


def _reshape_to_time_space(x, temporal_dims):
  if x.ndim == 3:
    b, thw, d = x.shape
    assert thw % temporal_dims == 0
    hw = thw // temporal_dims
    x = jnp.reshape(x, [b, temporal_dims, hw, d])
  assert x.ndim == 4
  return x


def embed_2d_patch(x, patches, embedding_dim):
  """Standard ViT method of embedding input patches."""

  n, h, w, c = x.shape

  assert patches.get('size') is not None, ('patches.size is now the only way'
                                           'to define the patches')

  fh, fw = patches.size
  gh, gw = h // fh, w // fw

  if embedding_dim:
    x = nn.Conv(
        embedding_dim, (fh, fw),
        strides=(fh, fw),
        padding='VALID',
        name='embedding')(x)
  else:
    # This path often results in excessive padding: b/165788633
    x = jnp.reshape(x, [n, gh, fh, gw, fw, c])
    x = jnp.transpose(x, [0, 1, 3, 2, 4, 5])
    x = jnp.reshape(x, [n, gh, gw, -1])

  return x


def embed_3d_patch(x,
                   patches,
                   embedding_dim,
                   kernel_init_method,
                   name='embedding'):
  """Embed 3D input patches into tokens."""

  assert patches.get('size') is not None, 'patches.size must be defined'
  assert len(patches.size) == 3, 'patches.size must have 3 elements'
  assert embedding_dim, 'embedding_dim must be specified'

  fh, fw, ft = patches.size

  if kernel_init_method == 'central_frame_initializer':
    kernel_initializer = model_utils.central_frame_initializer()
    logging.info('Using central frame initializer for input embedding')
  elif kernel_init_method == 'average_frame_initializer':
    kernel_initializer = model_utils.average_frame_initializer()
    logging.info('Using average frame initializer for input embedding')
  else:
    kernel_initializer = default_kernel_init
    logging.info('Using default initializer for input embedding')

  x = nn.Conv(
      embedding_dim, (ft, fh, fw),
      strides=(ft, fh, fw),
      padding='VALID',
      name=name,
      kernel_init=kernel_initializer)(
          x)

  return x





In [2]:
def temporal_encode(x,
                    #temporal_encoding_config,
                    patches,
                    hidden_size,
                    return_1d=True,
                    name='embedding'):
  """Encode video for feeding into ViT."""

  n, _, in_h, in_w, c = x.shape

  # if temporal_encoding_config.method == 'temporal_sampling':
  # n_sampled_frames = 16
  # x = video_utils.sample_frames_uniformly(x, n_sampled_frames)
  # t_s = x.shape[1]
  # x = jnp.reshape(x, [n, t_s * in_h, in_w, c])

  #   x = embed_2d_patch(x, patches, hidden_size)
  #   temporal_dims = t_s
  #   if return_1d:
  #     n, th, w, c = x.shape
  #     x = jnp.reshape(x, [n, th * w, c])
  #   else:
  #     n, th, w, c = x.shape
  #     x = jnp.reshape(x, [n, t_s, -1, w, c])

 
  kernel_init_method = 'central_frame_initializer'
  x = embed_3d_patch(x, patches, hidden_size, kernel_init_method, name)
  temporal_dims = x.shape[1]
  if return_1d:
    n, t, h, w, c = x.shape
    x = np.reshape(x, [n, t * h * w, c])


  assert x.size > 0, ('Found zero tokens after temporal encoding. '
                      'Perhaps one of the patch sizes is such that '
                      'floor(dim_size / patch_size) = 0?')

  return x, temporal_dims

In [3]:
class EncoderFactorizedSelfAttentionBlock(nn.Module):
  """Encoder with facctorized self attention block.

  Attributes:
    mlp_dim: Dimension of the mlp on top of attention block.
    num_heads: Number of heads.
    temporal_dims: Number of temporal dimensions in the flattened input
    attention_kernel_initializer: Initializer to use for attention layers.
    dropout_rate: Dropout rate.
    attention_dropout_rate: Dropout for attention heads.
    droplayer_p: Probability of dropping a layer.
    attention_order: The order to do the attention. Choice of {time_space,
      space_time}.
    dtype: the dtype of the computation (default: float32).
  """
  mlp_dim: int
  num_heads: int
  temporal_dims: int
  attention_kernel_initializer: Initializer
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  droplayer_p: Optional[float] = None
  attention_order: str = 'time_space'
  dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(self, inputs: jnp.ndarray, *, deterministic: bool=False):
    """Applies Encoder1DBlock module."""
    some_param = self.param('some_param', nn.initializers.zeros_init(), (1, ))
    dropout_rng = self.make_rng('dropout')

    b, thw, d = inputs.shape
    inputs = _reshape_to_time_space(inputs, self.temporal_dims)
    # self_attention = functools.partial(
    #     nn.SelfAttention,
    #     num_heads=self.num_heads,
    #     kernel_init=self.attention_kernel_initializer,
    #     broadcast_dropout=False,
    #     dropout_rate=self.attention_dropout_rate,
    #     dtype=self.dtype)

    if self.attention_order == 'time_space':
      attention_axes = (1, 2)
    elif self.attention_order == 'space_time':
      attention_axes = (2, 1)
    else:
      raise ValueError(f'Invalid attention order {self.attention_order}.')

    def _run_attention_on_axis(inputs, axis, two_d_shape):
      """Reshapes the input and run attention on the given axis."""
      inputs = model_utils.reshape_to_1d_factorized(inputs, axis=axis)
      x = nn.LayerNorm(
          dtype=self.dtype, name='LayerNorm_{}'.format(_AXIS_TO_NAME[axis]))(
              inputs)
      output = nn.SelfAttention(num_heads=self.num_heads,
        kernel_init=self.attention_kernel_initializer,
        broadcast_dropout=False,
        dropout_rate=self.attention_dropout_rate,
        dtype=self.dtype,
        deterministic=deterministic,
        name='MultiHeadDotProductAttention_{}'.format(_AXIS_TO_NAME[axis]))
      init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}

      variables = output.init(init_rngs, x, method=nn.SelfAttention.__call__)

      x=output.apply(variables,x,rngs={'dropout': jax.random.key(2)})
      
      x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
      x = x + inputs
      return model_utils.reshape_to_2d_factorized(
          x, axis=axis, two_d_shape=two_d_shape)

    x = inputs
    two_d_shape = inputs.shape
    for axis in attention_axes:
      x = _run_attention_on_axis(x, axis, two_d_shape)

    # MLP block.
    x = jnp.reshape(x, [b, thw, d])
    y = nn.LayerNorm(dtype=self.dtype, name='LayerNorm_mlp')(x)
    op = attention_layers.MlpBlock(
        mlp_dim=self.mlp_dim,
        dtype=self.dtype,
        dropout_rate=self.dropout_rate,
        activation_fn=nn.gelu,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        name='MlpBlock')
    init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}

    variables = op.init(init_rngs, y)

    y=op.apply(variables,y,rngs={'dropout': jax.random.key(2)})
    
    return x + y


In [4]:
class Encoder(nn.Module):
  """Transformer Encoder.

  Attributes:
    inputs: nd-array, Input data
    temporal_dims: Number of temporal dimensions in the input.
    mlp_dim: Dimension of the mlp on top of attention block.
    num_layers: Number of layers.
    num_heads: Number of attention heads.
    attention_config: Has parameters for the type of attention.
    dropout_rate: Dropout rate.
    attention_dropout_rate: Dropout for attention heads.
    stochastic_droplayer_rate: Probability of dropping a layer linearly
      grows from 0 to the provided value. Our implementation of stochastic
      depth follows timm library, which does per-example layer dropping and
      uses independent dropping patterns for each skip-connection.
    positional_embedding: The type of positional embedding to use. Supported
      values are {learned_1d, sinusoidal_1d, sinusoidal_3d, none}.
    normalise_output: If True, perform layernorm on the output.
  """

  temporal_dims: Optional[int]
  mlp_dim: int
  num_layers: int
  num_heads: int
  #attention_config: ml_collections.ConfigDict = None
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  stochastic_droplayer_rate: float = 0.0
  dtype: jnp.dtype = jnp.float32
  positional_embedding: str = 'learned_1d'
  normalise_output: bool = True

  @nn.compact
  def __call__(self, inputs: jnp.ndarray, *, train: bool = True):
    """Applies Transformer model on the inputs."""
    # some_param = self.param('some_param', nn.initializers.zeros_init(), (1, ))
    # dropout_rng = self.make_rng('dropout')
    assert inputs.ndim == 3  # (batch, len, emb)
    dtype = jax.dtypes.canonicalize_dtype(self.dtype)

    if self.positional_embedding == 'learned_1d':
      x = vit.AddPositionEmbs(
          posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
          name='posembed_input')(inputs)
    # elif self.positional_embedding == 'sinusoidal_1d':
    #   x = attention_layers.Add1DPositionEmbedding(
    #       posemb_init=None)(inputs)
    # elif self.positional_embedding == 'sinusoidal_3d':
    #   batch, num_tokens, hidden_dim = inputs.shape
    #   height = width = int(np.sqrt(num_tokens // self.temporal_dims))
    #   if height * width * self.temporal_dims != num_tokens:
    #     raise ValueError('Input is assumed to be square for sinusoidal init.')
    #   inputs_reshape = inputs.reshape([batch, self.temporal_dims, height, width,
    #                                    hidden_dim])
    #   x = attention_layers.AddFixedSinCosPositionEmbedding()(inputs_reshape)
    #   x = x.reshape([batch, num_tokens, hidden_dim])
    # elif self.positional_embedding == 'none':
    #   x = inputs
    else:
      raise ValueError(
          f'Unknown positional embedding {self.positional_embedding}')
    # x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

    # if self.attention_config is None or self.attention_config.type in [
    #     'spacetime', 'factorized_encoder'
    # ]:
    #   encoder_block = EncoderBlock
    #elif self.attention_config.type == 'factorized_self_attention_block':
    encoder_block = functools.partial(
          EncoderFactorizedSelfAttentionBlock,
          attention_order='space_time',
          attention_kernel_initializer=KERNEL_INITIALIZERS[('xavier')],
          temporal_dims=self.temporal_dims)
    
    # elif self.attention_config.type == 'factorized_dot_product_attention':
    # b, thw, d = x.shape
    # x = _reshape_to_time_space(x, self.temporal_dims)  # [b, t, hw, d]
    # encoder_block = functools.partial(
    #       EncoderBlock,
    #       attention_fn=functools.partial(
    #           model_utils.factorized_dot_product_attention))
    # # else:
    #   raise ValueError(f'Unknown attention type {self.attention_config.type}')

    # Input Encoder
    for lyr in range(self.num_layers):
      droplayer_p = (
          lyr / max(self.num_layers - 1, 1)) * self.stochastic_droplayer_rate
      output = encoder_block(
          mlp_dim=self.mlp_dim,
          num_heads=self.num_heads,
          dropout_rate=self.dropout_rate,
          attention_dropout_rate=self.attention_dropout_rate,
          droplayer_p=droplayer_p,
          name=f'encoderblock_{lyr}',
          dtype=dtype)
      init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}

      variables = output.init(init_rngs, x, method=EncoderFactorizedSelfAttentionBlock.__call__)

      x=output.apply(variables,x,rngs={'dropout': jax.random.key(2)})

    # if self.attention_config.type == 'factorized_dot_product_attention':
    #   # Reshape back to 3D:
    #   x = jnp.reshape(x, [b, thw, d])

    if self.normalise_output:
      encoded = nn.LayerNorm(name='encoder_norm')(x)
    else:
      encoded = x

    return encoded

In [5]:
class ViViT(nn.Module):
  """Vision Transformer model for Video.

  Attributes:
    mlp_dim: Dimension of the mlp on top of attention block.
    num_classes: Number of output classes.
    num_heads: Number of self-attention heads.
    num_layers: Number of layers.
    patches: Configuration of the patches extracted in the stem of the model.
    hidden_size: Size of the hidden state of the output of model's stem.
    representation_size: Size of the representation layer in the model's head.
      if None, we skip the extra projection + tanh activation at the end.
    temporal_encoding_config: ConfigDict which defines the type of input
      encoding when tokenising the video.
    attention_config: ConfigDict which defines the type of spatio-temporal
      attention applied in the model.
    dropout_rate: Dropout rate.
    attention_dropout_rate: Dropout for attention heads.
    stochastic_droplayer_rate: Probability of dropping a layer. Linearly
      increases from 0 to the provided value..
    classifier: type of the classifier layer. Options are 'gap', 'gmp', 'gsp',
      'token'.
    return_prelogits: If true, return the final representation of the network
      before the classification head. Useful when using features for a
      downstream task.
    return_preclassifier: If true, return the representation after the
      transformer encoder. Useful if using this as the backbone stem as part
      of a bigger architecture.
    dtype: JAX data type for activations.
  """

  mlp_dim: int
  num_layers: int
  num_heads: int
  # num_classes: int
  patches: ml_collections.ConfigDict
  hidden_size: int
  #temporal_encoding_config: ml_collections.ConfigDict
  #attention_config: ml_collections.ConfigDict
  representation_size: Optional[int] = None
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  stochastic_droplayer_rate: float = 0.
  classifier: str = 'gap'
  #return_prelogits: bool = True
  #return_preclassifier: bool = False
  dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(self, x, *, train: bool = True, debug: bool = False):

   

    x1, temporal_dims = temporal_encode(
        x, self.patches, self.hidden_size)

    # # If we want to add a class token, add it here.
    # if self.classifier in ['token']:
    #   n, _, c = x.shape
    #   cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype)
    #   cls = jnp.tile(cls, [n, 1, 1])
    #   x = jnp.concatenate([cls, x], axis=1)

    
    output = Encoder(
        temporal_dims=temporal_dims,
        mlp_dim=self.mlp_dim,
        num_layers=self.num_layers,
        num_heads=self.num_heads,
        
        dropout_rate=self.dropout_rate,
        attention_dropout_rate=self.attention_dropout_rate,
        stochastic_droplayer_rate=self.stochastic_droplayer_rate,
        dtype=self.dtype,
        name='Transformer')
    init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}

    variables = output.init(init_rngs, x1)

    x2=output.apply(variables,x1,rngs={'dropout': jax.random.key(2)})

    
    


    # if self.return_preclassifier:
    #   return x

    #if self.classifier in ['token', '0']:
    x2 = x2[:, 0]
    # elif self.classifier in ('gap', 'gmp', 'gsp'):
    #   fn = {'gap': jnp.mean, 'gmp': jnp.max, 'gsp': jnp.sum}[self.classifier]
    #   x = fn(x, axis=list(range(1, x.ndim - 1)))

    if self.representation_size is not None:
      x2 = nn.Dense(self.representation_size, name='pre_logits')(x)
      x2 = nn.tanh(x)
    else:
      x2 = nn_layers.IdentityLayer(name='pre_logits')(x)

    
    return x2
    


In [11]:
import cv2
import numpy as np

def preprocess_video(video_path, output_fps, min_resize=256, crop_size=224, zero_centering=True):
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    # Calculate frame interval to extract 16 frames per second
    interval = int(round(fps / output_fps))

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        frames.append(frame)
        for _ in range(interval - 1):
            cap.grab()  # Move to the next frame without decoding

    # Resize frames to min_resize and crop to crop_size
    processed_frames = []
    for frame in frames:
        resized_frame = cv2.resize(frame, (min_resize, min_resize))
        cropped_frame = resized_frame[(min_resize - crop_size) // 2:(min_resize + crop_size) // 2,
                                      (min_resize - crop_size) // 2:(min_resize + crop_size) // 2]

        if zero_centering:
            # Normalize pixel values to [-1, 1]
            cropped_frame = cropped_frame.astype(float) / 255.0  # Assuming original range [0, 255]
            cropped_frame = (cropped_frame - 0.5) / 0.5  # Normalize to [-1, 1]

        processed_frames.append(cropped_frame)

    cap.release()

    # Convert frames to a tensor (NumPy array)
    frames_tensor = np.array(processed_frames)

    # Reshape tensor to match the desired output shape
    num_frames = len(processed_frames)
    frames_tensor = frames_tensor.reshape(1, num_frames, crop_size, crop_size, 3)

    return frames_tensor

# Example usage:
video_path = r"E:\train\1-1004\A.Beautiful.Mind.2001__#00-01-45_00-02-50_label_A.mp4"
output_fps = 16
processed_frames_tensor = preprocess_video(video_path, output_fps)
print(f"Processed frames tensor shape: {processed_frames_tensor.shape}")
import jax.numpy as jnp

# Convert NumPy array to JAX NumPy array
jax_array = jnp.array(processed_frames_tensor)



Processed frames tensor shape: (1, 781, 224, 224, 3)


In [7]:
import ml_collections
import haiku as hk
import jax



ipatches = ml_collections.ConfigDict()
ipatches.size = (16, 16, 2)
params = ipatches
output=ViViT(
          
          
          #num_classes=self.dataset_meta_data['num_classes'],
          mlp_dim=3072,
          num_layers=12,
          num_heads=12,
          representation_size=None,
          patches=ipatches,
          hidden_size=768,
          #temporal_encoding_config=self.config.model.temporal_encoding_config,
          #attention_config=self.config.model.attention_config,
          classifier='token',
          dropout_rate=0.1,
          attention_dropout_rate=0.1,
          stochastic_droplayer_rate= 0,
          #return_prelogits=self.config.model.get('return_prelogits', False),
          #return_preclassifier=self.config.model.get(
          #    'return_preclassifier', False),

          dtype=jnp.float32
      )

init_rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}

variables = output.init(init_rngs, jax_array, method=ViViT.__call__)

foutput=output.apply(variables,jax_array,rngs={'dropout': jax.random.key(2)})


In [12]:
import numpy as np

np.save(r'E:\train\1-1004\A.Beautiful.Mind.2001__#00-01-45_00-02-50_label_A.npy',foutput)