In [4]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from transformers import ViTModel

  from .autonotebook import tqdm as notebook_tqdm


# Self Attention

In [3]:
# Model Definition
class SelfAttention(nn.Module):
    def __init__(self, feature_dim):
        super(SelfAttention, self).__init__()
        self.query_proj = nn.Linear(feature_dim, feature_dim)
        self.key_proj = nn.Linear(feature_dim, feature_dim)
        self.value_proj = nn.Linear(feature_dim, feature_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, exemplar_features, image_features):
        """
        Args:
            exemplar_features: Tensor of shape [batch_size, feature_dim] (exemplar extracted from bounding box).
            image_features: Tensor of shape [batch_size, num_patches, feature_dim] (global features).

        Returns:
            Attended feature representation.
        """
        query = self.query_proj(exemplar_features).unsqueeze(1)  # [B, 1, D]
        key = self.key_proj(image_features)  # [B, N, D]
        value = self.value_proj(image_features)  # [B, N, D]

        attention_scores = torch.bmm(query, key.transpose(1, 2)) / (key.size(-1) ** 0.5)  # [B, 1, N]
        attention_weights = self.softmax(attention_scores)  # [B, 1, N]

        attended_features = torch.bmm(attention_weights, value)  # [B, 1, D]
        return attended_features.squeeze(1)  # [B, D]

class ExemplarObjectCounter(nn.Module):
    def __init__(self, feature_dim=2048, num_classes=1):
        super(ExemplarObjectCounter, self).__init__()
        
        self.backbone = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])  # Remove FC layers

        self.exemplar_extractor = models.resnet18(pretrained=True)
        self.exemplar_extractor = nn.Sequential(*list(self.exemplar_extractor.children())[:-1])  
        self.exemplar_fc = nn.Linear(512, feature_dim)  

        self.attention = SelfAttention(feature_dim)
        self.regressor = nn.Linear(feature_dim, num_classes)

    def forward(self, images, exemplars):
        """
        Args:
            images: Input images [B, 3, H, W].
            exemplars: Exemplar regions from bounding boxes [B, 3, 224, 224].

        Returns:
            Object count predictions.
        """
        features = self.backbone(images)  # [B, C, H', W']
        B, C, H, W = features.shape
        features = features.view(B, C, H * W).permute(0, 2, 1)  # [B, N, C]

        # Extract exemplar features
        exemplar_features = self.exemplar_extractor(exemplars).view(B, -1)  # [B, 512]
        exemplar_features = self.exemplar_fc(exemplar_features)  # [B, 2048]

        attended_features = self.attention(exemplar_features, features)  # [B, C]
        count_pred = self.regressor(attended_features)  # [B, 1]
        return count_pred

In [5]:
model = ExemplarObjectCounter()
print(model)



ExemplarObjectCounter(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

# Cross Attention

In [2]:
# Model Definition
class CrossAttention(nn.Module):
    def __init__(self, feature_dim, num_heads=8):
        super(CrossAttention, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads, batch_first=True)

    def forward(self, exemplar_features, image_features):
        """
        Args:
            exemplar_features: [B, 1, D] - Query (Exemplar Features)
            image_features: [B, N, D] - Key & Value (Image Features)

        Returns:
            Attended Features [B, 1, D] -> Squeezed to [B, D]
        """
        attended_features, _ = self.cross_attention(query=exemplar_features, key=image_features, value=image_features)
        return attended_features.squeeze(1)  # Remove sequence dimension


class ExemplarObjectCounter(nn.Module):
    def __init__(self, feature_dim=2048, num_classes=1):
        super(ExemplarObjectCounter, self).__init__()
        
        self.backbone = models.resnet50(pretrained=True)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])  # Remove FC layers

        self.exemplar_extractor = models.resnet18(pretrained=True)
        self.exemplar_extractor = nn.Sequential(*list(self.exemplar_extractor.children())[:-1])  
        self.exemplar_fc = nn.Linear(512, feature_dim)  

        self.cross_attention = CrossAttention(feature_dim, num_heads=8)  # Use Cross Attention

        self.regressor = nn.Linear(feature_dim, num_classes)

    def forward(self, images, exemplars):
        """
        Args:
            images: Input images [B, 3, H, W].
            exemplars: Exemplar regions from bounding boxes [B, 3, 224, 224].

        Returns:
            Object count predictions.
        """
        features = self.backbone(images)  # [B, C, H', W']
        B, C, H, W = features.shape
        features = features.view(B, C, H * W).permute(0, 2, 1)  # [B, N, C]

        # Extract exemplar features
        exemplar_features = self.exemplar_extractor(exemplars).view(B, -1)  # [B, 512]
        exemplar_features = self.exemplar_fc(exemplar_features).unsqueeze(1)  # [B, 1, 2048]

        attended_features = self.cross_attention(exemplar_features, features)  # [B, 2048]
        count_pred = self.regressor(attended_features)  # [B, 1]
        return count_pred

In [3]:
model = ExemplarObjectCounter()
print(model)



ExemplarObjectCounter(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

# ViT

In [5]:
class ViTObjectCounter(nn.Module):
    def __init__(self, model_name="google/vit-base-patch16-224", num_classes=1):
        super(ViTObjectCounter, self).__init__()
        self.vit = ViTModel.from_pretrained(model_name)
        self.regressor = nn.Linear(self.vit.config.hidden_size, num_classes)  # Regression Head

    def forward(self, images):
        outputs = self.vit(images).last_hidden_state[:, 0, :]  # Use CLS token
        count_pred = self.regressor(outputs)
        return count_pred

In [6]:
model = ViTObjectCounter()
print(model)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTObjectCounter(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=76

# ViT Density Model

In [8]:
class ViTDensityModel(nn.Module):
    def __init__(self):
        super(ViTDensityModel, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        
        # Density Map Decoder (Same as Before)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=1)  # Output density map
        )

        # Count Regression Head
        self.count_regressor = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Convert [B, 1, H, W] → [B, 1, 1, 1]
            nn.Flatten(),             # [B, 1, 1, 1] → [B, 1]
            nn.Linear(1, 1)           # Final count prediction
        )

    def forward(self, x):
        features = self.vit(x).last_hidden_state
        features = features[:, 1:, :]  # Remove class token
        b, n, d = features.shape
        h, w = int(n**0.5), int(n**0.5)  # 14x14
        features = features.permute(0, 2, 1).contiguous().view(b, d, h, w)

        # Get density map
        density_map = self.decoder(features)

        # Get count prediction from density map
        count_pred = self.count_regressor(density_map)

        return density_map, count_pred  # Return both outputs

In [9]:
model = ViTDensityModel()
print(model)

ViTDensityModel(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768