In [161]:
from src.datasets.imagenet1k import make_imagenet1k
import yaml
import torch
import pprint
from src.transforms import make_transforms
from src.masks.multiblock import MaskCollator as MBMaskCollator
from src.utils.distributed import (
    init_distributed,
    AllReduce
)

from src.helper import (
    load_checkpoint,
    init_model,
    init_opt)

In [162]:
fname = './configs/in100_vitt_ep1.yaml'
fname = './configs/dev.yaml'

In [163]:
with open(fname, 'r') as y_file:
    args = yaml.load(y_file, Loader=yaml.FullLoader)
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(args)

{   'data': {   'batch_size': 64,
                'color_jitter_strength': 0.0,
                'crop_scale': [0.3, 1.0],
                'crop_size': 224,
                'image_folder': 'data/imagenet_100/',
                'num_workers': 10,
                'pin_mem': True,
                'root_path': '/localscratch/hsun409/',
                'use_color_distortion': False,
                'use_gaussian_blur': False,
                'use_horizontal_flip': False},
    'logging': {   'folder': '/localscratch/hsun409/logs/ijepa/test/',
                   'write_tag': 'jepa'},
    'mask': {   'allow_overlap': False,
                'aspect_ratio': [0.75, 1.5],
                'enc_mask_scale': [0.85, 1.0],
                'min_keep': 10,
                'num_enc_masks': 1,
                'num_pred_masks': 4,
                'patch_size': 14,
                'pred_mask_scale': [0.15, 0.2]},
    'meta': {   'copy_data': False,
                'load_checkpoint': False,
                'mo

In [164]:
use_bfloat16 = args['meta']['use_bfloat16']
model_name = args['meta']['model_name']
load_model = args['meta']['load_checkpoint']
r_file = args['meta']['read_checkpoint']
copy_data = args['meta']['copy_data']
pred_depth = args['meta']['pred_depth']
pred_emb_dim = args['meta']['pred_emb_dim']


# -- DATA
use_gaussian_blur = args['data']['use_gaussian_blur']
use_horizontal_flip = args['data']['use_horizontal_flip']
use_color_distortion = args['data']['use_color_distortion']
color_jitter = args['data']['color_jitter_strength']
# --
batch_size = args['data']['batch_size']
pin_mem = args['data']['pin_mem']
num_workers = args['data']['num_workers']
root_path = args['data']['root_path']
image_folder = args['data']['image_folder']
crop_size = args['data']['crop_size']
crop_scale = args['data']['crop_scale']
# --

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --

# -- OPTIMIZATION
ema = args['optimization']['ema']
ipe_scale = args['optimization']['ipe_scale']  # scheduler scale factor (def: 1.0)
wd = float(args['optimization']['weight_decay'])
final_wd = float(args['optimization']['final_weight_decay'])
num_epochs = args['optimization']['epochs']
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']

# -- LOGGING
folder = args['logging']['folder']
tag = args['logging']['write_tag']

In [5]:
# if not torch.cuda.is_available():
#     device = torch.device('cpu')
# else:
#     device = torch.device('cuda:0')
#     torch.cuda.set_device(device)

device = 'cpu'

In [6]:
transform = make_transforms(
    crop_size=crop_size,
    crop_scale=crop_scale,
    gaussian_blur=use_gaussian_blur,
    horizontal_flip=use_horizontal_flip,
    color_distortion=use_color_distortion,
    color_jitter=color_jitter)


mask_collator = MBMaskCollator(
    input_size=crop_size,
    patch_size=patch_size,
    pred_mask_scale=pred_mask_scale,
    enc_mask_scale=enc_mask_scale,
    aspect_ratio=aspect_ratio,
    nenc=num_enc_masks,
    npred=num_pred_masks,
    allow_overlap=allow_overlap,
    min_keep=min_keep)

world_size, rank = init_distributed()

_, unsupervised_loader, unsupervised_sampler = make_imagenet1k(
        transform=transform,
        batch_size=batch_size,
        collator=mask_collator,
        pin_mem=pin_mem,
        training=True,
        num_workers=num_workers,
        world_size=world_size,
        rank=rank,
        root_path=root_path,
        image_folder=image_folder,
        copy_data=copy_data,
        drop_last=True)

for itr, (udata, masks_enc, masks_pred) in enumerate(unsupervised_loader):

    def load_imgs():
        # -- unsupervised imgs
        imgs = udata[0].to(device, non_blocking=True)
        masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
        masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
        return (imgs, masks_1, masks_2)
    imgs, masks_enc, masks_pred = load_imgs()
    break

INFO:root:making imagenet data transforms
INFO:root:SLURM vars not set (distributed training not available)
INFO:root:data-path /localscratch/hsun409/data/imagenet_100/train/
INFO:root:Initialized ImageNet
INFO:root:ImageNet dataset created
INFO:root:ImageNet unsupervised data loader created


In [168]:
udata[0].shape

torch.Size([64, 3, 224, 224])

In [167]:
udata[1].shape

torch.Size([64])

In [7]:
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)
from src.masks.utils import apply_masks


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)



class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, N, D = x.shape
        device = x.device

        # Determine the number of positions to drop in the masked area
        num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, mask.size(1), device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        return x

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, _, D = x.shape  # Original shape of x
        device = x.device

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels

        return x, pos_bool, mask_no_pos


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed

            if masks is not None:
                x = apply_masks(x, masks)

        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x, pos_bool, mask_no_pos = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if use_decoder:
            assert pos_drop_ratio, 'The function is only tested when pos are dropped.'
            logits = self.forward_decoder(x)
            return x, logits, pos_bool, mask_no_pos

        else:
            return x  # The classical IJEPA

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model




VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [49]:
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)
from src.masks.utils import apply_masks


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)



class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, N, D = x.shape
        device = x.device

        # Determine the number of positions to drop in the masked area
        num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, mask.size(1), device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        return x

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, _, D = x.shape  # Original shape of x
        device = x.device

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels
        pos_labels = torch.sort(mask_no_pos.detach(), dim=1).values

        return x, pos_bool, pos_labels


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed

            if masks is not None:
                x = apply_masks(x, masks)

        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x, pos_bool, pos_labels = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if use_decoder:
            assert pos_drop_ratio, 'The function is only tested when pos are dropped.'
            logits = self.forward_decoder(x)
            return x, logits, pos_bool, pos_labels

        else:
            return x  # The classical IJEPA

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model




VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [56]:
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)
from src.masks.utils import apply_masks


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)



class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, N, D = x.shape
        device = x.device

        # Determine the number of positions to drop in the masked area
        num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, mask.size(1), device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        return x

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, _, D = x.shape  # Original shape of x
        device = x.device
        x_initial = x.clone()

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        # x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        # x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels
        pos_labels = torch.sort(mask_no_pos.detach(), dim=1).values

        return x, pos_bool, pos_labels, x_initial, mask_no_pos


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed

            if masks is not None:
                x = apply_masks(x, masks)

        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            print('here')
            x, pos_bool, pos_labels, x_initial, mask_no_pos = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)
            return x, pos_bool, pos_labels, x_initial, mask_no_pos

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if use_decoder:
            assert pos_drop_ratio, 'The function is only tested when pos are dropped.'
            logits = self.forward_decoder(x)
            return x, logits, pos_bool, pos_labels

        else:
            return x  # The classical IJEPA

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model




VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [57]:
import random
import torch.nn.functional as F
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

encoder = vit_tiny(patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim).to(device)

In [58]:
num_patches = encoder.patch_embed.num_patches

In [59]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.4, use_decoder=True)

here


In [53]:
x, logits, pos_bool, labels = result

logits = logits[pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)].reshape(
        batch_size, -1, num_patches)

loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

In [60]:
x, pos_bool, labels, x_initial, mask_no_pos = result

In [54]:
logits = logits[pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)].reshape(
        batch_size, -1, num_patches)

loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

In [55]:
loss_pos

tensor(5.6027, grad_fn=<NllLoss2DBackward0>)

In [61]:
x, pos_bool, labels, x_initial, mask_no_pos = result

In [62]:
x_initial.shape

torch.Size([64, 256, 192])

In [63]:
mask_no_pos

tensor([[  2,  24, 140,  ...,  57, 109,  26],
        [144, 123,  93,  ..., 224,  45, 163],
        [ 58,   4, 139,  ..., 108,  25,  73],
        ...,
        [ 58,  13, 108,  ..., 142, 138, 109],
        [ 34, 115, 112,  ...,  94,  64,  97],
        [130,  81, 131,  ...,  30, 100,  98]])

In [64]:
sorted_tensor, sorted_indices = torch.sort(mask_no_pos, dim=1)

In [65]:
sorted_tensor

tensor([[  0,   1,   2,  ..., 156, 157, 158],
        [  3,   5,  12,  ..., 225, 226, 238],
        [  3,   4,   8,  ..., 140, 142, 154],
        ...,
        [  1,  13,  14,  ..., 138, 139, 142],
        [  3,   7,  14,  ..., 129, 142, 144],
        [  0,  12,  18,  ..., 130, 131, 144]])

In [66]:
x_drop_pos_from_initial = apply_masks(x_initial, [sorted_tensor])

In [69]:
x_drop_pos_from_function = x[pos_bool.unsqueeze(-1).expand(
    -1, -1, x_initial.size(2))].reshape(
        batch_size, -1, x_initial.size(2))

In [73]:
(x_drop_pos_from_function != x_drop_pos_from_initial).sum()

tensor(0)

In [74]:
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)
from src.masks.utils import apply_masks


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)


class VisionTransformerPredictor(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        num_patches,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=6,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # --
        self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
                                                requires_grad=False)
        predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
                                                      int(num_patches**.5),
                                                      cls_token=False)
        self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
        # --
        self.predictor_blocks = nn.ModuleList([
            Block(
                dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
        # ------
        self.init_std = init_std
        trunc_normal_(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks_x, masks):
        assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        B = len(x) // len(masks_x)

        # -- map from encoder-dim to pedictor-dim
        x = self.predictor_embed(x)

        # -- add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, N_ctxt, D = x.shape

        # -- concat mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
        # --
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        # --
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- fwd prop
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- return preds for mask tokens
        x = x[:, N_ctxt:]
        x = self.predictor_proj(x)

        return x


class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, N, D = x.shape
        device = x.device

        # Determine the number of positions to drop in the masked area
        num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, mask.size(1), device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        return x

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, _, D = x.shape  # Original shape of x
        device = x.device

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels

        # Notice that the labels are sorted as the original order
        pos_labels = torch.sort(mask_no_pos.detach(), dim=1).values

        return x, pos_bool, pos_labels


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed

            if masks is not None:
                x = apply_masks(x, masks)

            pos_bool, pos_labels = None, None

        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x, pos_bool, pos_labels = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if use_decoder:
            assert pos_drop_ratio, 'The function is only tested when pos are dropped.'
            logits = self.forward_decoder(x)
            return x, logits, pos_bool, pos_labels

        else:
            return x  # The classical IJEPA

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_predictor(**kwargs):
    model = VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)
    return model


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_huge(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_giant(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [151]:
"""
Notes on the edits:
    - Add a lightweight decoder with linear head for pos prediction
    - Add a careful pos dropping strategy based on current masking approach
"""

from src.masks.utils import apply_masks
from functools import partial
from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)

import torch.nn as nn
import numpy as np
import math
import torch


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)


class VisionTransformerPredictor(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        num_patches,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=6,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # --
        self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
                                                requires_grad=False)
        predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
                                                      int(num_patches**.5),
                                                      cls_token=False)
        self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
        # --
        self.predictor_blocks = nn.ModuleList([
            Block(
                dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
        # ------
        self.init_std = init_std
        trunc_normal_(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks_x, masks):
        assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        B = len(x) // len(masks_x)

        # -- map from encoder-dim to pedictor-dim
        x = self.predictor_embed(x)

        # -- add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, N_ctxt, D = x.shape

        # -- concat mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
        # --
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        # --
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- fwd prop
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- return preds for mask tokens
        x = x[:, N_ctxt:]
        x = self.predictor_proj(x)

        return x


class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        # This function will be used in the forward part

        # ---------------------------------------------------------- #
        # Preparation
        B, _, D = x.shape  # Original shape of x
        device = x.device

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Drop the positions
        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Case 1: Replace pos_embed with mask_pos_token
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        # Case 2: Retain the pos_embed
        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the variables that we want
        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels

        # Notice that the labels are sorted as the original order
        pos_labels = torch.sort(mask_no_pos.detach(), dim=1).values

        # x.shape = (B, N_m, D)
        # pos_bool.shape = (B, N_m)
        # pos_labels.shape = (B, int(N_m * pos_drop_ratio))
        # ----------------------------------------------------------

        return x, pos_bool, pos_labels


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        """
        masks: a list of masks; for context there should only be one mask.
            each mask has shape (batch_size, no. patches to keep)

        pos_drop_ratio: the ratio to drop the positions from the context patches

        user_decoder: we apply a lightweight decoder for position prediction.
            if we use decoder, we return additional pos_logits, pos_bool, pos_labels
        """

        # ---------------------------------------------------------- #
        # Handle the mask
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the positional embeddings for each patch
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed
            if masks is not None:
                x = apply_masks(x, masks)
            pos_bool, pos_labels = None, None

        # ---------------------------------------------------------- #
        # When we drop the positional embeddings:
        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x, pos_bool, pos_labels = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Forward and apply norm
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the position prediction results
        if use_decoder:
            assert pos_drop_ratio, 'Only tested when pos are dropped.'
            pos_logits = self.forward_decoder(x)
            return x, pos_logits, pos_bool, pos_labels

            # Usage for the logits
            # logits.shape = [B, N_m, N_patches]
            # pos_labels.shape = [B, int(N_m * pos_drop_ratio)]
            # We don't predict the labels for those with pos_emb
            # so we should do:
            # logits = logits[pos_bool.unsqueeze(-1).expand(
            #          -1, -1, N_patches)].reshape(
            #          batch_size, -1, N_patches)
            # Here N_patches essentially is N_classes for positions
            # loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

        # If not use decoder, just classical IJEPA
        else:
            return x
        # ----------------------------------------------------------

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_predictor(**kwargs):
    model = VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)
    return model


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_huge(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_giant(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [155]:
"""
Notes on the edits:
    - Add a lightweight decoder with linear head for pos prediction
    - Add a careful pos dropping strategy based on current masking approach
"""

from src.masks.utils import apply_masks
from typing import List, Tuple
from functools import partial
from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)

import torch.nn as nn
import numpy as np
import math
import torch


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)


class VisionTransformerPredictor(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        num_patches,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=6,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        **kwargs
    ):
        super().__init__()
        self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # --
        self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
                                                requires_grad=False)
        predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
                                                      int(num_patches**.5),
                                                      cls_token=False)
        self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
        # --
        self.predictor_blocks = nn.ModuleList([
            Block(
                dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.predictor_norm = norm_layer(predictor_embed_dim)
        self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
        # ------
        self.init_std = init_std
        trunc_normal_(self.mask_token, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.predictor_blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x, masks_x, masks):
        assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'

        if not isinstance(masks_x, list):
            masks_x = [masks_x]

        if not isinstance(masks, list):
            masks = [masks]

        # -- Batch Size
        B = len(x) // len(masks_x)

        # -- map from encoder-dim to pedictor-dim
        x = self.predictor_embed(x)

        # -- add positional embedding to x tokens
        x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
        x += apply_masks(x_pos_embed, masks_x)

        _, N_ctxt, D = x.shape

        # -- concat mask tokens to x
        pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
        pos_embs = apply_masks(pos_embs, masks)
        pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
        # --
        pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
        # --
        pred_tokens += pos_embs
        x = x.repeat(len(masks), 1, 1)
        x = torch.cat([x, pred_tokens], dim=1)

        # -- fwd prop
        for blk in self.predictor_blocks:
            x = blk(x)
        x = self.predictor_norm(x)

        # -- return preds for mask tokens
        x = x[:, N_ctxt:]
        x = self.predictor_proj(x)

        return x


class FeatAvgPool(nn.Module):
    def __init__(self):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        # bs, seq_len, dims = x.shape
        x = x.permute((0, 2, 1))
        return self.avg_pool(x).squeeze()


class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        self.avg_pool = FeatAvgPool()
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        """
        Helper functions to be used in the forward part to drop positions.
        """
        # ---------------------------------------------------------- #
        # Preparation
        B, _, D = x.shape  # Original shape of x
        device = x.device

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Drop the positions
        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Case 1: Replace pos_embed with mask_pos_token
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        # Case 2: Retain the pos_embed
        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the variables that we want
        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels

        # Notice that the labels are sorted as the original order
        pos_labels = torch.sort(mask_no_pos.detach(), dim=1).values

        # x.shape = (B, N_m, D)
        # pos_bool.shape = (B, N_m)
        # pos_labels.shape = (B, int(N_m * pos_drop_ratio))
        # ----------------------------------------------------------

        return x, pos_bool, pos_labels

    def forward_decoder(self, x):
        """
        This will be used at the forward part to
        get the logits for positions.
        """

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x,
                masks=None,
                pos_drop_ratio: float=0,
                use_pos_predictor: bool=False,
                out_feat_keys: List[str]=None):
        """
        masks: a list of masks; for context there should only be one mask.
            each mask has shape (batch_size, no. patches to keep)

        pos_drop_ratio: the ratio to drop the positions from the context patches

        user_decoder: we apply a lightweight decoder for position prediction.
            if we use decoder, we return additional pos_logits, pos_bool, pos_labels
        """
        # ---------------------------------------------------------- #
        # Get features for the evaluation; see methods at the end
        if out_feat_keys:
            x = self.get_intermediate_features(x, masks, out_feat_keys)
            return x
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Handle the mask
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the positional embeddings for each patch
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed
            if masks is not None:
                x = apply_masks(x, masks)
            pos_bool, pos_labels = None, None

        # ---------------------------------------------------------- #
        # When we drop the positional embeddings:
        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x, pos_bool, pos_labels = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Forward and apply norm
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the position prediction results
        if use_pos_predictor:
            assert pos_drop_ratio, 'Only tested when pos are dropped.'
            pos_logits = self.forward_decoder(x)
            return x, pos_logits, pos_bool, pos_labels

            # Usage for the logits
            # logits.shape = [B, N_m, N_patches]
            # pos_labels.shape = [B, int(N_m * pos_drop_ratio)]
            # We don't predict the labels for those with pos_emb
            # so we should do:
            # logits = logits[pos_bool.unsqueeze(-1).expand(
            #          -1, -1, N_patches)].reshape(
            #          batch_size, -1, N_patches)
            # Here N_patches essentially is N_classes for positions
            # loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

        # If not use decoder, just classical IJEPA
        else:
            return x
        # ----------------------------------------------------------

    # ======================================================================= #
    # Helper functions to get features (will be called with no_grad for eval)
    def prepare_tokens(self, x: torch.Tensor, masks=None) -> torch.Tensor:

        # ---------------------------------------------------------- #
        # We shouldn't use mask for eval, but let's just keep it here
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape
        # ----------------------------------------------------------

        # ---------------------------------------------------------- #
        # Get the positional embeddings for each patch
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # In evaluation we will always add the pos_emb
        x = x + pos_embed

        # We shouldn't use mask for eval, but let's just keep it here
        if masks is not None:
            x = apply_masks(x, masks)
        # ----------------------------------------------------------

        return x

    def get_intermediate_features(self, x, masks=None,
            names: List[str]=None) -> List[torch.Tensor]:
        """
        Given a list of feature names, return a list of the same length
        where each output correspond to the desired feature.

        To align with ijepa, the available features are:
        - lastpool
        - concatpool4
        """

        # Prepare tokens (patchify and add positional encoding)
        x = self.prepare_tokens(x, masks)

        # Determine the number of layers to keep based on requested features
        keep_last_n = 1 if 'lastpool' in names else 0
        if any(name.startswith('concatpool') for name in names):
            keep_last_n = max(keep_last_n, 4)  # Keep last 4 for concatPOOL4

        # Buffer to store outputs of the required last N layers
        interms_buffer = collections.deque(maxlen=keep_last_n)

        # Forward propagation
        for i, blk in enumerate(self.blocks):
            x = blk(x)

            # Append to buffer if in the last N layers
            if i >= len(self.blocks) - keep_last_n:
                interms_buffer.append(x)

        if self.norm is not None:
            x = self.norm(x)
            interms_buffer[-1] = x

        output = []
        for name in names:
            if name == 'lastpool':
                output.append(self.avg_pool(interms_buffer[-1]))
            elif name.startswith('concatpool'):
                concat_features = torch.cat([self.avg_pool(layer) for layer in interms_buffer], dim=-1)
                output.append(concat_features)

        return output


def vit_predictor(**kwargs):
    model = VisionTransformerPredictor(
        mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs)
    return model


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_small(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_base(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_large(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_huge(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def vit_giant(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [160]:
encoder

VisionTransformer(
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
  (avg_pool): FeatAvgPool(
    (avg_pool): AdaptiveAvgPool1d(output_size=1)
  )
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(14, 14), stride=(14, 1

In [156]:
import random
import torch.nn.functional as F
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

encoder = vit_tiny(patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim,
    decoder_embed_dim=256,
    decoder_num_heads=2,
    decoder_depth=2).to(device)

num_patches = encoder.patch_embed.num_patches

In [158]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.4, use_pos_predictor=True)

x, pos_logits, pos_bool, pos_labels = result

pos_logits = pos_logits[pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)].reshape(
        batch_size, -1, num_patches)

pos_loss = F.cross_entropy(pos_logits.permute(0, 2, 1), pos_labels)
pos_loss

tensor(5.6027, grad_fn=<NllLoss2DBackward0>)

In [145]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.4, use_decoder=True)

x, pos_logits, pos_bool, pos_labels = result

pos_logits = pos_logits[pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)].reshape(
        batch_size, -1, num_patches)

pos_loss = F.cross_entropy(pos_logits.permute(0, 2, 1), pos_labels)
pos_loss

tensor(5.6027, grad_fn=<NllLoss2DBackward0>)

In [159]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.4, use_pos_predictor=False)

# x, pos_logits, pos_bool, pos_labels = result

# pos_logits = pos_logits[pos_bool.unsqueeze(-1).expand(
#     -1, -1, num_patches)].reshape(
#         batch_size, -1, num_patches)

# pos_loss = F.cross_entropy(pos_logits.permute(0, 2, 1), pos_labels)
print(result[0][0][0])

tensor(-2.0316, grad_fn=<SelectBackward0>)


In [146]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.4, use_decoder=False)

# x, pos_logits, pos_bool, pos_labels = result

# pos_logits = pos_logits[pos_bool.unsqueeze(-1).expand(
#     -1, -1, num_patches)].reshape(
#         batch_size, -1, num_patches)

# pos_loss = F.cross_entropy(pos_logits.permute(0, 2, 1), pos_labels)
print(result[0][0][0])

tensor(-2.0316, grad_fn=<SelectBackward0>)


In [147]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0, use_decoder=False)

# x, pos_logits, pos_bool, pos_labels = result

# pos_logits = pos_logits[pos_bool.unsqueeze(-1).expand(
#     -1, -1, num_patches)].reshape(
#         batch_size, -1, num_patches)

# pos_loss = F.cross_entropy(pos_logits.permute(0, 2, 1), pos_labels)
print(result[0][0][0])

tensor(-2.3050, grad_fn=<SelectBackward0>)


In [148]:
result = encoder(imgs, None, pos_drop_ratio=0, use_decoder=False)

# x, pos_logits, pos_bool, pos_labels = result

# pos_logits = pos_logits[pos_bool.unsqueeze(-1).expand(
#     -1, -1, num_patches)].reshape(
#         batch_size, -1, num_patches)

# pos_loss = F.cross_entropy(pos_logits.permute(0, 2, 1), pos_labels)
print(result[0][0][0])

tensor(-2.3018, grad_fn=<SelectBackward0>)


In [103]:
pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)[0][4]

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [104]:
logits.shape

torch.Size([64, 73, 256])

In [105]:
logits = logits[pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)].reshape(
        batch_size, -1, num_patches)

In [106]:
logits.shape

torch.Size([64, 29, 256])

In [107]:
labels.shape

torch.Size([64, 29])

In [84]:
x, logits, pos_bool, labels = result

logits = logits[pos_bool.unsqueeze(-1).expand(
    -1, -1, num_patches)].reshape(
        batch_size, -1, num_patches)

loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

In [79]:
loss_pos

tensor(5.6027, grad_fn=<NllLoss2DBackward0>)

In [80]:
pos_bool.shape

torch.Size([64, 73])

In [81]:
labels.shape

torch.Size([64, 29])