In [1]:
import vision_transformer as vits
import utils
from vision_transformer import DINOHead
from main_dino import DINOLoss
from vot.models.factory import (
    create_model,
    create_optimizer,
    dino_load_checkpoint,
    dino_save_checkpoint,
    MultiCropWrapper as VotMCW,
    PatchEmbeddingConv,
)
import MinkowskiEngine as ME
import yaml


In [2]:
dino = vits.__dict__["vit_small"](patch_size=16)
dino_head = DINOHead(dino.embed_dim, 65536, False)


In [3]:
with open("/home/hli/dl_ws/voxel-transformer/configs/2d_dino.yaml", "r") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
teacher = create_model(config["model"])
vot = teacher.backbone
vot_head = teacher.head


In [4]:
def get_num_params(model):
    res = 0
    for p in model.parameters():
        res += p.numel()
    return res


In [5]:
dino_head.last_layer.weight.std()


tensor(0.0361, grad_fn=<StdBackward0>)

In [6]:
vot_head.last_layer.weight.std()


tensor(0.0361, grad_fn=<StdBackward0>)

In [7]:
get_num_params(dino_head)


22352128

In [8]:
get_num_params(vot_head)


22352128

In [9]:
dino.cls_token.shape


torch.Size([1, 1, 384])

In [10]:
vot.class_token.class_position_embedding.shape


torch.Size([1, 1, 384])

In [11]:
import torch
from einops import rearrange

with torch.no_grad():
    vot.patch_embedding.head_conv.kernel[:] = rearrange(
        dino.patch_embed.proj.weight, "p c w h -> (h w) c p"
    )
    vot.patch_embedding.head_conv.bias[:] = dino.patch_embed.proj.bias

    vot.position_embedding.position_embedding[:] = dino.pos_embed[0, 1:]
    vot.class_token.class_position_embedding[:] = dino.pos_embed[0, 0]
    vot.class_token.class_embedding[:] = dino.cls_token

    vot.encoder.blocks.load_state_dict(dino.blocks.state_dict())
    vot.encoder.encoder_norm.load_state_dict(dino.norm.state_dict())

    vot_head.load_state_dict(dino_head.state_dict())

    # vot.patch_embedding = dino.patch_embed


In [12]:
from experiment.hybrid import Hybrid

hybrid_model = Hybrid(dino, vot)


In [13]:
import torch

input = torch.randn(1, 3, 224, 224)
sinput = ME.to_sparse(input)


In [14]:
if isinstance(vot.patch_embedding, PatchEmbeddingConv):
    patch_embeddings, patch_ids, sparse_patch_embeddings = vot.patch_embedding(sinput)
else:
    patch_embeddings = vot.patch_embedding(input)
    B, P, C = patch_embeddings.shape
    patch_ids = torch.arange(P).view(1, -1).expand(B, -1)
    sparse_patch_embeddings = None
patch_embeddings[0, torch.where(patch_ids == 0)[1]]

patch_embeddings = patch_embeddings + vot.position_embedding(patch_ids)
pos_embeddings = vot.position_embedding(patch_ids)

patch_embeddings, masks = vot.class_token(patch_embeddings, patch_ids < 0)
patch_embeddings[0, -1]

block = vot.encoder.blocks[0]
attn_layer = block.attn

B, N, C = patch_embeddings.shape
qkv = (
    attn_layer.qkv(patch_embeddings)
    .reshape(B, N, 3, attn_layer.heads, C // attn_layer.heads)
    .permute(2, 0, 3, 1, 4)
)
q, k, v = (
    qkv[0],
    qkv[1],
    qkv[2],
)

attn = (q @ k.transpose(-2, -1)) * attn_layer.scale
# if mask is not None:
#     attn = attn.masked_fill(mask.view(B, 1, 1, -1), float("-inf"))
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)

# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)

# x = x + self.drop_path(y)
# x = x + self.drop_path(self.mlp(self.norm2(x)))
# patch_embeddings, hidden_states = vot.encoder(patch_embeddings, masks)
# hidden_states[-1][0, -1]
# patch_embeddings[0, -1]

# vot_out = vot_head(patch_embeddings[:, -1])
# vot_out


In [15]:
x, patch_emb, pos_emb = dino.prepare_tokens(input)
patch_emb[0, 0]
x[0, 0]

block = dino.blocks[0]
attn_layer = block.attn

B, N, C = x.shape
qkv_dino = (
    attn_layer.qkv(x)
    .reshape(B, N, 3, attn_layer.num_heads, C // attn_layer.num_heads)
    .permute(2, 0, 3, 1, 4)
)
q_dino, k_dino, v_dino = qkv_dino[0], qkv_dino[1], qkv_dino[2]

attn_dino = (q_dino @ k_dino.transpose(-2, -1)) * attn_layer.scale


# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)

# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)

# y_dino, attn = block.attn(block.norm1(x))
# y_dino

# x = x + self.drop_path(y)
# x = x + self.drop_path(self.mlp(self.norm2(x)))

# hidden_states_dino = []
# for blk in dino.blocks:
#     x = blk(x)
#     hidden_states_dino.append(x)
# x[0, 0]

# x = dino.norm(x)
# x[0, 0]

# dino_out = dino_head(x[:, 0])
# dino_out


In [16]:
x, patch_emb, pos_emb = hybrid_model.dino.prepare_tokens(input)
masks = torch.zeros(x.size(0), x.size(1), dtype=bool)
# x[0,0]


In [17]:
# patch_emb, hidden_states, sparse_patch_emb = vot(sinput,debug=True)
# output, patch_emb, pos_emb, hidden_states = dino(input,debug=True)
vot_out = vot(sinput)
dino_out = dino(input)
hybrid_out = hybrid_model(input)
# vot.patch_embedding = dino.patch_embed
# vot_out_2 = vot(input)


In [18]:
(dino_out - vot_out).abs().max()


tensor(2.2352e-06, grad_fn=<MaxBackward1>)

In [19]:
(dino_out - hybrid_out).abs().max()


tensor(2.2352e-06, grad_fn=<MaxBackward1>)

In [20]:
dino_teacher = utils.MultiCropWrapper(
    dino,
    dino_head,
)
vot_teacher = VotMCW(vot, vot_head)


In [21]:
input = [torch.randn(1, 3, 224, 224) for _ in range(10)]
sinput = [ME.to_sparse(i) for i in input]


In [22]:
get_num_params(dino_teacher) == get_num_params(vot_teacher)


True

In [23]:
dino_out = dino_teacher(input)
vot_out = vot_teacher(sinput)


ValueError: not enough values to unpack (expected 4, got 2)

In [None]:
dino_out[2] - vot_out[2]


tensor([-4.4703e-08, -5.9605e-08,  3.1665e-08,  ...,  4.8429e-08,
         0.0000e+00, -7.0781e-08], grad_fn=<SubBackward0>)

In [None]:
from main_dino import DataAugmentationDINO
from torchvision import datasets

transform = DataAugmentationDINO(
    (0.4, 1.0),
    (0.05, 0.4),
    8,
)
data_path = "/mnt/raid/hli/datasets/imagenet/ilsvrc2012/train"
dataset = datasets.ImageFolder(data_path, transform=transform)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=6,
    pin_memory=True,
    drop_last=True,
)




In [None]:
(images, _) = next(iter(data_loader))
simages = [ME.to_sparse(i) for i in images]


In [None]:
dino_out = dino_teacher(images)
vot_out = vot_teacher(simages)


In [None]:
(dino_out - vot_out).min()


tensor(-4.4703e-07, grad_fn=<MinBackward1>)