# **Original ViTB16**

In [1]:
from torchvision.models import vit_b_16
ViT_B_16 = vit_b_16()
ViT_B_16

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

# **Custom ViTB16**

In [2]:
import torch
import torch.nn as nn
from torch.nn import MultiheadAttention
from collections import OrderedDict

class MLPBlock(nn.Module):
    def __init__(self, in_features, hidden_features, dropout_prob):
        super(MLPBlock, self).__init__()
        self.add_module('linear_1', nn.Linear(in_features, hidden_features, bias=True))
        self.add_module('1', nn.GELU(approximate='none'))
        self.add_module('2', nn.Dropout(dropout_prob))
        self.add_module('linear_2', nn.Linear(hidden_features, in_features, bias=True))
        self.add_module('4', nn.Dropout(dropout_prob))

    def forward(self, x):
        for layer in self.children():
            x = layer(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_hidden_dim, dropout_prob):
        super(EncoderBlock, self).__init__()
        self.ln_1 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.self_attention = MultiheadAttention(embed_dim, num_heads, dropout=dropout_prob, batch_first=True)
        self.dropout = nn.Dropout(dropout_prob)
        self.ln_2 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.mlp = MLPBlock(embed_dim, mlp_hidden_dim, dropout_prob)

    def forward(self, x):
        attn_output, _ = self.self_attention(x, x, x)
        x = x + self.dropout(attn_output)
        x = self.ln_1(x)
        mlp_output = self.mlp(x)
        x = x + self.dropout(mlp_output)
        x = self.ln_2(x)
        return x

class Encoder(nn.Module):
    def __init__(self, embed_dim, num_layers, num_heads, mlp_hidden_dim, dropout_prob, num_patches):
        super(Encoder, self).__init__()
        self.dropout = nn.Dropout(dropout_prob)

        self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        layers = OrderedDict()
        for i in range(num_layers):
            layers[f'encoder_layer_{i}'] = EncoderBlock(embed_dim, num_heads, mlp_hidden_dim, dropout_prob)

        self.layers = nn.Sequential(layers)
        self.ln = nn.LayerNorm(embed_dim, eps=1e-6)

    def forward(self, x):
        x = x + self.pos_embedding
        x = self.dropout(x)
        x = self.layers(x)
        x = self.ln(x)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, num_layers=12, num_heads=12, mlp_hidden_dim=3072, dropout_prob=0.0):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.conv_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size))

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.encoder = Encoder(embed_dim, num_layers, num_heads, mlp_hidden_dim, dropout_prob, self.num_patches)

        self.heads = nn.Sequential(OrderedDict([('head', nn.Linear(embed_dim, num_classes))]))

    def forward(self, x):

        x = self.conv_proj(x)  # Shape [batch_size, embed_dim, num_patches_height, num_patches_width]
        x = x.flatten(2)  # Shape [batch_size, embed_dim, num_patches]
        x = x.transpose(1, 2)  # Shape [batch_size, num_patches, embed_dim]

        batch_size = x.size(0)
        class_token = self.class_token.expand(batch_size, -1, -1)  # Shape [batch_size, 1, embed_dim]
        x = torch.cat((class_token, x), dim=1)  # Shape [batch_size, num_patches + 1, embed_dim]

        x = self.encoder(x)
        x = x[:, 0]
        x = self.heads(x)
        return x

custom_model = VisionTransformer(num_classes =2)

In [3]:
custom_model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
 

# **Loading weights**

In [4]:
# Define paths
root_dir = "/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/"
models_root_dir = root_dir + 'DeepfakeEmpiricalStudy_Models/'

# Load the entire state_dict
vit_weights = torch.load(models_root_dir + 'vit_b_16-c867db91.pth')

# Get the current state_dict of your custom model
model_state_dict = custom_model.state_dict()

  vit_weights = torch.load(models_root_dir + 'vit_b_16-c867db91.pth')


In [5]:
# List of specific layers you want to load (these should match the layer names in the state_dict)
layers_to_load = list(model_state_dict.keys())[:-2]

# Filter only those layers from the vit_weights
filtered_weights = {k: v for k, v in vit_weights.items() if k in layers_to_load}

# Ensure the selected layers exist in the model's state_dict
filtered_weights = {k: v for k, v in filtered_weights.items() if k in model_state_dict}

# Load the filtered state_dict into your custom model
custom_model.load_state_dict(filtered_weights, strict=False)

_IncompatibleKeys(missing_keys=['heads.head.weight', 'heads.head.bias'], unexpected_keys=[])