In [1]:
from src.datasets.imagenet1k import make_imagenet1k
import yaml
import torch
import pprint
from src.transforms import make_transforms
from src.masks.multiblock import MaskCollator as MBMaskCollator
from src.utils.distributed import (
    init_distributed,
    AllReduce
)

from src.helper import (
    load_checkpoint,
    init_model,
    init_opt)

In [2]:
fname = './configs/in100_vitt_ep1.yaml'

In [3]:
with open(fname, 'r') as y_file:
    args = yaml.load(y_file, Loader=yaml.FullLoader)
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(args)

{   'data': {   'batch_size': 64,
                'color_jitter_strength': 0.0,
                'crop_scale': [0.3, 1.0],
                'crop_size': 224,
                'image_folder': 'data/imagenet_100/',
                'num_workers': 10,
                'pin_mem': True,
                'root_path': '/localscratch/hsun409/',
                'use_color_distortion': False,
                'use_gaussian_blur': False,
                'use_horizontal_flip': False},
    'logging': {   'folder': '/localscratch/hsun409/logs/ijepa/test/',
                   'write_tag': 'jepa'},
    'mask': {   'allow_overlap': False,
                'aspect_ratio': [0.75, 1.5],
                'enc_mask_scale': [0.85, 1.0],
                'min_keep': 10,
                'num_enc_masks': 1,
                'num_pred_masks': 4,
                'patch_size': 14,
                'pred_mask_scale': [0.15, 0.2]},
    'meta': {   'copy_data': False,
                'load_checkpoint': False,
                'mo

In [4]:
use_bfloat16 = args['meta']['use_bfloat16']
model_name = args['meta']['model_name']
load_model = args['meta']['load_checkpoint']
r_file = args['meta']['read_checkpoint']
copy_data = args['meta']['copy_data']
pred_depth = args['meta']['pred_depth']
pred_emb_dim = args['meta']['pred_emb_dim']


# -- DATA
use_gaussian_blur = args['data']['use_gaussian_blur']
use_horizontal_flip = args['data']['use_horizontal_flip']
use_color_distortion = args['data']['use_color_distortion']
color_jitter = args['data']['color_jitter_strength']
# --
batch_size = args['data']['batch_size']
pin_mem = args['data']['pin_mem']
num_workers = args['data']['num_workers']
root_path = args['data']['root_path']
image_folder = args['data']['image_folder']
crop_size = args['data']['crop_size']
crop_scale = args['data']['crop_scale']
# --

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --

# -- OPTIMIZATION
ema = args['optimization']['ema']
ipe_scale = args['optimization']['ipe_scale']  # scheduler scale factor (def: 1.0)
wd = float(args['optimization']['weight_decay'])
final_wd = float(args['optimization']['final_weight_decay'])
num_epochs = args['optimization']['epochs']
warmup = args['optimization']['warmup']
start_lr = args['optimization']['start_lr']
lr = args['optimization']['lr']
final_lr = args['optimization']['final_lr']

# -- LOGGING
folder = args['logging']['folder']
tag = args['logging']['write_tag']

In [5]:
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# device = 'cpu'

In [6]:
device

device(type='cuda', index=0)

In [7]:
transform = make_transforms(
    crop_size=crop_size,
    crop_scale=crop_scale,
    gaussian_blur=use_gaussian_blur,
    horizontal_flip=use_horizontal_flip,
    color_distortion=use_color_distortion,
    color_jitter=color_jitter)

INFO:root:making imagenet data transforms


In [8]:
pprint.pprint(transform)

Compose(
    RandomResizedCrop(size=(224, 224), scale=(0.3, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
    ToTensor()
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
)


In [9]:
mask_collator = MBMaskCollator(
    input_size=crop_size,
    patch_size=patch_size,
    pred_mask_scale=pred_mask_scale,
    enc_mask_scale=enc_mask_scale,
    aspect_ratio=aspect_ratio,
    nenc=num_enc_masks,
    npred=num_pred_masks,
    allow_overlap=allow_overlap,
    min_keep=min_keep)

In [10]:
world_size, rank = init_distributed()

INFO:root:SLURM vars not set (distributed training not available)


In [11]:
image_folder

'data/imagenet_100/'

In [12]:
_, unsupervised_loader, unsupervised_sampler = make_imagenet1k(
        transform=transform,
        batch_size=batch_size,
        collator=mask_collator,
        pin_mem=pin_mem,
        training=True,
        num_workers=num_workers,
        world_size=world_size,
        rank=rank,
        root_path=root_path,
        image_folder=image_folder,
        copy_data=copy_data,
        drop_last=True)

INFO:root:data-path /localscratch/hsun409/data/imagenet_100/train/
INFO:root:Initialized ImageNet
INFO:root:ImageNet dataset created
INFO:root:ImageNet unsupervised data loader created


In [13]:
for itr, (udata, masks_enc, masks_pred) in enumerate(unsupervised_loader):

    def load_imgs():
        # -- unsupervised imgs
        imgs = udata[0].to(device, non_blocking=True)
        masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
        masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
        return (imgs, masks_1, masks_2)
    imgs, masks_enc, masks_pred = load_imgs()
    break

In [14]:
# masks = masks_enc 
# mask = masks[0]

In [15]:
imgs.shape

torch.Size([64, 3, 224, 224])

In [236]:
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)
from src.masks.utils import apply_masks


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)



class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, N, D = x.shape
        device = x.device

        # Determine the number of positions to drop in the masked area
        num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, mask.size(1), device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        return x
                
    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, _, D = x.shape  # Original shape of x
        device = x.device
        x_initial = x.clone()  # Save the original x for later use


        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        # x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        # x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x_restored = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # # Create a mask to identify positions with dropped positional embeddings
        # pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)  # Adjust mask size to [B, N_m]
        # drop_indices_restored = restored_indices[:, :num_pos_to_drop]

        # # Use advanced indexing to set the dropped positions to True
        # batch_indices = torch.arange(B, device=device).view(-1, 1)
        # pos_drop_mask[batch_indices, drop_indices_restored] = True

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_drop_mask = shuffled_pos_drop_mask.gather(1, restored_indices)


        return x_no_pos, x_restored, pos_drop_mask, x_initial, mask_no_pos


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed
            if masks is not None: x = apply_masks(x, masks)
            mask_test = None, None

        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x_no_pos, x, pos_drop_mask, x_initial, mask_no_pos = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)
            
            return x_no_pos, x, pos_drop_mask, x_initial, mask_no_pos

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if use_decoder:
            x = self.forward_decoder(x)
            x = self.decoder_pred(x)  # from embed_dim to num_patches
            return x, pos_drop_mask 

        return x

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model




VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [315]:
a = torch.randn(4, 4)
print(a)


torch.argsort(a, dim=1)

tensor([[ 0.5080,  0.7931,  0.0418, -0.9466],
        [ 0.4480,  1.2964, -1.4254, -1.4578],
        [ 0.7382, -0.0849, -0.0025, -0.9821],
        [ 1.0544,  1.7277, -0.5480,  1.0907]])


tensor([[3, 2, 0, 1],
        [3, 2, 0, 1],
        [3, 1, 2, 0],
        [2, 0, 3, 1]])

In [313]:
del encoder

In [316]:
x.shape

torch.Size([64, 80, 192])

In [314]:
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

encoder = vit_tiny(patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim).to(device)

result = encoder(imgs, masks_enc, pos_drop_ratio=0.2)
x_no_pos, x, pos_drop_mask, x_initial, mask_no_pos = result
kk = x[pos_drop_mask.unsqueeze(-1).expand(-1, -1, 192)].view(batch_size, -1, 192)
dd = apply_masks(x_initial, [mask_no_pos])
sorted_tensor, sorted_indices = torch.sort(mask_no_pos, dim=1)
ee = apply_masks(x_initial, [sorted_tensor])
kk == ee

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [237]:
import random
random.seed(42)

In [238]:
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7f2089e2c110>

In [239]:
del encoder

In [240]:
encoder = vit_tiny(patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim).to(device)

In [241]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.2)

In [242]:
len(result)

5

In [243]:
x_no_pos, x, pos_drop_mask, x_initial, mask_no_pos = result

In [244]:
kk = x[pos_drop_mask.unsqueeze(-1).expand(-1, -1, 192)].view(batch_size, -1, 192)

In [246]:
x_initial.shape

torch.Size([64, 256, 192])

In [247]:
mask_no_pos.shape

torch.Size([64, 16])

In [248]:
dd = apply_masks(x_initial, [mask_no_pos])

In [305]:
mask_no_pos

tensor([[ 21,   1, 110,  ..., 109,   4,  75],
        [ 66,  16,  80,  ...,  32,  46,  81],
        [ 44,  67, 130,  ...,  34,  84, 179],
        ...,
        [192,  45, 198,  ...,  59,   6, 201],
        [ 27,  14,   7,  ...,  48,  28,  67],
        [141,  60, 200,  ..., 204, 196, 110]], device='cuda:0')

In [306]:
sorted_tensor, sorted_indices = torch.sort(mask_no_pos, dim=1)


In [307]:
sorted_tensor

tensor([[  1,   3,   4,  ..., 109, 110, 178],
        [  6,  16,  19,  ...,  92,  97, 109],
        [ 20,  32,  33,  ..., 160, 179, 192],
        ...,
        [  6,   9,  12,  ..., 193, 198, 201],
        [  3,   6,   7,  ..., 112, 119, 158],
        [ 26,  29,  42,  ..., 196, 200, 204]], device='cuda:0')

In [312]:
sorted_tensor.shape

torch.Size([64, 16])

In [308]:
ee = apply_masks(x_initial, [sorted_tensor])

In [249]:
mask_no_pos.shape

torch.Size([64, 16])

In [250]:
(mask_no_pos[0].max())

tensor(178, device='cuda:0')

In [251]:
(mask_no_pos[0].argmax())

tensor(5, device='cuda:0')

In [289]:
mask_no_pos[0]

tensor([ 21,   1, 110,  13,   3, 178,  32,  27,  40,  16,  49,  51,  24, 109,
          4,  75], device='cuda:0')

In [290]:
mask_no_pos[0].argsort()

tensor([ 1,  4, 14,  3,  9,  0, 12,  7,  6,  8, 10, 11, 15, 13,  2,  5],
       device='cuda:0')

In [311]:
kk == ee

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [301]:
kk[0, 4, 1]

tensor(0.1167, device='cuda:0', grad_fn=<SelectBackward0>)

In [295]:
dd[0, -2, 1]

tensor(0.5047, device='cuda:0', grad_fn=<SelectBackward0>)

In [302]:
x_initial[0, 16, 1]

tensor(0.1167, device='cuda:0', grad_fn=<SelectBackward0>)

In [257]:
dd.shape

torch.Size([64, 16, 192])

In [118]:
x_no_pos.shape

torch.Size([64, 16, 192])

In [119]:
x.shape

torch.Size([64, 80, 192])

In [120]:
pos_drop_mask.shape

torch.Size([64, 80])

In [124]:
kk = x[pos_drop_mask.unsqueeze(-1).expand(-1, -1, 192)].view(batch_size, -1, 192)

In [125]:
kk.shape

torch.Size([64, 16, 192])

In [128]:
x_no_pos[0][0]

tensor([-2.3067,  0.4483, -1.2291,  0.8508, -0.2626, -2.4275, -1.2877,  0.6114,
        -0.4339, -0.5747,  0.1223, -0.1771,  0.8479, -1.1140,  0.1051,  1.4932,
         2.1093, -0.1131,  1.2370, -0.1574,  0.3884, -1.4170,  0.1403,  0.3876,
        -0.0950,  0.8172, -0.4414,  1.6031, -0.8563,  0.2816,  0.4774, -0.1175,
        -1.3723, -0.8352,  1.1400,  0.7687, -0.6594, -0.2641,  1.5124, -0.8623,
        -0.9539,  1.1319, -1.2194, -0.4562, -0.1077, -1.0779, -0.3600, -1.2942,
         1.5568,  0.7144, -0.9447, -0.4813, -0.0278, -0.4884, -0.9560, -0.2270,
        -0.6050, -0.3337, -0.8963,  0.2019, -1.4995, -1.3187, -1.0790, -0.7769,
        -0.7151, -0.6373, -0.5240, -0.0088, -0.3877,  0.0799,  0.9536,  0.4070,
        -0.4604, -0.8042, -0.0260,  0.8930, -2.5984, -0.9909, -0.3840,  0.5446,
        -0.8373, -0.9194,  0.9641, -0.9640,  0.8740,  0.4703,  0.8944,  4.0453,
         0.8769, -1.2336, -0.6158, -0.0587, -0.0862,  0.2332, -0.2544, -1.8226,
        -0.0378,  0.1936, -0.2143,  0.46

In [129]:
kk[0][0]

tensor([-2.9658e+00,  2.7459e-01, -1.2602e+00,  9.4208e-01, -4.1614e-01,
        -2.6972e+00, -1.3175e+00,  5.2243e-01, -6.0261e-02, -2.7417e-01,
         2.6612e-02,  3.2121e-02,  8.2448e-01, -1.1622e+00,  3.0622e-02,
         1.4931e+00,  1.9149e+00, -1.0901e-01,  1.4340e+00, -7.0940e-02,
         4.5717e-01, -1.6009e+00,  1.7435e-01, -2.3522e-01, -4.1109e-01,
         1.2431e+00, -1.5995e-01,  1.8360e+00, -9.3732e-01,  4.5272e-01,
         6.7320e-01, -2.5089e-02, -1.1962e+00, -1.1338e+00,  1.1345e+00,
         7.2322e-01, -7.5964e-01, -2.0787e-01,  1.8721e+00, -8.4484e-01,
        -1.4141e+00,  1.1036e+00, -1.0093e+00, -3.3924e-01,  3.2352e-01,
        -1.2067e+00,  1.6980e-01, -8.9869e-01,  1.4213e+00,  9.3691e-01,
        -9.3960e-01, -1.3784e-01, -1.9907e-01, -7.0311e-01, -6.2364e-01,
        -2.2677e-01, -8.4872e-01, -5.4525e-01, -9.6562e-01,  1.4417e-02,
        -1.7230e+00, -1.3977e+00, -1.3073e+00, -1.1918e+00, -7.9442e-01,
        -1.0835e+00, -5.7867e-01, -2.7028e-03, -7.4

In [114]:
pred, mask_test = encoder(imgs, masks_enc, pos_drop_ratio=0.2, use_decoder=True)

ValueError: too many values to unpack (expected 2)

In [106]:
pred.shape

torch.Size([64, 80, 256])

In [107]:
mask_test.shape

torch.Size([64, 80])

In [86]:
mask_test.shape

torch.Size([64, 80])

In [87]:
pred.shape

torch.Size([64, 80, 256])

In [95]:
x_1.shape

torch.Size([3, 3])

In [96]:
x_1 = torch.tensor([[[1, 1], [2, 2], [3, 3]], 
                    [[1, 1], [2, 2], [3, 3]], 
                    [[1, 1], [2, 2], [3, 3]]])

In [97]:
mask_1 = torch.tensor([[0, 1], [0, 1], [1, 2]])

In [98]:
apply_masks(x_1, [mask_1])

tensor([[[1, 1],
         [2, 2]],

        [[1, 1],
         [2, 2]],

        [[2, 2],
         [3, 3]]])

In [82]:
mask_test.unsqueeze(-1).expand(-1, -1, encoder.patch_embed.num_patches).shape

torch.Size([64, 80, 256])

In [89]:
kk = pred[mask_test.unsqueeze(-1).expand(-1, -1, encoder.patch_embed.num_patches)].view(batch_size, -1, encoder.patch_embed.num_patches)

In [91]:
kk.shape

torch.Size([64, 16, 256])

In [85]:
pred[mask_test.unsqueeze(-1).expand(-1, -1, encoder.patch_embed.num_patches)].shape

torch.Size([262144])

In [71]:
pred_no_pos = apply_masks(pred, [mask_test])

RuntimeError: gather(): Expected dtype int64 for index

In [68]:
pred.shape

torch.Size([64, 80, 256])

In [69]:
mask_test.shape

torch.Size([64, 80])

In [70]:
result.shape

torch.Size([64, 80, 192])

In [24]:
masks_enc[0].shape

torch.Size([64, 80])

In [32]:
full_mask = torch.zeros((3, 4), dtype=torch.bool, device=device)

In [33]:
full_mask 

tensor([[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]], device='cuda:0')

In [34]:
mask_no_pos = torch.tensor([[0, 1],
                            [0, 1],
                            [1, 2]], device=device)

In [35]:
full_mask.scatter_(1, mask_no_pos, True)

tensor([[ True,  True, False, False],
        [ True,  True, False, False],
        [False,  True,  True, False]], device='cuda:0')

In [324]:
import math
from functools import partial
import numpy as np

import torch
import torch.nn as nn

from src.utils.tensors import (
    trunc_normal_,
    repeat_interleave_batch
)
from src.masks.utils import apply_masks


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=float)
    grid_w = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid length
    return:
    pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid = np.arange(grid_size, dtype=float)
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega   # (D/2,)

    pos = pos.reshape(-1)   # (M,)
    out = np.einsum('m,d->md', pos, omega)   # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, return_attention=False):
        y, attn = self.attn(self.norm1(x))
        if return_attention:
            return attn
        x = x + self.drop_path(y)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class ConvEmbed(nn.Module):
    """
    3x3 Convolution stems for ViT following ViTC models
    """

    def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
        super().__init__()
        # Build the stems
        stem = []
        channels = [in_chans] + channels
        for i in range(len(channels) - 2):
            stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
                               stride=strides[i], padding=1, bias=(not batch_norm))]
            if batch_norm:
                stem += [nn.BatchNorm2d(channels[i+1])]
            stem += [nn.ReLU(inplace=True)]
        stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
        self.stem = nn.Sequential(*stem)

        # Comptute the number of patches
        stride_prod = int(np.prod(strides))
        self.num_patches = (img_size[0] // stride_prod)**2

    def forward(self, x):
        p = self.stem(x)
        return p.flatten(2).transpose(1, 2)



class VisionTransformer(nn.Module):
    """ Vision Transformer """
    def __init__(
        self,
        img_size=[224],
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        predictor_embed_dim=384,
        depth=12,
        predictor_depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        init_std=0.02,
        decoder_embed_dim=256,
        decoder_num_heads=2,
        decoder_depth=2,
        **kwargs
    ):
        super().__init__()

        # ---------------------------------------------------------------------- #
        # Encoder settings
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Set up the stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Set up the encoder blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
                  qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
                  attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Patch settings
        self.patch_embed = PatchEmbed(img_size=img_size[0],
                                      patch_size=patch_size,
                                      in_chans=in_chans,
                                      embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Position settings
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
                                            int(self.patch_embed.num_patches**.5),
                                            cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Mask token settings
        self.mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Decoder settings (a light weight decoder just for position prediction)
        # Require additional parameters:
        # - decoder_emebed_dim
        # - decoder_num_heads
        # - decoder_depth
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_blocks = nn.ModuleList([
            Block(dim=decoder_embed_dim, num_heads=decoder_num_heads,
                  mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(decoder_depth)])
        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, num_patches, bias=True)
        # ----------------------------------------------------------------------

        # ---------------------------------------------------------------------- #
        # Weight Initialiazation
        self.init_std = init_std
        self.apply(self._init_weights)
        self.fix_init_weight()
        # ----------------------------------------------------------------------

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

        # $$$$ Also initialize the mask_pos_token
        torch.nn.init.normal_(self.mask_pos_token, std=.02)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, N, D = x.shape
        device = x.device

        # Determine the number of positions to drop in the masked area
        num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, mask.size(1), device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        return x

    def apply_pos_drop_mask(self, x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
        B, _, D = x.shape  # Original shape of x
        device = x.device

        # Determine the number of positions to drop in the masked area
        N_m = mask.size(1)  # Number of patches to keep after the mask is applied
        num_pos_to_drop = int(N_m * pos_drop_ratio)

        # Shuffle mask along the last dimension
        random_tensor = torch.rand(B, N_m, device=device)
        shuffled_indices = random_tensor.argsort(dim=1)
        shuffled_mask = mask.gather(1, shuffled_indices)

        # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
        mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
        mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

        # Apply the masks to x
        x_no_pos = apply_masks(x, [mask_no_pos])
        x_keep_pos = apply_masks(x, [mask_keep_pos])

        # Apply pos_embed and mask_pos_token accordingly
        mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1).to(device)
        x_no_pos = x_no_pos + mask_pos_tokens

        pos_embed = pos_embed.repeat(B, 1, 1).to(device)
        pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
        x_keep_pos = x_keep_pos + pos_embed_masked

        # Concatenate the results and shuffle again to restore the original order
        x = torch.cat([x_no_pos, x_keep_pos], dim=1)
        restored_indices = torch.argsort(shuffled_indices, dim=1)
        x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

        # Create a boolean mask in the shuffled order
        shuffled_pos_drop_mask = torch.zeros(B, N_m, dtype=torch.bool, device=device)
        shuffled_pos_drop_mask[:, :num_pos_to_drop] = True  # Mark the first num_pos_to_drop as True

        # Restore the order of the boolean mask to match x_restored
        pos_bool = shuffled_pos_drop_mask.gather(1, restored_indices)

        # The pos_drop_bool is used to apply on x to get the ones
        # whose positional embeddings are dropped
        # to apply it, you should you use it like
        # x_ = x[pos_drop_bool.unsqueeze(-1).expand(-1, -1, D)].reshape(B, -1, D)
        # Differently, mask_no_pos contains the original indices (no. of the patch)
        # and will be used as labels

        return x, pos_bool, mask_no_pos


    def forward_decoder(self, x):

        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)  # from decoder_embed_dim to num_patches

        return x

    def forward(self, x, masks=None, pos_drop_ratio=0, use_decoder=False):
        if masks is not None:
            if not isinstance(masks, list):
                masks = [masks]

        # -- patchify x
        x = self.patch_embed(x)
        B, N, D = x.shape

        # -- add positional embedding to x
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)

        # When we do not drop the positional embeddings:
        if not pos_drop_ratio:
            x += pos_embed

            if masks is not None:
                x = apply_masks(x, masks)

        else:
            assert len(masks) == 1, 'Only one mask is needed for the context.'
            x, pos_bool, mask_no_pos = self.apply_pos_drop_mask(
                x, pos_embed, self.mask_pos_token, masks[0], pos_drop_ratio)

        # -- fwd prop
        for i, blk in enumerate(self.blocks):
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if use_decoder:
            assert pos_drop_ratio, 'The function is only tested when pos are dropped.'
            logits = self.forward_decoder(x)
            return x, logits, pos_bool, mask_no_pos

        else:
            return x  # The classical IJEPA

    def interpolate_pos_encoding(self, x, pos_embed):
        npatch = x.shape[1] - 1
        N = pos_embed.shape[1] - 1
        if npatch == N:
            return pos_embed
        class_emb = pos_embed[:, 0]
        pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=math.sqrt(npatch / N),
            mode='bicubic',
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)


def vit_tiny(patch_size=16, **kwargs):
    model = VisionTransformer(
        patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
        qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model




VIT_EMBED_DIMS = {
    'vit_tiny': 192,
    'vit_small': 384,
    'vit_base': 768,
    'vit_large': 1024,
    'vit_huge': 1280,
    'vit_giant': 1408,
}


In [325]:
del encoder

In [326]:
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

encoder = vit_tiny(patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim).to(device)

# result = encoder(imgs, masks_enc, pos_drop_ratio=0.2)
# x_no_pos, x, pos_drop_mask, x_initial, mask_no_pos = result
# kk = x[pos_drop_mask.unsqueeze(-1).expand(-1, -1, 192)].view(batch_size, -1, 192)
# dd = apply_masks(x_initial, [mask_no_pos])
# sorted_tensor, sorted_indices = torch.sort(mask_no_pos, dim=1)
# ee = apply_masks(x_initial, [sorted_tensor])
# kk == ee

In [327]:
result = encoder(imgs, masks_enc)

In [328]:
result.shape

torch.Size([64, 80, 192])

In [329]:
result = encoder(imgs, masks_enc, pos_drop_ratio=0.4, use_decoder=True)

In [351]:
x, logits, pos_bool, labels = result

In [352]:
logits.shape

torch.Size([64, 80, 256])

In [354]:
pos_bool.shape

torch.Size([64, 80])

In [356]:
pos_bool

tensor([[False,  True, False,  ..., False, False, False],
        [False, False, False,  ...,  True, False,  True],
        [False, False,  True,  ...,  True, False,  True],
        ...,
        [ True,  True, False,  ..., False,  True,  True],
        [False, False,  True,  ...,  True, False, False],
        [False,  True, False,  ..., False,  True, False]], device='cuda:0')

In [343]:
encoder.patch_embed.num_patches

256

In [344]:
logits = logits[pos_bool.unsqueeze(-1).expand(-1, -1, encoder.patch_embed.num_patches)].reshape(batch_size, -1, encoder.patch_embed.num_patches)

In [None]:
logits = logits[pos_bool.unsqueeze(-1).expand(-1, -1, encoder.patch_embed.num_patches)].reshape(batch_size, -1, encoder.patch_embed.num_patches)
loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

In [345]:
logits.shape

torch.Size([64, 32, 256])

In [348]:
import torch.nn.functional as F

In [349]:
loss_pos = F.cross_entropy(logits.permute(0, 2, 1), labels)

In [350]:
loss_pos

tensor(5.6013, device='cuda:0', grad_fn=<NllLoss2DBackward0>)

In [346]:
labels.shape

torch.Size([64, 32])

In [25]:
B, N, D = result.shape

In [272]:
pos_predictor = nn.Linear(encoder.embed_dim, 
                          encoder.patch_embed.num_patches, 
                          bias=True)

In [243]:
pred = pos_predictor(result)

In [244]:
pred.shape

torch.Size([64, 77, 256])

In [246]:
k = F.softmax(pred, dim=-1)

In [247]:
k.shape

torch.Size([64, 77, 256])

In [233]:
import torch.nn.functional as F

In [241]:
labels = torch.arange(90).repeat(N, 1)

In [256]:
labels.shape

torch.Size([64, 77])

In [None]:
    N, L = mask.shape
    num_vis = pred.shape[1]
    labels = torch.arange(L).repeat(N, 1).to(pred.device).detach()
    labels = torch.gather(labels, dim=1, index=ids_keep)

In [237]:
mask.shape

torch.Size([64, 77])

In [253]:
labels = mask

In [255]:
labels.shape

torch.Size([64, 77])

In [252]:
pred.shape

torch.Size([64, 77, 256])

In [259]:
F.cross_entropy(pred.permute(0, 2, 1), labels)

tensor(5.6813, grad_fn=<NllLoss2DBackward0>)

In [251]:
pred.shape

torch.Size([64, 77, 256])

In [249]:
pred.permute(0, 2, 1).shape

torch.Size([64, 256, 77])

In [250]:
mask.shape

torch.Size([64, 77])

In [236]:
F.cross_entropy(r, mask)

RuntimeError: Expected target size [64, 256], got [64, 77]

In [None]:
           # Calculate the number of positions to drop for the current mask
            num_to_drop = int(mask.shape[1] * pos_drop_ratio)
            print(num_to_drop)

            # Randomly select the indices to be dropped
            drop_indices = torch.randperm(mask.shape[1], device=device)[:num_to_drop]
            
            # Gather the indices from the mask that will be dropped
            drop_indices_masked = torch.index_select(mask, 1, drop_indices)
            
            # Create a mask for positions that are not dropped
            non_drop_indices = torch.ones(N, dtype=torch.bool, device=device)
            non_drop_indices[drop_indices_masked] = False
            
            # Add positional embeddings to the non-dropped positions
            x[b_idx, non_drop_indices] += pos_embed[b_idx, non_drop_indices]
            
            # Add mask_pos_token to the dropped positions
            x[b_idx, drop_indices_masked] = mask_pos_token.expand_as(x[b_idx, drop_indices_masked])

In [96]:
def apply_pos_drop_mask(x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
    B, N, D = x.shape
    device = x.device

    # Determine the number of positions to drop in the masked area
    num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

    # Shuffle mask along the last dimension
    random_tensor = torch.rand(B, N_m)
    perms = random_tensor.argsort(dim=1)
    shuffled_mask = mask.gather(1, perms)

    # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
    mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
    mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

    # Apply the masks to x
    x_keep_pos = apply_masks(x, [mask_keep_pos])
    x_no_pos = apply_masks(x, [mask_no_pos])

    # Apply pos_embed and mask_pos_token accordingly
    pos_embed = pos_embed.repeat(B, 1, 1)
    pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
    x_keep_pos = x_keep_pos + pos_embed_masked

    mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1)
    x_no_pos = x_no_pos + mask_pos_tokens

    # Concatenate the results and shuffle again to restore the original order
    x_concat = torch.cat([x_keep_pos, x_no_pos], dim=1)


    return x_restored


In [183]:
def apply_pos_drop_mask(x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
    B, N, D = x.shape
    device = x.device

    # Determine the number of positions to drop in the masked area
    num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

    # Shuffle mask along the last dimension
    random_tensor = torch.rand(B, N_m)
    shuffled_indices = random_tensor.argsort(dim=1)
    shuffled_mask = mask.gather(1, shuffled_indices)

    # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
    mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
    mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

    # Apply the masks to x
    x_no_pos = apply_masks(x, [mask_no_pos])
    x_keep_pos = apply_masks(x, [mask_keep_pos])

    # Apply pos_embed and mask_pos_token accordingly
    mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1)
    x_no_pos = x_no_pos + mask_pos_tokens

    pos_embed = pos_embed.repeat(B, 1, 1)
    pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
    x_keep_pos = x_keep_pos + pos_embed_masked

    # Concatenate the results and shuffle again to restore the original order
    x = torch.cat([x_no_pos, x_keep_pos], dim=1)
    restored_indices = torch.argsort(shuffled_indices, dim=1)
    x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

    return x


In [179]:
x = encoder_.patch_embed(imgs)
pos_embed = encoder_.interpolate_pos_encoding(x, encoder_.pos_embed)
x.shape

torch.Size([64, 256, 192])

In [180]:
    B, N, D = x.shape
    device = x.device

    # Determine the number of positions to drop in the masked area
    num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

    # Shuffle mask along the last dimension
    random_tensor = torch.rand(B, N_m)
    shuffled_indices = random_tensor.argsort(dim=1)
    shuffled_mask = mask.gather(1, shuffled_indices)

    # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
    mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
    mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

    # Apply the masks to x
    x_no_pos = apply_masks(x, [mask_no_pos])
    x_keep_pos = apply_masks(x, [mask_keep_pos])

    # # Apply pos_embed and mask_pos_token accordingly
    # mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1)
    # x_no_pos = x_no_pos + mask_pos_tokens

    # pos_embed = pos_embed.repeat(B, 1, 1)
    # pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
    # x_keep_pos = x_keep_pos + pos_embed_masked

    # Concatenate the results and shuffle again to restore the original order
    x_concat = torch.cat([x_no_pos, x_keep_pos], dim=1)
    restored_indices = torch.argsort(shuffled_indices, dim=1)
    x_restored = x_concat.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))


In [181]:
x_ = apply_masks(x, [mask])

In [182]:
x_ == x_restored

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [166]:
x_restored[0][0]

tensor([ 0.0208, -0.7616,  0.4506, -0.1220, -0.9622,  0.4491,  0.5946,  0.2430,
        -0.6636,  1.4046,  0.0318, -0.6683, -0.9617,  1.0837, -0.2786,  0.0447,
         0.8781,  1.4150,  0.0974,  0.2760, -0.0513, -1.3648, -0.4059,  0.4779,
        -0.4468, -1.4697, -1.6139, -0.2443, -1.0343, -0.3135, -0.0034, -0.0478,
        -0.7251,  0.5278,  0.0551, -0.2267,  1.5084,  0.0426,  1.6751,  0.5833,
        -0.0367, -0.3562,  0.3727,  0.2418,  0.6569,  0.3649,  1.3766, -0.3369,
         0.6524,  0.2756, -0.6972,  0.2613,  0.7315, -0.9653, -1.7160,  0.3165,
         1.5662,  0.6354,  0.3665, -1.1100,  0.7692,  0.0204,  1.2124,  1.1574,
        -0.3337,  0.6425, -0.4855, -0.4052, -0.9536, -0.0757,  0.0626, -0.3468,
         0.0157,  0.7922, -0.4741, -1.3195,  0.9651,  0.2866, -0.0285,  1.4460,
         0.8951, -0.9021, -0.7752,  0.1132,  0.1739, -1.2930, -0.1893,  0.2564,
         0.3589,  0.3718,  0.2388, -0.2554, -0.1476,  0.8896,  0.5945, -0.2822,
         0.0288,  0.9150,  0.7384,  0.30

In [160]:
torch.argsort(perms, dim=1)

tensor([[41, 69, 59,  ..., 60, 45, 35],
        [15, 21, 51,  ..., 47, 33,  6],
        [68,  5, 40,  ..., 66, 48, 44],
        ...,
        [36, 41, 61,  ..., 16,  1, 71],
        [36, 68, 39,  ..., 72, 56, 38],
        [47, 19, 76,  ..., 50, 20, 29]])

In [161]:
perms

tensor([[15,  9, 59,  ..., 53, 16, 21],
        [10, 38, 48,  ...,  3, 31, 71],
        [ 9, 39, 70,  ..., 58, 21, 18],
        ...,
        [ 4, 75, 51,  ..., 10, 56, 32],
        [42, 53, 51,  ...,  9, 15, 36],
        [51,  9, 21,  ...,  8, 35,  2]])

In [159]:
mask

tensor([[  0,   1,   2,  ..., 164, 165, 166],
        [  0,   1,   2,  ..., 198, 199, 200],
        [  0,   1,   2,  ..., 146, 147, 148],
        ...,
        [  6,   7,   8,  ..., 226, 227, 228],
        [  0,   1,   2,  ..., 131, 132, 141],
        [  0,   1,   2,  ..., 149, 158, 160]])

In [158]:
shuffled_mask 

tensor([[ 16,   9, 126,  ..., 110,  17,  28],
        [ 17,  66,  94,  ...,   3,  52, 195],
        [ 16,  67, 133,  ..., 112,  35,  32],
        ...,
        [ 10, 227, 201,  ...,  23, 206, 142],
        [ 76,  94,  92,  ...,  16,  28,  64],
        [ 99,  16,  35,  ...,  14,  65,   2]])

In [124]:
    mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
    mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

In [145]:
mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1)

In [148]:
x_no_pos = x_no_pos + mask_pos_tokens

In [150]:
x_concat = torch.cat([x_keep_pos, x_no_pos], dim=1)

In [151]:
x_concat.shape

torch.Size([64, 77, 192])

In [149]:
x_no_pos.shape

torch.Size([64, 15, 192])

In [147]:
mask_pos_tokens.shape

torch.Size([64, 15, 192])

In [139]:
x = encoder_.patch_embed(imgs)
pos_embed = encoder_.interpolate_pos_encoding(x, encoder_.pos_embed)
x.shape

torch.Size([64, 256, 192])

In [140]:
    pos_embed = pos_embed.repeat(B, 1, 1)
    pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])

In [142]:
    x_keep_pos = apply_masks(x, [mask_keep_pos])
    x_no_pos = apply_masks(x, [mask_no_pos])

In [143]:
x_keep_pos.shape

torch.Size([64, 62, 192])

In [144]:
x_keep_pos = x_keep_pos + pos_embed_masked

In [141]:
pos_embed_masked.shape

torch.Size([64, 62, 192])

In [136]:
pos_embed_ = pos_embed.repeat(B, 1, 1)

In [137]:
pos_embed_.shape

torch.Size([64, 256, 192])

In [97]:
B, N, D = x.shape
num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

In [106]:
import torch

# Example tensor
B, N = 2, 3
tensor = torch.tensor([[1, 2, 3], [1, 2, 3]])

# Create random permutations for each row in a batched way
perms = torch.randperm(N).repeat(B, 1)
print(perms)

# Shuffle each row according to its permutation
shuffled_tensor = tensor.gather(1, perms)

print(shuffled_tensor)


tensor([[0, 1, 2],
        [0, 1, 2]])
tensor([[1, 2, 3],
        [1, 2, 3]])


In [112]:
B, N_m = mask.shape

In [121]:
mask

tensor([[  0,   1,   2,  ..., 164, 165, 166],
        [  0,   1,   2,  ..., 198, 199, 200],
        [  0,   1,   2,  ..., 146, 147, 148],
        ...,
        [  6,   7,   8,  ..., 226, 227, 228],
        [  0,   1,   2,  ..., 131, 132, 141],
        [  0,   1,   2,  ..., 149, 158, 160]])

In [120]:
random_tensor = torch.rand(B, N_m)
perms = random_tensor.argsort(dim=1)
shuffled_mask = mask.gather(1, perms)

In [122]:
shuffled_mask

tensor([[162, 142,  80,  ..., 126,  83,  65],
        [ 32,  81, 192,  ...,  53,  14,  33],
        [ 27,  94, 128,  ...,  12, 142,  99],
        ...,
        [226,  12, 217,  ...,  71,   7, 218],
        [ 36, 141,  92,  ...,  12,  33,  46],
        [ 66,  48,   5,  ...,  96,  35,  39]])

In [107]:
import torch

# Example tensor
B, N = 2, 3
tensor = torch.tensor([[1, 2, 3], [1, 2, 3]])

# Generate a random tensor and use argsort to get unique permutations for each row
random_tensor = torch.rand(B, N)
perms = random_tensor.argsort(dim=1)

# Shuffle each row according to its unique permutation
shuffled_tensor = tensor.gather(1, perms)

print(shuffled_tensor)


tensor([[1, 2, 3],
        [2, 1, 3]])


In [109]:
random_tensor = torch.rand(B, N)
random_tensor

tensor([[0.1963, 0.0848, 0.1815],
        [0.3819, 0.5320, 0.8059]])

In [110]:
perms = random_tensor.argsort(dim=1)
perms

tensor([[1, 2, 0],
        [0, 1, 2]])

In [101]:
shuffled_indices = torch.randperm(mask.size(1), device=device).repeat(B, 1)
shuffled_indices

tensor([[49, 22, 65,  ...,  1, 71, 13],
        [49, 22, 65,  ...,  1, 71, 13],
        [49, 22, 65,  ...,  1, 71, 13],
        ...,
        [49, 22, 65,  ...,  1, 71, 13],
        [49, 22, 65,  ...,  1, 71, 13],
        [49, 22, 65,  ...,  1, 71, 13]])

In [None]:
    shuffled_indices = torch.randperm(mask.size(1), device=device).repeat(B, 1)
    shuffled_mask = torch.gather(mask, 1, shuffled_indices)

In [49]:
imgs.shape

torch.Size([64, 3, 224, 224])

In [51]:
x = encoder_.patch_embed(imgs)

In [87]:
embed_dim = encoder_.embed_dim

In [88]:
mask_pos_token = nn.Parameter(torch.zeros(1, 1, embed_dim), requires_grad=True)

In [91]:
B = imgs.shape[0]

In [93]:
mask_pos_tokens = mask_pos_token.repeat(B, 4, 1)

In [95]:
mask_pos_tokens.shape

torch.Size([64, 4, 192])

In [89]:
mask_pos_token.shape

torch.Size([1, 1, 192])

In [84]:
x.shape

torch.Size([64, 77, 192])

In [72]:
encoder_.pos_embed.shape

torch.Size([1, 256, 192])

In [81]:
 int(77* 0.2)

15

In [53]:
B, N, D = x.shape

In [73]:
# Create a mask for all positions initially set to keep (1)
full_mask = torch.ones((B, N), dtype=torch.bool, device=device)

In [82]:
x = apply_masks(x, masks)

In [83]:
x.shape

torch.Size([64, 77, 192])

In [74]:
full_mask.shape

torch.Size([64, 256])

In [80]:
mask.size()

torch.Size([64, 77])

In [77]:
# Flatten the mask for easier processing
flat_mask = mask.flatten()  # This mask is of size (B, no. patches kepts for context)

# Determine total number of positions to drop across the batch
total_num_to_drop = int(flat_mask.numel() * pos_drop_ratio)

In [78]:
total_num_to_drop

985

In [None]:
    # Create a mask for all positions initially set to keep (1)
    full_mask = torch.ones((B, N), dtype=torch.bool, device=device)
    
    # Flatten the mask for easier processing
    flat_mask = mask.flatten()

    # Determine total number of positions to drop across the batch
    total_num_to_drop = int(flat_mask.numel() * pos_drop_ratio)

    # Randomly select indices to be dropped from the flattened mask
    drop_indices = torch.randperm(flat_mask.numel(), device=device)[:total_num_to_drop]
    
    # Mark the selected indices for dropping in the full mask
    full_mask.view(-1)[flat_mask[drop_indices]] = False

    # Apply the positional embeddings for non-dropped positions
    x[full_mask] += pos_embed.view(B * N, -1)[full_mask.view(-1)]
    
    # Apply mask_pos_token for dropped positions
    x[~full_mask] = mask_pos_token

In [55]:
full_mask

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [56]:
flat_mask = mask.flatten()

In [70]:
mask

tensor([[  0,   1,   2,  ..., 164, 165, 166],
        [  0,   1,   2,  ..., 198, 199, 200],
        [  0,   1,   2,  ..., 146, 147, 148],
        ...,
        [  6,   7,   8,  ..., 226, 227, 228],
        [  0,   1,   2,  ..., 131, 132, 141],
        [  0,   1,   2,  ..., 149, 158, 160]])

In [57]:
flat_mask.shape

torch.Size([4928])

In [58]:
pos_drop_ratio = 0.2

In [59]:
flat_mask.numel()

4928

In [60]:
total_num_to_drop = int(flat_mask.numel() * pos_drop_ratio)

In [61]:
total_num_to_drop

985

In [35]:
mask = masks_enc[0]
pos_drop_ratio = 0.2

In [62]:
drop_indices = torch.randperm(flat_mask.numel(), device=device)[:total_num_to_drop]

In [67]:
full_mask.shape

torch.Size([64, 256])

In [69]:
flat_mask.numel() 

4928

In [68]:
64 * 256

16384

In [65]:
full_mask.view(-1).shape
#[flat_mask[drop_indices]] = False

torch.Size([16384])

In [37]:
print(mask.shape[1])

77


In [48]:
len(masks)

1

In [44]:
masks = masks_enc

In [45]:
masks

[tensor([[  0,   1,   2,  ..., 164, 165, 166],
         [  0,   1,   2,  ..., 198, 199, 200],
         [  0,   1,   2,  ..., 146, 147, 148],
         ...,
         [  6,   7,   8,  ..., 226, 227, 228],
         [  0,   1,   2,  ..., 131, 132, 141],
         [  0,   1,   2,  ..., 149, 158, 160]])]

In [46]:
masks[0]

tensor([[  0,   1,   2,  ..., 164, 165, 166],
        [  0,   1,   2,  ..., 198, 199, 200],
        [  0,   1,   2,  ..., 146, 147, 148],
        ...,
        [  6,   7,   8,  ..., 226, 227, 228],
        [  0,   1,   2,  ..., 131, 132, 141],
        [  0,   1,   2,  ..., 149, 158, 160]])

In [47]:
masks[0].shape

torch.Size([64, 77])

In [36]:
            num_to_drop = int(mask.shape[1] * pos_drop_ratio)
            print(num_to_drop)

15


In [38]:
drop_indices = torch.randperm(mask.shape[1], device=device)[:num_to_drop]

In [39]:
drop_indices

tensor([19, 53, 71, 66, 55, 24, 75, 20, 60, 72, 45, 29, 32, 47, 13])

In [29]:
masks_enc[0].shape

torch.Size([64, 77])

In [33]:
torch.randperm(10)[:5]

tensor([9, 7, 5, 4, 0])

In [None]:
    def forward_encoder(self, x, mask_ratio, pos_mask_ratio):
        outs = {}
        inputs = x.detach().clone()

        # embed patches w/o [cls] token
        x = self.patch_embed(x)
        N, L, D = x.shape

        # generate mask
        ids_keep, mask, ids_restore, ids_remove = self.random_masking(x, mask_ratio)
        outs['mask'], outs['ids_keep'], outs['ids_restore'] = mask, ids_keep, ids_restore
        # gather patch embeddings and position embeddings
        x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        pos_embed_all = self.pos_embed[:, 1:, :].data.repeat(N, 1, 1)  # w/o [cls] token
        pos_embed_vis = torch.gather(pos_embed_all, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)).detach()

        # random masking for position embedding
        ids_keep_pos, mask_pos, ids_restore_pos, ids_remove_pos = self.random_masking(x, pos_mask_ratio)
        outs['mask_pos'], outs['ids_keep_pos'], outs['ids_restore_pos'] = mask_pos, ids_keep_pos, ids_restore_pos

        # gather position embeddings
        pos_embed = torch.gather(pos_embed_vis, dim=1, index=ids_keep_pos.unsqueeze(-1).repeat(1, 1, D))

        # append mask tokens to position embeddings
        mask_pos_length = mask_pos.sum().item()
        if self.mask_token_type == 'param':
            mask_pos_tokens = self.mask_pos_token.repeat(N, mask_pos_length, 1)
        elif self.mask_token_type == 'zeros':
            mask_pos_tokens = torch.zeros((N, mask_pos_length, self.embed_dim)).to(x.device)
        elif self.mask_token_type == 'wrong_pos':
            removed_pos_embed = torch.gather(pos_embed_vis, dim=1, index=ids_remove_pos.unsqueeze(-1).repeat(1, 1, D))
            # convert to numpy, since numpy shuffles the first dimension, we have to transpose first
            removed_pos_embed = removed_pos_embed.detach().cpu().permute(1, 0, 2).numpy()        # [N, L, D] -> [L, N, D]
            np.random.shuffle(removed_pos_embed)
            # restore to torch
            removed_pos_embed = torch.from_numpy(removed_pos_embed).permute(1, 0, 2)    # [L, N, D] -> [N, L, D]
            mask_pos_tokens = removed_pos_embed.to(x.device)
        else:
            raise Exception('unknown mask_token_type: {}'.format(self.mask_token_type))

        pos_embed = torch.cat([pos_embed, mask_pos_tokens], dim=1)

        # restore position embeddings before adding
        pos_embed = torch.gather(pos_embed, dim=1, index=ids_restore_pos.unsqueeze(-1).repeat(1, 1, D))

        # add position embedding w/o [cls] token
        x = x + pos_embed

        if self.shuffle:
            # generate shuffle indexes first
            ids_keep_shuffle, _, ids_restore_shuffle, _ = self.random_masking(x, 0.)

            # gather
            x = torch.gather(x, dim=1, index=ids_keep_shuffle.unsqueeze(-1).repeat(1, 1, D))
            outs['ids_restore_shuffle'] = ids_restore_shuffle

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # get last self-attention
        if self.attn_guide:
            # get attentions
            # attn = self.get_last_attention(inputs)
            # attn = attn[:, :, 0, 1:].mean(1)    # [N, num_patches]
            # outs['attn_full'] = attn

            # get similarities
            attn = self.get_feature_similarity(inputs)
            outs['attn_full'] = attn

            # gather visible patches
            attn = torch.gather(attn, dim=1, index=ids_keep)
            outs['attn'] = attn / attn.sum(-1, keepdims=True)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        outs['x'] = self.norm(x)

        return outs

In [16]:
masks_enc[0]

tensor([[  0,   1,   2,  ..., 164, 165, 166],
        [  0,   1,   2,  ..., 198, 199, 200],
        [  0,   1,   2,  ..., 146, 147, 148],
        ...,
        [  6,   7,   8,  ..., 226, 227, 228],
        [  0,   1,   2,  ..., 131, 132, 141],
        [  0,   1,   2,  ..., 149, 158, 160]])

In [17]:
masks_enc[0].shape

torch.Size([64, 77])

In [18]:
# encoder, predictor = init_model(
#     device=device,
#     patch_size=patch_size,
#     crop_size=crop_size,
#     pred_depth=pred_depth,
#     pred_emb_dim=pred_emb_dim,
#     model_name=model_name)

In [19]:
imgs.shape

torch.Size([64, 3, 224, 224])

In [20]:
masks_enc[0].shape

torch.Size([64, 77])

In [21]:
del encoder_

NameError: name 'encoder_' is not defined

In [22]:
encoder_ = vit_tiny(patch_size=patch_size,
    crop_size=crop_size,
    pred_depth=pred_depth,
    pred_emb_dim=pred_emb_dim)

In [23]:
encoder_ = encoder_.to(device)

In [25]:
d_ = encoder_(imgs, masks_enc)

x shape torch.Size([64, 256, 192])
pos_embed shape torch.Size([1, 256, 192])


In [30]:
d_.shape

torch.Size([64, 77, 192])

In [52]:
(d == d_).sum()

tensor(1093632, device='cuda:0')

In [22]:
masks_enc[0][0]

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,  28,
         29,  30,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,
         44,  45,  46,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,
         59,  60,  61,  62,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,
         74,  75,  76,  77,  78,  80,  81,  82,  83,  84,  85,  86,  87,  88,
         89,  90,  91,  92,  93,  94,  96, 109, 110, 112, 125, 126],
       device='cuda:0')

In [23]:
masks_enc[0][1]

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  16,  17,  18,  19,  20,  21,  22,  23,  32,  33,  34,  35,  36,
         37,  38,  39,  48,  49,  50,  51,  52,  53,  64,  65,  66,  67,  68,
         69,  80,  81,  82,  83,  84,  85,  96,  97,  98,  99, 100, 101, 109,
        110, 112, 113, 114, 115, 116, 117, 125, 126, 128, 129, 130, 131, 132,
        133, 141, 142, 144, 145, 146, 147, 148, 149, 157, 158, 160, 161, 162,
        173, 174, 176, 177, 178, 189, 190, 192, 193, 194, 202, 203],
       device='cuda:0')

In [185]:
N_m == mask.size(1)

True

In [184]:
def apply_pos_drop_mask(x, pos_embed, mask_pos_token, mask, pos_drop_ratio):
    B, N, D = x.shape
    device = x.device

    # Determine the number of positions to drop in the masked area
    num_pos_to_drop = int(mask.size(1) * pos_drop_ratio)

    # Shuffle mask along the last dimension
    random_tensor = torch.rand(B, N_m)
    shuffled_indices = random_tensor.argsort(dim=1)
    shuffled_mask = mask.gather(1, shuffled_indices)

    # Split the mask into two: one for keeping pos_embed, one for mask_pos_token
    mask_no_pos = shuffled_mask[:, :num_pos_to_drop]
    mask_keep_pos = shuffled_mask[:, num_pos_to_drop:]

    # Apply the masks to x
    x_no_pos = apply_masks(x, [mask_no_pos])
    x_keep_pos = apply_masks(x, [mask_keep_pos])

    # Apply pos_embed and mask_pos_token accordingly
    mask_pos_tokens = mask_pos_token.repeat(B, num_pos_to_drop, 1)
    x_no_pos = x_no_pos + mask_pos_tokens

    pos_embed = pos_embed.repeat(B, 1, 1)
    pos_embed_masked = apply_masks(pos_embed, [mask_keep_pos])
    x_keep_pos = x_keep_pos + pos_embed_masked

    # Concatenate the results and shuffle again to restore the original order
    x = torch.cat([x_no_pos, x_keep_pos], dim=1)
    restored_indices = torch.argsort(shuffled_indices, dim=1)
    x = x.gather(1, restored_indices.unsqueeze(-1).expand(-1, -1, D))

    return x


In [64]:
imgs.shape

torch.Size([64, 3, 224, 224])

In [63]:
d.shape

torch.Size([64, 256, 192])

In [65]:
patch_size

14

In [14]:
encoder_ = vit_tiny().to(device)

In [17]:
del encoder_

In [15]:
imgs.shape

torch.Size([64, 3, 224, 224])

In [16]:
dd = encoder_(imgs, masks_enc)

../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [632,0,0], thread: [32,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [632,0,0], thread: [33,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [632,0,0], thread: [34,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [632,0,0], thread: [35,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [632,0,0], thread: [36,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): b

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [78]:
masks_enc[0].shape

torch.Size([64, 63])

In [76]:
dd.shape

torch.Size([64, 196, 192])