In [None]:
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import kagglehub
path = kagglehub.dataset_download("ambityga/imagenet100")
print("Path to dataset files:", path)

In [None]:
!ls -l $path/val.X/n01773549

In [None]:
imagepath = path + '/val.X/' + 'n01773549' +'/' + 'ILSVRC2012_val_00008316.JPEG'

In [None]:
import cv2
image = cv2.imread(imagepath)
from google.colab.patches import cv2_imshow
cv2_imshow(image)

In [None]:
from types import SimpleNamespace
config = SimpleNamespace()
config.classifier = "token"
config.hidden_size = 768
config.patches = {
    "grid": (32, 32),
}
config.transformer = {
    "num_heads": 12,
    "num_layers": 12,
    "attention_dropout_rate": 0.0,
    "dropout_rate": 0.1,
    "mlp_dim": 3072,
}

In [None]:
def np2th(weights, conv=False): # convert HWIO to OIHW
    """Possibly ."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)

In [None]:
def swish(x):
    return x * torch.sigmoid(x)

ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

In [None]:
import math
import torch.nn as nn
from torch.nn import Linear, Dropout, LayerNorm, Softmax

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

In [None]:
class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
from torch.nn.modules.utils import _pair
from torch.nn import Conv2d

class Embeddings(nn.Module):
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        self.patch_embeddings = Conv2d(in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

In [None]:
def pjoin(a,b,c):
    return a + "/" + b + "/" + c

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

In [None]:
import copy

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

In [None]:
class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights

In [None]:
from torch.nn import CrossEntropyLoss
from scipy import ndimage

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis)
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        logits = self.head(x[:, 0])

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights

    def load_from(self, weights):
        with torch.no_grad():
            if self.zero_head:
                nn.init.zeros_(self.head.weight)
                nn.init.zeros_(self.head.bias)
            else:
                self.head.weight.copy_(np2th(weights["head/kernel"]).t())
                self.head.bias.copy_(np2th(weights["head/bias"]).t())

            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(weights["conv_root/kernel"], conv=True))
                gn_weight = np2th(weights["gn_root/scale"]).view(-1)
                gn_bias = np2th(weights["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(weights, n_block=bname, n_unit=uname)

In [None]:
# mimic module hierarchy from the pretrained model
import sys
import types

# Create fake package "models"
models = types.ModuleType("models")
sys.modules["models"] = models

# Create fake submodule "models.modeling"
modeling = types.ModuleType("models.modeling")
sys.modules["models.modeling"] = modeling

# Expose your classes inside the fake module
modeling.VisionTransformer = VisionTransformer
modeling.Transformer = Transformer
modeling.Encoder = Encoder
modeling.Block = Block
modeling.Embeddings = Embeddings
modeling.Attention = Attention
modeling.Mlp = Mlp

In [None]:
# download the pretrained model
!wget http://www.agentspace.org/download/ViT-B_32.pth

In [None]:
# Load Model
model = torch.load('ViT-B_32.pth', weights_only=False)
model.eval()

In [None]:
# load image
from PIL import Image
im = Image.open(imagepath)

In [None]:
# preprocess
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
blob = transform(im).unsqueeze(0).to(device)
print('image', blob.shape)

In [None]:
# Call the transformer model
def embed(self, x):
    B = x.shape[0]
    cls_tokens = self.cls_token.expand(B, -1, -1) # nn.Parameters  1 x 1 x 768

    if self.hybrid:
        x = self.hybrid_model(x)

    x = self.patch_embeddings(x) # Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32)) -> 1 x 768 x 7 x 7
    x = x.flatten(2) # 1 x 768 x 49
    x = x.transpose(-1, -2) # 1 x 49 x 768
    x = torch.cat((cls_tokens, x), dim=1) # 1 x 50 x 768

    embeddings = x + self.position_embeddings #  1 x 50 x 768
    embeddings = self.dropout(embeddings)
    return embeddings

hidden_states = embed(model.transformer.embeddings, blob) # 1 x 50 x 768

def encode(self, hidden_states):
    attn_maps = []
    for layer_block in self.layer:
        hidden_states, coefs = layer_block(hidden_states)
        if self.vis:
            attn_maps.append(coefs)

    encoded = self.encoder_norm(hidden_states)
    return encoded, attn_maps

hidden_states, att_maps = encode(model.transformer.encoder, hidden_states) # 1 x 50 x 768, 12 x 50 x 50

def lmhead(self, hidden_states):
    logits = self.head(hidden_states[:, 0]) # Linear(in_features=768, out_features=1000, bias=True) 1 x 1 x 768 -> 1 x 1000
    return logits

logits = lmhead(model, hidden_states)

print('logits',logits.shape)
print('att maps',[att_map.shape for att_map in att_maps])

In [None]:
# download ImageNet Labels
!wget http://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt

In [None]:
imagenet_labels = dict(enumerate(open('ilsvrc2012_wordnet_lemmas.txt')))
print(imagenet_labels)

In [None]:
# Present probabilities of categories
probs = torch.nn.Softmax(dim=-1)(logits)
top5 = torch.argsort(probs, dim=-1, descending=True)

print("Prediction:")
for idx in top5[0, :5]:
    print(f'{probs[0, idx.item()]:.5f} : {imagenet_labels[idx.item()]}', end='')

In [None]:
def draw_mask(img, mask):
    H, W = img.shape[:2]
    mask_resized = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
    if img.ndim == 3:
        mask_resized = np.repeat(mask_resized[:, :, None], 3, axis=2)
    result = (img.astype(np.float32) * mask_resized).clip(0, 255).astype(np.uint8)
    return result

In [None]:
import matplotlib.pyplot as plt
base = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2GRAY)
att_mats = torch.cat(att_maps) # (12, 12, 50, 50)
num_layers, num_heads, N, _ = att_mats.shape
grid_size = int(math.sqrt(N-1))
plt.figure(figsize=(2 * num_heads, 2 * num_layers))
for t in range(num_layers):
    for i in range(num_heads):
        head = att_mats[t, i]      # shape (50, 50)
        mask = head[0, 1:].reshape(grid_size, grid_size)
        mask = mask.detach().cpu().numpy()
        disp = draw_mask(base, mask)
        plt.subplot(num_layers, num_heads, t * num_heads + i + 1)
        plt.imshow(disp, cmap='gray')
        plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# joint_attention:
# ===============
print(att_mats.shape) # (num_layers, num_heads, N, N)

In [None]:
# 1) Average attention across heads: (L, N, N)
att = att_mats.mean(dim=1)
print(att.shape)

In [None]:
# visualize joined heads
plt.figure(figsize=(2, 2 * num_layers))
for t in range(num_layers):
    head = att[t]
    mask = head[0, 1:].reshape(grid_size, grid_size)
    mask = mask.detach().cpu().numpy()
    disp = draw_mask(base, mask)
    plt.subplot(num_layers, 1, t+1)
    plt.imshow(disp, cmap='gray')
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# 2) Add residual connection to each layer
# (I + A) normalized by rows
eye = torch.eye(N).expand(num_layers, N, N).to(att.device)
att = att + eye
# row-normalize
att = att / att.sum(dim=-1, keepdim=True)
print(att.shape)

In [None]:
# visualize adjusted heads
plt.figure(figsize=(2, 2 * num_layers))
for t in range(num_layers):
    head = att[t]
    mask = head[0, 1:].reshape(grid_size, grid_size)
    mask = mask.detach().cpu().numpy()
    disp = draw_mask(base, mask)
    plt.subplot(num_layers, 1, t+1)
    plt.imshow(disp, cmap='gray')
    plt.axis('off')
plt.tight_layout()
plt.show()

In [None]:

# 3) Joint attention: multiply attention matrices from first to last layer
joint_att = att[0]
for l in range(1, num_layers):
    joint_att = att[l] @ joint_att  # matrix multiplication

print(joint_att.shape)

In [None]:
# visualize joined heads
plt.figure(figsize=(2, 2))
mask = joint_att[0, 1:].reshape(grid_size, grid_size)
mask = mask.detach().cpu().numpy()
disp = draw_mask(base, mask)
plt.subplot(1, 1, 1)
plt.imshow(disp, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()