In [1]:
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt

# Load the model
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
dinov2.eval()

# Define image transforms
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225])
])

# Helper function to load and process an image
def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    return transform(img).unsqueeze(0)

# Example usage
# image = load_image('path/to/your/image.jpg')
# with torch.no_grad():
#     features = dinov2(image)

Using cache found in /home/arda/.cache/torch/hub/facebookresearch_dinov2_main
    PyTorch 2.5.1+cu124 with CUDA 1204 (you have 2.0.0+cu117)
    Python  3.9.20 (you have 3.9.20)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [6]:
import sys
sys.path.append('/home/arda/dinov2')
from distillation.datasets.CustomDataset import CustomDataset

dataset = CustomDataset(img_dir='/home/arda/data/train2017', transform=transform)



In [8]:
from distillation.models.ModelWrapper import ModelWrapper

resnet50 = ModelWrapper(
            model_type='resnet',
            n_patches=256,
            target_feature=['res5'],
            feature_matcher_config=None,
            checkpoint_path='/home/arda/dinov2/distillation/checkpoints/resnet50_distilled.pth',
            **{'depth': 50, 'out_features': ['res4', 'res5'], 'freeze_at': 0, 'norm_type': 'BN'}
        )

tensor([[[-1.8782, -1.8782, -1.8782,  ..., -0.4054, -0.4054, -0.3883],
         [-1.8610, -1.8610, -1.8782,  ..., -0.3883, -0.3541, -0.3369],
         [-1.8610, -1.8782, -1.8782,  ..., -0.3541, -0.3027, -0.3027],
         ...,
         [ 0.0569,  0.0398,  0.0398,  ..., -0.1486, -0.1657, -0.2513],
         [ 0.0227,  0.0227,  0.0398,  ..., -0.1657, -0.1999, -0.2513],
         [ 0.0227,  0.0569,  0.1083,  ..., -0.1999, -0.2171, -0.2684]],

        [[-1.8431, -1.8431, -1.8431,  ...,  0.3803,  0.3803,  0.4153],
         [-1.8431, -1.8431, -1.8431,  ...,  0.3978,  0.4153,  0.4153],
         [-1.8431, -1.8431, -1.8431,  ...,  0.3978,  0.3978,  0.4153],
         ...,
         [ 0.5903,  0.6078,  0.6078,  ...,  0.2402,  0.1877,  0.1877],
         [ 0.6078,  0.6254,  0.6254,  ...,  0.2227,  0.2052,  0.1877],
         [ 0.6078,  0.6254,  0.6604,  ...,  0.2227,  0.2052,  0.2052]],

        [[-1.4210, -1.4384, -1.4384,  ...,  0.7054,  0.6879,  0.6356],
         [-1.4210, -1.4210, -1.4210,  ...,  0