# Use dinov2 conda environment for this

In [1]:
import os
import sys
import time
import math
import random
import torch
import torchvision
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.nn.functional as F

import logging
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

In [2]:
model = torch.hub.load("facebookresearch/hiera", model="hiera_base_224", pretrained=True, checkpoint="mae_in1k_ft_in1k")

Using cache found in /home/sur06423/.cache/torch/hub/facebookresearch_hiera_main
Downloading: "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth" to /home/sur06423/.cache/torch/hub/checkpoints/hiera_base_224.pth
100%|██████████| 590M/590M [00:23<00:00, 26.3MB/s]  


In [3]:
from torchinfo import summary
summary(model, input_size=[1, 3, 224, 224])

Layer (type:depth-idx)                   Output Shape              Param #
Hiera                                    [1, 1000]                 301,056
├─PatchEmbed: 1-1                        [1, 3136, 96]             --
│    └─Conv2d: 2-1                       [1, 96, 56, 56]           14,208
├─Unroll: 1-2                            [1, 3136, 96]             --
├─ModuleList: 1-3                        --                        --
│    └─HieraBlock: 2-2                   [1, 3136, 96]             --
│    │    └─LayerNorm: 3-1               [1, 3136, 96]             192
│    │    └─MaskUnitAttention: 3-2       [1, 3136, 96]             37,248
│    │    └─Identity: 3-3                [1, 3136, 96]             --
│    │    └─LayerNorm: 3-4               [1, 3136, 96]             192
│    │    └─Mlp: 3-5                     [1, 3136, 96]             74,208
│    │    └─Identity: 3-6                [1, 3136, 96]             --
│    └─HieraBlock: 2-3                   [1, 3136, 96]            

In [None]:
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
import hiera
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

In [None]:
# Create input transformations
input_size = 224

transform_list = [
    transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(input_size)
]

# The visualization and model need different transforms
transform_vis  = transforms.Compose(transform_list)
transform_norm = transforms.Compose(transform_list + [
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

In [None]:
# Get imagenet class as output
out = model(img_norm[None, ...])

# 207: golden retriever  (imagenet-1k)
out.argmax(dim=-1).item()

In [None]:
# If you also want intermediate feature maps
_, intermediates = model(img_norm[None, ...], return_intermediates=True)

for x in intermediates:
    print(x.shape)

In [None]:
val_dir = "/net/polaris/storage/deeplearning/sur_data/binary_rgb_daa/split_0/val"

val_dataset = ImageFolder(root=val_dir)

In [None]:
# Turn train and test Datasets into DataLoaders
from torch.utils.data import DataLoader
val_dataloader = DataLoader(dataset=val_dataset, 
                             batch_size=1, 
                             num_workers=1, 
                             shuffle=False) # don't usually need to shuffle testing data