In [1]:
# imports

from torchvision import transforms
from torchinfo import summary
from torch import hub, device, cuda, load, Tensor, nn, cat, from_numpy, argmax
from timm.data import resolve_model_data_config, create_transform

import numpy as np
import os
from PIL import Image

device = device('cuda' if cuda.is_available() else 'cpu')
print(f'Device used: {device}')

  from .autonotebook import tqdm as notebook_tqdm


Device used: cuda


In [2]:
# load model

vision_transformer = hub.load('facebookresearch/dinov2', 'dinov2_vits14')
vision_transformer.eval()
vision_transformer.to(device)

pretr_head = load('/home/stud/afroehli/coding/dinov2_ood/pretrained_heads/dinov2_vits14_linear_head.pth')
pretr_head_big = load('/home/stud/afroehli/coding/dinov2_ood/pretrained_heads/dinov2_vits14_linear4_head.pth')

Using cache found in /home/stud/afroehli/.cache/torch/hub/facebookresearch_dinov2_main


  pretr_head = load('/home/stud/afroehli/coding/dinov2_ood/pretrained_heads/dinov2_vits14_linear_head.pth')
  pretr_head_big = load('/home/stud/afroehli/coding/dinov2_ood/pretrained_heads/dinov2_vits14_linear4_head.pth')


In [14]:
# experiments
#print(pretr_head)
#summary(pretr_head)
print(pretr_head['weight'].size())
print(pretr_head['bias'].size())
print(pretr_head_big['weight'].size())
print(pretr_head_big['bias'].size())

torch.Size([1000, 768])
torch.Size([1000])
torch.Size([1000, 1920])
torch.Size([1000])


In [3]:
base_path = '/home/stud/afroehli/datasets/ImageNet1k/imagenet1k/ILSVRC/Data/CLS-LOC/val_sorted/n01484850'
imgs_pths_one_class = os.listdir(base_path)

timm_model = 'vit_small_patch14_dinov2'
timm_model_conf = resolve_model_data_config(timm_model)
timm_model_conf['input_size'] = (3, 518, 518)

timm_transform = create_transform(**timm_model_conf, is_training=False)

imgs_transformed = [timm_transform(Image.open(base_path + '/' + img_path)) for img_path in imgs_pths_one_class]

In [4]:
from torch import nn

class LinearClassifier(nn.Module): 

    def __init__(self, in_features = 384, out_features = 1000):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(in_features= in_features, out_features=out_features),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.network(x)

In [13]:
img_on_dev = imgs_transformed[0]
img_on_dev = img_on_dev.unsqueeze(0)
img_on_dev = img_on_dev.to(device)
patch_tokens = vision_transformer.get_intermediate_layers(img_on_dev, return_class_token=True)

print(patch_tokens[0][0].shape)
print(patch_tokens[0][1].shape)
# p_token_cpu = patch_tokens[0].cpu().detach().numpy()
# p_token_mean = np.mean(p_token_cpu, axis=1)
# print(p_token_mean.shape)

torch.Size([1, 1369, 384])
torch.Size([1, 384])


In [7]:
from collections import OrderedDict

linear_cl = LinearClassifier(in_features=768, out_features=1000)

trained_model_weights = load('/home/stud/afroehli/coding/dinov2_ood/storage/model_checkpoints/linear_classifier_epoch_27.pth')
print(trained_model_weights['model_state'].keys())

model_params = OrderedDict()
model_params['network.0.weight'] = pretr_head['weight']
model_params['network.0.bias'] = pretr_head['bias']
print(model_params.keys())
linear_cl.load_state_dict(model_params)
linear_cl.eval()
linear_cl.to(device)

class_preds = []
check_preds = []
for imgs in imgs_transformed:

    last_layers = vision_transformer.get_intermediate_layers(imgs.unsqueeze(0).to(device), return_class_token=True)

    patch_tokens = last_layers[0][0]
    patch_tokens_trans = patch_tokens.cpu().detach().numpy()
    patch_tokens_mean = np.mean(patch_tokens_trans, axis=1)

    cls_token = last_layers[0][1].cpu().detach().numpy()

    cls_plus_patch = np.concatenate((cls_token[0], patch_tokens_mean[0]), axis=0)

    cls_plus_patch_trans = from_numpy(cls_plus_patch).unsqueeze(0).to(device)

    prediction = linear_cl(cls_plus_patch_trans)
    #print(prediction.argmax(1))
    pred_index = prediction.argmax(1).item()

    class_preds.append(pred_index)
    check_preds.append(argmax(prediction).item())

print(class_preds)
print(check_preds)

  trained_model_weights = load('/home/stud/afroehli/coding/dinov2_ood/storage/model_checkpoints/linear_classifier_epoch_27.pth')


odict_keys(['network.0.weight', 'network.0.bias'])
odict_keys(['network.0.weight', 'network.0.bias'])
[2, 2, 2, 2, 2, 3, 2, 2, 2, 148, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 149, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 4, 2, 2]
[2, 2, 2, 2, 2, 3, 2, 2, 2, 148, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 149, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 4, 2, 2]
