In [2]:
%load_ext autoreload
%autoreload 2

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import numpy as np

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models

import sys
sys.path.insert(1, '../')
from configs.config import get_config

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
config = get_config()
config

callback_config:
  checkpoint_filepath: wandb/model_{epoch}
  early_patience: 6
  rlrp_factor: 0.2
  rlrp_patience: 3
  save_best_only: true
  use_earlystopping: true
  use_reduce_lr_on_plateau: false
  viz_num_images: 100
dataset_config:
  batch_size: 8
  num_crops:
  - 2
  - 5
  shuffle_buffer: 100
  size_crops:
  - 224
  - 96
  use_options: true
model_config:
  anchor_tau: 0.01
  attention_probs_dropout_prob: 0.0
  backbone: ViT
  dropout_rate: 0.5
  hidden_act: gelu
  hidden_dropout_prob: 0.0
  hidden_size: 768
  initializer_range: 0.02
  intermediate_size: 3072
  is_encoder_decoder: false
  layer_norm_eps: 1.0e-12
  mask_ratio: 0.75
  model_img_channels: 3
  model_img_height: 224
  model_img_width: 224
  norm_pix_loss: false
  num_attention_heads: 12
  num_channels: 3
  num_hidden_layers: 12
  num_prototypes: 10
  patch_size: 16
  post_gap_dropout: false
  proj_hidden_size: 1024
  proj_output_dim: 256
  proj_use_bn: true
  qkv_bias: true
  target_tau: 0.1
seed: 0
train_config:
  e

In [4]:
import tensorflow as tf
from tensorflow.keras import models

from transformers.models.vit_mae.configuration_vit_mae import ViTMAEConfig

from msn.model.encoder import ViTMAEEmbeddings
from msn.model.encoder import TFViTMAEEncoder
from msn.model.projection import ProjectionHead

In [5]:
def get_vit_mae_configs(args):
    custom_config = ViTMAEConfig(
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_hidden_layers,
        num_attention_heads=args.num_attention_heads,
        intermediate_size=args.intermediate_size,
        hidden_act=args.hidden_act,
        hidden_dropout_prob=args.hidden_dropout_prob,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        initializer_range=args.initializer_range,
        layer_norm_eps=args.layer_norm_eps,
        is_encoder_decoder=args.is_encoder_decoder,
        image_size=(args.model_img_width, args.model_img_height),
        patch_size=(args.patch_size, args.patch_size),
        num_channels=args.model_img_channels,
        qkv_bias=args.qkv_bias,
        mask_ratio=args.mask_ratio,
        norm_pix_loss=args.norm_pix_loss,
        proj_hidden_size = args.proj_hidden_size,
        proj_output_dim = args.proj_output_dim,
        proj_use_bn = args.proj_use_bn
    )

    return custom_config

In [6]:
args = config.model_config
custom_config = get_vit_mae_configs(args)
custom_config

ViTMAEConfig {
  "attention_probs_dropout_prob": 0.0,
  "decoder_hidden_size": 512,
  "decoder_intermediate_size": 2048,
  "decoder_num_attention_heads": 16,
  "decoder_num_hidden_layers": 8,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": [
    224,
    224
  ],
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "mask_ratio": 0.75,
  "model_type": "vit_mae",
  "norm_pix_loss": false,
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": [
    16,
    16
  ],
  "proj_hidden_size": 1024,
  "proj_output_dim": 256,
  "proj_use_bn": true,
  "qkv_bias": true,
  "transformers_version": "4.23.1"
}

In [8]:
class TFMAEViTModelWithProjection(tf.keras.Model):
    """The encoder model with projection head."""

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.config = get_vit_mae_configs(config)

        self.embeddings = ViTMAEEmbeddings(self.config, name="embeddings")
        self.encoder = TFViTMAEEncoder(self.config, name="encoder")
        self.layernorm = tf.keras.layers.LayerNormalization(
            epsilon=self.config.layer_norm_eps, name="layernorm"
        )
        self.projection_head = ProjectionHead(self.config, name="projection_head")

    def call(
        self,
        pixel_values = None,
        noise: tf.Tensor = None,
        head_mask = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        training: bool = False,
    ) -> tf.Tensor:
        embedding_output, mask, ids_restore = self.embeddings(
            pixel_values=pixel_values, training=training, noise=noise
        )

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.config.num_hidden_layers
        encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training,
        )

        sequence_output = encoder_outputs[0]
        sequence_output = self.layernorm(inputs=sequence_output)
        
        # Pass the [CLS] token to the projection head
        projection_output = self.projection_head(sequence_output[:, 0, :])

        return projection_output


In [10]:
anchor = TFMAEViTModelWithProjection(config.model_config)
target = TFMAEViTModelWithProjection(config.model_config)

In [11]:
from msn.utils import build_and_clone_model

In [12]:
anchor, target = build_and_clone_model(anchor, target, config)

In [13]:
anchor.summary()

Model: "tfmae_vi_t_model_with_projection_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embeddings (ViTMAEEmbedding  multiple                 742656    
 s)                                                              
                                                                 
 encoder (TFViTMAEEncoder)   multiple                  85054464  
                                                                 
 layernorm (LayerNormalizati  multiple                 1536      
 on)                                                             
                                                                 
 projection_head (Projection  multiple                 2107648   
 Head)                                                           
                                                                 
Total params: 87,906,304
Trainable params: 87,750,912
Non-trainable params: 155,392
______________

In [14]:
pixel_values = tf.random.normal((8, 224, 224, 3))
print(anchor(pixel_values=pixel_values))

tf.Tensor(
[[-0.08931544 -0.27112478  0.02645697 ...  0.3183299  -0.06450961
  -0.6526876 ]
 [-0.19172618 -0.18916884  0.02150304 ...  0.24803305  0.03818919
  -0.521096  ]
 [-0.12475991 -0.11191332 -0.08389263 ...  0.5131528   0.1558488
  -0.44654652]
 ...
 [-0.16175449 -0.1768989   0.03484544 ...  0.34487978  0.11445557
  -0.5612816 ]
 [-0.1758292  -0.19360761 -0.02152263 ...  0.3325669   0.07130273
  -0.5162564 ]
 [-0.14243533 -0.23717216 -0.05518706 ...  0.40793493 -0.01335798
  -0.55936456]], shape=(8, 256), dtype=float32)
