In [None]:
import tensorflow as tf
import numpy as np
import sys
# Embedding Model
class Embedder(tf.keras.Model):
  def __init__(self, embedding_size,
               num_context_steps):
    super().__init__()

    # Will download pre-trained ResNet50V2 here
    base_model = tf.keras.applications.resnet_v2.ResNet50V2(include_top=False,
                                        weights='imagenet',
                                        pooling='max')
    layer = 'conv4_block3_out'
    self.num_context_steps = num_context_steps
    self.base_model = tf.keras.Model(
        inputs=base_model.input,
        outputs=base_model.get_layer(layer).output)
    self.conv_layers = [tf.keras.layers.Conv3D(256, 3, padding='same') for _ in range(2)]
    self.bn_layers = [tf.keras.layers.BatchNormalization() for _ in range(2)]

    self.fc_layers = [tf.keras.layers.Dense(256,activation=tf.nn.relu) for _ in range(2)]
    
    self.embedding_layer = tf.keras.layers.Dense(embedding_size)
  
  def call(self, frames, training):
    batch_size, _, h,  w, c = frames.shape
    frames = tf.reshape(frames,[-1, h, w, c])
    x = self.base_model(frames , training=training)
    _, h,  w, c = x.shape
    x = tf.reshape(x, [-1, self.num_context_steps, h, w, c])
    x = self.dropout(x)
    for conv_layer, bn_layer in zip(self.conv_layers,self.bn_layers):
      x = conv_layer(x)
      print("after conv", x.shape)
      x = bn_layer(x)
      x = tf.nn.relu(x)
    x = tf.reduce_max(x, [1, 2, 3])
    _,  c = x.shape
    x = tf.reshape(x, [batch_size, -1, c]) 
    for fc_layer in self.fc_layers:
      x = self.dropout(x)
      x = fc_layer(x)
    x = self.embedding_layer(x)
    return x

In [None]:
import sys
sys.path.append('/home/c1l1mo/projects/VideoAlignment/model')
from transformer.resnet50.resnet50 import ResNet50
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

# Embedding Model
class torch_Embedder(nn.Module):
  def __init__(self, embedding_size, num_context_steps=5):
    super().__init__()

    # Will download pre-trained ResNet50V2 here
    self.resnet = ResNet50(tcc=True)
    self.num_context_steps = num_context_steps
    self.conv_layers = nn.ModuleList([nn.Conv3d(1024, 256, kernel_size=3, padding="same"),nn.Conv3d(256, 256, kernel_size=3, padding="same")])
    self.bn_layers = nn.ModuleList([nn.BatchNorm3d(256)
                                      for _ in range(2)])
    self.maxpool = nn.AdaptiveMaxPool3d(1)
    self.fc_layers = nn.ModuleList([nn.Linear(256, 256)
                                      for _ in range(2)])
    self.embedding_layer = nn.Linear(256, embedding_size)
  
  def forward(self, frames):
    B , T , C , H , W = frames.shape

    x = self.resnet(frames)
    x = x.reshape(-1, self.num_context_steps,1024,14,14)
    x = x.permute(0,2,1,3,4)

    for conv_layer, bn_layer in zip(self.conv_layers,
                                    self.bn_layers):
      x = conv_layer(x)
      x = bn_layer(x)
      x = F.relu(x)

    x = self.maxpool(x)
    x = x.reshape(B, -1, 256)

    for fc_layer in self.fc_layers:
      x = fc_layer(x)
      x = F.relu(x)

    x = self.embedding_layer(x)


    return x

In [None]:
import torch
video = np.random.rand(2, 100, 224, 224, 3).astype(np.float32)
torch_video = torch.from_numpy(video).permute(0,1,4,2,3)

In [None]:
gpus = tf.config.list_physical_devices('GPU')
gpus

In [None]:
tf_model = Embedder(embedding_size=128, num_context_steps=5)
tf_embs = model(video, training=False)
tf_embs.shape

In [None]:
torch_model = torch_Embedder(128)
torch_embs = model(video)
torch_embs.shape