In [1]:
import vision_transformer as vits
import utils
from vision_transformer import DINOHead
from main_dino import DINOLoss
from vot.models.factory import (
    create_model,
    create_optimizer,
    dino_load_checkpoint,
    dino_save_checkpoint,
    MultiCropWrapper as VotMCW,
    PatchEmbeddingConv
)
import MinkowskiEngine as ME
import yaml
import torch



In [2]:
dino = vits.__dict__["vit_small"](patch_size=16)
dino2 = vits.__dict__["vit_small"](patch_size=16)
dino_head = DINOHead(dino.embed_dim, 65536, False)


In [3]:
with open("/home/hli/dl_ws/voxel-transformer/configs/2d_dino.yaml", "r") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
vot_student = create_model(config["model"])
vot_teacher = create_model(config["model"])
vot = vot_student.backbone
vot_head = vot_student.head


In [4]:
def get_num_params(model):
    res = 0
    for p in model.parameters():
        res += p.numel()
    return res


In [5]:
import torch
from einops import rearrange

with torch.no_grad():
    vot.patch_embedding.head_conv.kernel[:] = rearrange(dino.patch_embed.proj.weight, "p c w h -> (h w) c p")
    vot.patch_embedding.head_conv.bias[:] = dino.patch_embed.proj.bias
    
    vot.position_embedding.position_embedding[:] = dino.pos_embed[0,1:]
    vot.class_token.class_position_embedding[:] = dino.pos_embed[0,0]
    vot.class_token.class_embedding[:] = dino.cls_token
    
    vot.encoder.blocks.load_state_dict(dino.blocks.state_dict())
    vot.encoder.encoder_norm.load_state_dict(dino.norm.state_dict())
    
    vot_head.load_state_dict(dino_head.state_dict())
    
    # vot.patch_embedding = dino.patch_embed
    # vot_teacher.backbone.patch_embedding = dino2.patch_embed
    
    

In [6]:
dino_student = utils.MultiCropWrapper(
    dino,
    dino_head,
)
dino_t = vits.__dict__["vit_small"](patch_size=16)
dino_head_t = DINOHead(dino.embed_dim, 65536, False)
dino_teacher = utils.MultiCropWrapper(
    dino_t,
    dino_head_t,
)

dino_teacher.load_state_dict(dino_student.state_dict())
vot_teacher.load_state_dict(vot_student.state_dict())
for p in dino_teacher.parameters():
    p.requires_grad = False
for p in vot_teacher.parameters():
    p.requires_grad = False


In [7]:
from main_dino import DataAugmentationDINO
from torchvision import datasets

transform = DataAugmentationDINO(
    (0.4, 1.0),
    (0.05, 0.4),
    8,
)
data_path = "/mnt/raid/hli/datasets/imagenet/ilsvrc2012/train"
dataset = datasets.ImageFolder(data_path, transform=transform)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=6,
    pin_memory=True,
    drop_last=True,
)




In [8]:
(images, _) = next(iter(data_loader))
simages = [ME.to_sparse(i) for i in images]


In [9]:
dino_out = dino_student(images)
vot_out = vot_student(images)


In [10]:
(dino_out - vot_out).abs().max()


tensor(0., grad_fn=<MaxBackward1>)

In [11]:
dino_out_teacher = dino_teacher(images[:2])
vot_out_teacher = vot_teacher(images[:2])


In [12]:
(dino_out_teacher - vot_out_teacher).abs().max()


tensor(0.)

In [13]:
(dino_out[: dino_out_teacher.size(0)] - dino_out_teacher).abs().max()


tensor(1.3411e-07, grad_fn=<MaxBackward1>)

In [14]:
(vot_out[: dino_out_teacher.size(0)] - vot_out_teacher).abs().max()

tensor(1.3411e-07, grad_fn=<MaxBackward1>)

In [15]:
dino_loss = create_model(config["loss"])
dino_loss_vot = create_model(config["loss"])


In [16]:
# dino_loss = DINOLoss(
#     65536,
#     10,  # total number of crops = 2 global crops + local_crops_number
#     0.04,
#     0.04,
#     0,
#     100,
# )

# dino_loss_vot = DINOLoss(
#     65536,
#     10,  # total number of crops = 2 global crops + local_crops_number
#     0.04,
#     0.04,
#     0,
#     100,
# )


In [17]:
loss_dino = dino_loss(dino_out, dino_out_teacher, 0)
loss_vot = dino_loss_vot(vot_out, vot_out_teacher, 0)


In [18]:
params_groups = utils.get_params_groups(dino_student)
dino_optimizer = torch.optim.AdamW(dino_student.parameters(), lr=2e-2)
params_groups = utils.get_params_groups(vot_student)
vot_optimizer = torch.optim.AdamW(vot_student.parameters(), lr=2e-2)


In [19]:
dino_optimizer.zero_grad()
vot_optimizer.zero_grad()
loss_dino.backward()
loss_vot.backward()
dino_optimizer.step()
vot_optimizer.step()


In [20]:
# with torch.no_grad():
#     m = 0.996
#     for param_q, param_k in zip(
#         vot_student.module.parameters(), vot_teacher.parameters()
#     ):
#         param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
#     for param_q, param_k in zip(
#         dino_student.module.parameters(), dino_teacher.parameters()
#     ):
#         param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)


In [21]:
input = images[0]

In [25]:
if isinstance(vot.patch_embedding, PatchEmbeddingConv):
    patch_embeddings, patch_ids, sparse_patch_embeddings = vot.patch_embedding(
        sinput
    )
else:
    patch_embeddings = vot.patch_embedding(input)
    B, P, C = patch_embeddings.shape
    patch_ids = torch.arange(P).view(1, -1).expand(B, -1)
    sparse_patch_embeddings = None
patch_embeddings[0, torch.where(patch_ids == 0)[1][0]]

patch_embeddings = patch_embeddings + vot.position_embedding(patch_ids)
pos_embeddings = vot.position_embedding(patch_ids)

patch_embeddings, masks = vot.class_token(patch_embeddings, patch_ids < 0)
patch_embeddings[0, 0]

# block = vot.encoder.blocks[0]
# attn_layer = block.attn

# B, N, C = patch_embeddings.shape
# qkv = (
#     attn_layer.qkv(patch_embeddings)
#     .reshape(B, N, 3, attn_layer.heads, C // attn_layer.heads)
#     .permute(2, 0, 3, 1, 4)
# )
# q, k, v = (
#     qkv[0],
#     qkv[1],
#     qkv[2],
# )

# attn = (q @ k.transpose(-2, -1)) * attn_layer.scale
# if mask is not None:
#     attn = attn.masked_fill(mask.view(B, 1, 1, -1), float("-inf"))
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)

# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)

# x = x + self.drop_path(y)
# x = x + self.drop_path(self.mlp(self.norm2(x)))
# patch_embeddings, hidden_states = vot.encoder(patch_embeddings, masks)
# hidden_states[-1][0, -1]
# patch_embeddings[0, -1]

# vot_out = vot_head(patch_embeddings[:, -1])
# vot_out

tensor([-0.0406, -0.0160, -0.0508, -0.0788, -0.0277,  0.0516,  0.0071,  0.0134,
        -0.0638,  0.0393,  0.0709,  0.0069, -0.0521, -0.0743,  0.0380,  0.0685,
        -0.0932,  0.0174,  0.0303,  0.0507, -0.0652, -0.0673, -0.0420,  0.0147,
         0.0627, -0.0602,  0.0704,  0.0489, -0.0409, -0.0781,  0.0263, -0.0107,
         0.0243,  0.0628, -0.0437,  0.0511, -0.0152,  0.0177,  0.0448, -0.0543,
        -0.0644, -0.0366, -0.0384,  0.0656, -0.0806,  0.0610, -0.0516, -0.0928,
         0.0188, -0.0287,  0.0320,  0.0562, -0.0133,  0.0606,  0.0432,  0.0212,
         0.0243,  0.0748,  0.0561, -0.0433,  0.0810,  0.0507, -0.0499,  0.0503,
        -0.0699,  0.0397,  0.0663,  0.0137, -0.0827, -0.0110,  0.0114, -0.1011,
         0.0384, -0.0285,  0.0589, -0.0755, -0.0154, -0.0271,  0.0640, -0.0073,
        -0.0058,  0.0028, -0.0493,  0.0092, -0.0685, -0.0549, -0.0716, -0.0608,
        -0.0339, -0.0545, -0.0264,  0.0459,  0.0831, -0.0317,  0.0181,  0.0473,
        -0.0142,  0.0448,  0.0282, -0.05

In [23]:
x, patch_emb, pos_emb = dino.prepare_tokens(input)
patch_emb[0, 0]
# x[0, 0]

# block = dino.blocks[0]
# attn_layer = block.attn

# B, N, C = x.shape
# qkv_dino = (
#     attn_layer.qkv(x)
#     .reshape(B, N, 3, attn_layer.num_heads, C // attn_layer.num_heads)
#     .permute(2, 0, 3, 1, 4)
# )
# q_dino, k_dino, v_dino = qkv_dino[0], qkv_dino[1], qkv_dino[2]

# attn_dino = (q_dino @ k_dino.transpose(-2, -1)) * attn_layer.scale
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)

# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)
        
# y_dino, attn = block.attn(block.norm1(x))
# y_dino

# x = x + self.drop_path(y)
# x = x + self.drop_path(self.mlp(self.norm2(x)))

# hidden_states_dino = []
# for blk in dino.blocks:
#     x = blk(x)
#     hidden_states_dino.append(x)
# x[0, 0]

# x = dino.norm(x)
# x[0, 0]

# dino_out = dino_head(x[:, 0])
# dino_out


tensor([-21.6251,  23.6143,  23.8853,  15.8956,  24.5549,  25.0263, -22.1332,
         23.9854, -24.2066,  23.8880,  24.1964,  12.5992,   6.4231, -16.7931,
         -6.6260, -15.9072, -23.9110,  20.5405,   5.2969,   9.2776, -23.5279,
         15.5631,  23.1687, -24.6560,   1.4549,  11.8108, -24.4652,  -9.6431,
        -23.9609,   9.0096,   4.4318, -24.5280, -24.6686,  24.3940, -24.0790,
         11.9851,  -7.0437,  17.2459,  11.9501, -24.3065,  24.7551, -10.6594,
         24.6832,  24.0830,  12.1820, -22.6551,  11.5486, -25.9461, -23.7650,
         11.8115,   5.9917, -19.7801,  10.0713, -24.1953, -17.2989,  24.4109,
         24.5908, -22.9933, -21.6915, -24.7751,  24.0575,  18.9463, -19.1986,
        -23.5262,  25.0119,  23.7467, -12.4631,  25.0144, -24.8882,  17.6349,
        -24.3782, -24.8354, -23.7090, -23.8478,  10.3990, -24.6180, -23.9910,
        -24.6055,  24.6195,  24.4928, -11.3367, -23.7406, -24.7030, -19.7857,
         23.9775,  24.8067,  23.8190, -16.6945,  12.2922,  25.26

In [31]:
(vot.position_embedding.position_embedding - dino.pos_embed[0,1:]).abs().max()

tensor(3.5027e-06, grad_fn=<MaxBackward1>)

In [35]:
(vot.position_embedding.position_embedding.grad - dino.pos_embed.grad[0,1:]).abs().max()

tensor(8.7311e-11)

In [26]:
(pos_emb[0,1:]-pos_embeddings[0]).abs().max()

tensor(3.5027e-06, grad_fn=<MaxBackward1>)

In [None]:
(patch_embeddings-patch_emb).abs().max()

tensor(0., grad_fn=<MaxBackward1>)