### Explore the possible outputs of DINO, the interface designed by meta, and extending it for LORA

In [2]:
import torch
import torch.nn as nn
torch.hub.set_dir("../pretrained_weights")

In [3]:
# load the backbone model
device = 'cpu' #0 if torch.cuda.is_available() else "cpu"
dino_backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').to(device).eval()

Using cache found in ../pretrained_weights/facebookresearch_dino_main


In [4]:
print(dino_backbone)

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
)


### Look at the user interface of DINO (returning class token, intermed. outputs, etc)

In [5]:
in_tensor = torch.randn(8, 3, 224, 224).to(device) # BxCxHxW
out = dino_backbone.forward(in_tensor)
out_intermediate = dino_backbone.get_intermediate_layers(in_tensor, n=3) # eg. the last 3 layers

In [6]:
print(out.shape) # class token
print(len(out_intermediate))
print(out_intermediate[0].shape)
out_intermediate[-1][:, 0] == out

torch.Size([8, 384])
3
torch.Size([8, 785, 384])


tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

### Look at sub-components of dino_backbone

In [7]:
# go deeper into the structure
print(dino_backbone.blocks[0])

Block(
  (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (attn): Attention(
    (qkv): Linear(in_features=384, out_features=1152, bias=True)
    (attn_drop): Dropout(p=0.0, inplace=False)
    (proj): Linear(in_features=384, out_features=384, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (drop_path): Identity()
  (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=384, out_features=1536, bias=True)
    (act): GELU(approximate='none')
    (fc2): Linear(in_features=1536, out_features=384, bias=True)
    (drop): Dropout(p=0.0, inplace=False)
  )
)


In [8]:
# isolate the linear weights in qkv
print(dino_backbone.blocks[0].attn.qkv)
print(dino_backbone.blocks[0].attn.qkv.weight.requires_grad)

Linear(in_features=384, out_features=1152, bias=True)
True


In [9]:
# we replace the qkv in Attention.qkv with a custom module
class Lora_qkv(nn.Module):
    def __init__(self, qkv, qa, qb, va, vb):
        super().__init__()
        self.qkv = qkv
        self.qa = qa
        self.qb = qb
        self.va = va
        self.vb = vb
        self.dim = self.qkv.in_features

    def forward(self, x):
        B, N, C = x.shape
        old_qkv = self.qkv(x) # B, N, 3C
        new_q = self.qb(self.qa(x)) # B, N, C
        new_v = self.vb(self.va(x)) # B, N, C
        
        old_qkv[:, :, : self.dim] += new_q # first C channels belong to q
        old_qkv[:, :, -self.dim:] += new_v # last C channels belong to v
        return old_qkv

# implementing the qv lora mechanism
# design similar to https://github.com/BeileiCui/SurgicalDINO/blob/main/surgicaldino.py
class Lora_vit(nn.Module):
    def __init__(self, base_vit, lora_rank=4, full_ft=False):
        super().__init__()
        if not full_ft:
            # constrain the model to only train lora weights
            for param in base_vit.parameters():
                param.requires_grad = False
        self.base = base_vit
        
        self.r = lora_rank
        self.in_ftrs = self.base.blocks[0].attn.qkv.in_features
        out_ftrs_qkv = self.base.blocks[0].attn.qkv.out_features
        assert out_ftrs_qkv % 3 == 0
        self.out_ftrs = out_ftrs_qkv // 3
        
        self.initialize_lora_layers()

    def initialize_lora_layers(self):
        # instantiate lora weights for each of the blocks
        qa, qb = [], []
        va, vb = [], []
        for i, block in enumerate(self.base.blocks):
            qa.append(nn.Linear(self.in_ftrs, self.r, bias=False))
            qb.append(nn.Linear(self.r, self.out_ftrs, bias=False))
            va.append(nn.Linear(self.in_ftrs, self.r, bias=False))
            vb.append(nn.Linear(self.r, self.out_ftrs, bias=False))
            block.attn.qkv = Lora_qkv(block.attn.qkv, qa[i], qb[i], va[i], vb[i]) 

    def forward(self, x): # class token
        return self.base.forward(x)

    def get_intermediate_layers(self, x, n=1):
        return self.base.get_intermediate_layers(x, n)


In [10]:
dino_backbone_2 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').to(device).eval()
lora_test = Lora_vit(dino_backbone_2, lora_rank=4)

Using cache found in ../pretrained_weights/facebookresearch_dino_main


In [11]:
print(lora_test)

Lora_vit(
  (base): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Lora_qkv(
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (qa): Linear(in_features=384, out_features=4, bias=False)
            (qb): Linear(in_features=4, out_features=384, bias=False)
            (va): Linear(in_features=384, out_features=4, bias=False)
            (vb): Linear(in_features=4, out_features=384, bias=False)
          )
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, element

In [12]:
# isolate the linear weights in qkv
print(lora_test.base.blocks[0].attn.qkv)
print(lora_test.base.blocks[0].attn.qkv.qkv.weight.requires_grad)
print(lora_test.base.blocks[0].attn.qkv.qa.weight.requires_grad)

Lora_qkv(
  (qkv): Linear(in_features=384, out_features=1152, bias=True)
  (qa): Linear(in_features=384, out_features=4, bias=False)
  (qb): Linear(in_features=4, out_features=384, bias=False)
  (va): Linear(in_features=384, out_features=4, bias=False)
  (vb): Linear(in_features=4, out_features=384, bias=False)
)
False
True


In [13]:
out_lora = lora_test.forward(in_tensor)
out_lora_intermediate = lora_test.get_intermediate_layers(in_tensor, n=3) # eg. the last 3 layers

In [14]:
print(out_lora.shape) # class token
print(len(out_lora_intermediate))
print(out_lora_intermediate[0].shape)
print(out_lora_intermediate[-1][:, 0] == out_lora)
out_lora == out

torch.Size([8, 384])
3
torch.Size([8, 785, 384])
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])


tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [15]:
# Print model's state_dict
print("Model's state_dict:")
lora_test.train()
for param_tensor in lora_test.state_dict():
    print(param_tensor, "\t", lora_test.state_dict()[param_tensor].size())

Model's state_dict:
base.cls_token 	 torch.Size([1, 1, 384])
base.pos_embed 	 torch.Size([1, 785, 384])
base.patch_embed.proj.weight 	 torch.Size([384, 3, 8, 8])
base.patch_embed.proj.bias 	 torch.Size([384])
base.blocks.0.norm1.weight 	 torch.Size([384])
base.blocks.0.norm1.bias 	 torch.Size([384])
base.blocks.0.attn.qkv.qkv.weight 	 torch.Size([1152, 384])
base.blocks.0.attn.qkv.qkv.bias 	 torch.Size([1152])
base.blocks.0.attn.qkv.qa.weight 	 torch.Size([4, 384])
base.blocks.0.attn.qkv.qb.weight 	 torch.Size([384, 4])
base.blocks.0.attn.qkv.va.weight 	 torch.Size([4, 384])
base.blocks.0.attn.qkv.vb.weight 	 torch.Size([384, 4])
base.blocks.0.attn.proj.weight 	 torch.Size([384, 384])
base.blocks.0.attn.proj.bias 	 torch.Size([384])
base.blocks.0.norm2.weight 	 torch.Size([384])
base.blocks.0.norm2.bias 	 torch.Size([384])
base.blocks.0.mlp.fc1.weight 	 torch.Size([1536, 384])
base.blocks.0.mlp.fc1.bias 	 torch.Size([1536])
base.blocks.0.mlp.fc2.weight 	 torch.Size([384, 1536])
base.bl

In [40]:
# try saving and loading state_dict
path = '../lora_test.pt'
torch.save(lora_test.state_dict(), path)

dino_backbone_3 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8').to(device).eval()
lora_test_new = Lora_vit(dino_backbone_3, lora_rank=4)

out_lora_2 = lora_test_new.forward(in_tensor)
print(out_lora_2 == out_lora) # should be false

lora_test_new.load_state_dict(torch.load(path))
out_lora_3 = lora_test_new.forward(in_tensor)
print(out_lora_3 == out_lora) # should be true

Using cache found in ../pretrained_weights/facebookresearch_dino_main


tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])


In [17]:
# check that the weight gradient behavior is consistent with before load
print(lora_test_new.base.blocks[0].attn.qkv)
print(lora_test_new.base.blocks[0].attn.qkv.qkv.weight.requires_grad) # should be equal to full_ft
print(lora_test_new.base.blocks[0].attn.qkv.qa.weight.requires_grad) # should be true

Lora_qkv(
  (qkv): Linear(in_features=384, out_features=1152, bias=True)
  (qa): Linear(in_features=384, out_features=4, bias=False)
  (qb): Linear(in_features=4, out_features=384, bias=False)
  (va): Linear(in_features=384, out_features=4, bias=False)
  (vb): Linear(in_features=4, out_features=384, bias=False)
)
False
True


In [42]:
trained_path = '../logs/training_1/checkpoint0009.pth'

def load_lora_vit_from_dino_ckpt(model, ckpt_path):
    # we are loading the teacher model from the DINO checkpoint
    teacher = torch.load(ckpt_path)['teacher']
    
    # edit the dictionary to remove the projector and rename backbone entries
    for k in list(teacher.keys()):
        if 'backbone' in k:
            teacher[k.replace('backbone.', '')] = teacher.pop(k)
        else:
            teacher.pop(k)

    # model is the lora_vit model with a consistent lora rank as the checkpoint
    model.load_state_dict(teacher, strict=False)

In [43]:
out3 = lora_test_new.forward(in_tensor)
load_lora_vit_from_dino_ckpt(lora_test_new, trained_path)
out4 = lora_test_new.forward(in_tensor)
out4 == out3 # should be false

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])