In [4]:
# load the 10k model
import torch
from torchvision.models import VisionTransformer
import torch.nn as nn
import pickle
import os
import numpy as np
import pandas as pd

In [5]:
class BinaryViT(nn.Module):
    def __init__(self):
        super(BinaryViT, self).__init__()
        self.vit = VisionTransformer(
                    image_size=224,    # Input image size
                    patch_size=16,     # Patch size
                    hidden_dim=768,          # Embedding dimension
                    num_layers=12,         # Number of transformer layers
                    num_heads=12,         # Number of attention heads
                    mlp_dim=3072,     # Feedforward MLP dimension
                    dropout=0.1,      # Dropout probability
                    num_classes=1  # Output classes
                )
    def forward(self, x):
        x = self.vit(x)
        return x

In [6]:
weights = torch.load("models/trained_with_10k.pth", map_location="cpu")
model = BinaryViT()
model.load_state_dict(weights)
model.eval()

BinaryViT(
  (vit): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.1, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.1, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln_1): LayerNo