In [1]:
!git clone https://github.com/wkcn/TinyViT.git

Cloning into 'TinyViT'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (72/72), done.[K
remote: Total 88 (delta 11), reused 88 (delta 11), pack-reused 0 (from 0)[K
Receiving objects: 100% (88/88), 629.01 KiB | 15.34 MiB/s, done.
Resolving deltas: 100% (11/11), done.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!git clone https://github.com/kyegomez/SparseAttention.git

Cloning into 'SparseAttention'...
remote: Enumerating objects: 170, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (44/44), done.[K
remote: Total 170 (delta 25), reused 1 (delta 1), pack-reused 125 (from 2)[K
Receiving objects: 100% (170/170), 2.22 MiB | 32.05 MiB/s, done.
Resolving deltas: 100% (33/33), done.


In [4]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("lyfora/processed-imagenet-dataset-224")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'processed-imagenet-dataset-224' dataset.
Path to dataset files: /kaggle/input/processed-imagenet-dataset-224


In [5]:
import sys
import os
import torch

# Add TinyViT to system path to enable imports
if '/content/TinyViT' not in sys.path:
    sys.path.append('/content/TinyViT')

try:
    # Attempt to import the specific model function
    from models.tiny_vit import tiny_vit_5m_224

    # Instantiate the model
    model = tiny_vit_5m_224(pretrained=False)
    print("Successfully loaded 'tiny_vit_5m_224' model from the TinyViT library.")

except ImportError as e:
    print(f"Error importing model: {e}")
    print("Listing TinyViT directory contents for debugging:")
    if os.path.exists('/content/TinyViT'):
        print(os.listdir('/content/TinyViT'))

  @register_model
  @register_model
  @register_model
  @register_model
  @register_model


Successfully loaded 'tiny_vit_5m_224' model from the TinyViT library.


In [1]:
# Install a compatible version of timm
!pip install timm==0.9.10



In [4]:
import json
import torch
import os
import inspect
from models.tiny_vit import TinyViT

# Paths
config_path = '/content/custom_tinyvit_config.json'
weights_path = '/content/custom_tinyvit_5m_sparse.pth'

# 1. Load Config
try:
    with open(config_path, 'r') as f:
        custom_config = json.load(f)
    print("Loaded custom config.")
except FileNotFoundError:
    print(f"Config not found at {config_path}")
    custom_config = {}

# 2. Prepare Model Arguments
# Default parameters for TinyViT-21M
model_kwargs = {
    'embed_dims': [96, 192, 384, 576],
    'depths': [2, 2, 6, 2],
    'num_heads': [3, 6, 12, 18],
    'window_sizes': [7, 7, 14, 7],  # Standard default
    'drop_path_rate': 0.2,
    'num_classes': 1000
}

# Update with custom config values (e.g. window_sizes=[14, 14, 14, 7])
# We filter out keys that valid TinyViT.__init__ doesn't accept (like 'sparse_flags' if unsupported)
sig = inspect.signature(TinyViT.__init__)
valid_keys = set(sig.parameters.keys())

for k, v in custom_config.items():
    if k in valid_keys:
        model_kwargs[k] = v
        print(f"Overriding default '{k}' with custom value: {v}")
    else:
        print(f"Skipping custom config key '{k}' (not supported by TinyViT class)")

# 3. Instantiate Model
print("Instantiating TinyViT directly with custom arguments...")
try:
    model = TinyViT(**model_kwargs)
    print("Model instantiated successfully.")
except Exception as e:
    print(f"Error instantiating model: {e}")

# 4. Load Weights
if os.path.exists(weights_path):
    try:
        checkpoint = torch.load(weights_path, map_location='cpu')
        state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint

        # strict=False allows ignoring missing keys (like if sparse args imply extra layers)
        msg = model.load_state_dict(state_dict, strict=False)
        print(f"Loaded weights from {weights_path}")
        print(f"Load result: {msg}")
    except Exception as e:
        print(f"Error loading weights: {e}")
else:
    print(f"Weights file not found at {weights_path}")

Config not found at /content/custom_tinyvit_config.json
Instantiating TinyViT directly with custom arguments...
Model instantiated successfully.
Weights file not found at /content/custom_tinyvit_5m_sparse.pth


In [5]:
import json
import torch
import os
import inspect
import sys

# Ensure TinyViT is in path
if '/content/TinyViT' not in sys.path:
    sys.path.append('/content/TinyViT')

from models.tiny_vit import TinyViT

# Paths
config_path = '/content/custom_tinyvit_config.json'
weights_path = '/content/custom_tinyvit_5m_sparse.pth'

# 1. Load Config
try:
    with open(config_path, 'r') as f:
        custom_config = json.load(f)
    print("Loaded custom config.")
except FileNotFoundError:
    print(f"Config not found at {config_path}")
    custom_config = {}

# 2. Prepare Model Arguments
# Default parameters for TinyViT-21M
model_kwargs = {
    'embed_dims': [96, 192, 384, 576],
    'depths': [2, 2, 6, 2],
    'num_heads': [3, 6, 12, 18],
    'window_sizes': [7, 7, 14, 7],
    'drop_path_rate': 0.2,
    'num_classes': 1000
}

# Update with custom config values
sig = inspect.signature(TinyViT.__init__)
valid_keys = set(sig.parameters.keys())

for k, v in custom_config.items():
    if k in valid_keys:
        model_kwargs[k] = v
        print(f"Overriding default '{k}' with custom value: {v}")

# 3. Instantiate Model
print("Instantiating TinyViT directly with custom arguments...")
try:
    model = TinyViT(**model_kwargs)
    print("Model instantiated successfully.")
except Exception as e:
    print(f"Error instantiating model: {e}")

# 4. Load Weights
if os.path.exists(weights_path):
    try:
        checkpoint = torch.load(weights_path, map_location='cpu')
        state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint

        msg = model.load_state_dict(state_dict, strict=False)
        print(f"Loaded weights from {weights_path}")
        print(f"Load result: {msg}")
    except Exception as e:
        print(f"Error loading weights: {e}")
else:
    print(f"Weights file not found at {weights_path}")

Config not found at /content/custom_tinyvit_config.json
Instantiating TinyViT directly with custom arguments...
Model instantiated successfully.
Weights file not found at /content/custom_tinyvit_5m_sparse.pth


In [8]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import cv2
import numpy as np

# 1. Efficiently find the FIRST image in the dataset path without listing everything
print(f"Searching for images in: {path}")

def find_first_image(directory):
    extensions = {'.jpg', '.jpeg', '.png'}
    for root, dirs, files in os.walk(directory):
        for file in files:
            if os.path.splitext(file)[1].lower() in extensions:
                return os.path.join(root, file)
    return None

img_path = find_first_image(path)

if not img_path:
    print("No images found in the dataset path.")
else:
    print(f"Using image: {img_path}")

    # 2. Preprocess the image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    raw_image = Image.open(img_path).convert('RGB')
    input_tensor = transform(raw_image).unsqueeze(0)  # Add batch dimension

    # Display original image
    plt.imshow(raw_image)
    plt.title("Input Image")
    plt.axis('off')
    plt.show()

NameError: name 'path' is not defined

In [9]:
# Container to store qkv outputs
qkv_outputs = []

# Hook function to capture output of the qkv layer
def hook_fn(module, input, output):
    # output shape: (Batch, N, 3 * C)
    qkv_outputs.append(output.detach().cpu())

# Register hook on the last block's qkv layer
try:
    # Access the attention module of the last block
    attn_module = model.layers[-1].blocks[-1].attn
    target_layer = attn_module.qkv

    # Get num_heads for reshaping later (usually stored in the attn module)
    num_heads = getattr(attn_module, 'num_heads', 18) # Default to 18 if not found, based on previous dump

    handle = target_layer.register_forward_hook(hook_fn)
    print(f"Hook registered on: {target_layer}")

    # Run inference
    model.eval()
    with torch.no_grad():
        _ = model(input_tensor)

    # Remove hook
    handle.remove()

    if qkv_outputs:
        # Process collected qkv
        # Shape: (1, N, 3*C)
        qkv = qkv_outputs[0] # (1, 49, 1728)
        B, N, C_total = qkv.shape

        # Reshape to (B, N, 3, Num_Heads, Head_Dim)
        # C_total = 3 * Num_Heads * Head_Dim
        head_dim = C_total // (3 * num_heads)

        qkv = qkv.reshape(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
        # Shape: (3, B, Num_Heads, N, Head_Dim)

        q, k, v = qkv[0], qkv[1], qkv[2]

        # Calculate Attention: (Q @ K.T) * scale
        scale = head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)
        # Shape: (B, Num_Heads, N, N)

        print(f"Calculated attention shape: {attn.shape}")

        # Average over heads and batch
        attn_mean = attn[0].mean(dim=0) # (N, N)

        # Visualize
        side = int(np.sqrt(attn_mean.shape[0])) # Should be 7
        center_idx = (side * side) // 2

        attn_map = attn_mean[center_idx, :].reshape(side, side)

        # Upsample to image size for visualization
        attn_map_resized = cv2.resize(attn_map.numpy(), (224, 224), interpolation=cv2.INTER_CUBIC)

        # Plot
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(raw_image)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(raw_image)
        plt.imshow(attn_map_resized, cmap='jet', alpha=0.5)
        plt.title("Attention Map (Last Layer, Center Pixel)")
        plt.axis('off')
        plt.show()

    else:
        print("No QKV outputs captured.")

except AttributeError as e:
    print(f"Error accessing layer: {e}")
    print("Please verify the model structure.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Hook registered on: Linear(in_features=576, out_features=1728, bias=True)
An unexpected error occurred: name 'input_tensor' is not defined


In [None]:
import os
import sys

print("Restarting runtime to apply library changes and fix torch error...\n")
print("After the restart, please re-run your cells starting from the imports.")

# This command kills the current process, causing Colab to automatically restart the kernel
os.kill(os.getpid(), 9)

In [None]:
import timm
import torch

# Load TinyViT-5M-224
model_og = timm.create_model('tiny_vit_5m_224', pretrained=True)
model.eval()

In [None]:
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

# Ensure model_og is in eval mode
model_og.eval()

# Container for qkv outputs
og_qkv_outputs = []

def og_hook_fn(module, input, output):
    og_qkv_outputs.append(output.detach().cpu())

# Register Hook on the last block of the last stage
# Based on the printed structure: model_og.stages[3].blocks[1].attn.qkv
try:
    # TinyViT usually has 4 stages (0, 1, 2, 3)
    target_layer = model_og.stages[-1].blocks[-1].attn.qkv
    # Get num_heads (seen in output as 10 for stage 3)
    num_heads = model_og.stages[-1].blocks[-1].attn.num_heads

    handle = target_layer.register_forward_hook(og_hook_fn)
    print(f"Hook registered on: {target_layer}")

    # Run Inference
    # Ensuring input_tensor is on the same device as model
    device = next(model_og.parameters()).device
    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        _ = model_og(input_tensor)

    handle.remove()

    # Process and Visualize
    if og_qkv_outputs:
        qkv = og_qkv_outputs[0]
        B, N, C_total = qkv.shape

        # Reshape: (B, N, 3, Num_Heads, Head_Dim)
        head_dim = C_total // (3 * num_heads)
        qkv = qkv.reshape(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Calculate Attention: (Q @ K.T) * scale
        scale = head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)

        # Average over heads
        attn_mean = attn[0].mean(dim=0)

        # Visualize Center Pixel Attention
        side = int(np.sqrt(attn_mean.shape[0]))
        center_idx = (side * side) // 2
        attn_map = attn_mean[center_idx, :].reshape(side, side)

        # Resize to image size
        attn_map_resized = cv2.resize(attn_map.numpy(), (224, 224), interpolation=cv2.INTER_CUBIC)

        # Plot
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(raw_image)
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(raw_image)
        plt.imshow(attn_map_resized, cmap='jet', alpha=0.5)
        plt.title("Attention Map (model_og)")
        plt.axis('off')
        plt.show()
    else:
        print("No QKV outputs captured.")

except Exception as e:
    print(f"Error processing model_og: {e}")