<a href="https://colab.research.google.com/github/achika86/Machine-Learning/blob/main/VLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
import math

In [2]:
class HierarchicalVisionEncoder(nn.Module):
    """
    Hierarchical vision encoder that processes images at multiple scales
    and adaptively selects the most relevant features.
    """

    def __init__(self,
                 backbone='resnet50',
                 scales=[224, 448, 896],
                 feature_dim=768,
                 num_patches=196,
                 adaptive_pooling=True):
        super().__init__()

        self.scales = scales
        self.feature_dim = feature_dim
        self.num_patches = num_patches

        # Shared backbone for all scales
        if backbone == 'resnet50':
            self.backbone = resnet50(pretrained=True)
            # Remove final classification layers
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
            backbone_dim = 2048

        # Scale-specific adaptation layers
        self.scale_adapters = nn.ModuleDict({
            str(scale): nn.Sequential(
                nn.Conv2d(backbone_dim, feature_dim, 1),
                nn.LayerNorm([feature_dim]),
                nn.GELU()
            ) for scale in scales
        })

        # Attention mechanism for scale fusion
        self.scale_attention = nn.MultiheadAttention(
            embed_dim=feature_dim,
            num_heads=12,
            batch_first=True
        )

        # Adaptive pooling for consistent output size
        if adaptive_pooling:
            patch_size = int(math.sqrt(num_patches))
            self.adaptive_pool = nn.AdaptiveAvgPool2d((patch_size, patch_size))
        else:
            self.adaptive_pool = None

        # Position embeddings
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_patches, feature_dim) * 0.02
        )

    def encode_scale(self, x, scale):
        """Encode image at a specific scale."""
        # Resize input to target scale
        x_scaled = F.interpolate(
            x, size=(scale, scale),
            mode='bilinear',
            align_corners=False
        )

        # Extract features using backbone
        features = self.backbone(x_scaled)

        # Apply scale-specific adaptation
        adapted_features = self.scale_adapters[str(scale)](features)

        # Adaptive pooling if enabled
        if self.adaptive_pool:
            adapted_features = self.adaptive_pool(adapted_features)

        # Flatten spatial dimensions
        B, C, H, W = adapted_features.shape
        adapted_features = adapted_features.view(B, C, H * W).transpose(1, 2)

        return adapted_features

    def forward(self, x):
        batch_size = x.size(0)

        # Encode at each scale
        scale_features = []
        for scale in self.scales:
            features = self.encode_scale(x, scale)
            scale_features.append(features)

        # Stack features from all scales
        # Shape: [batch_size, num_scales * num_patches, feature_dim]
        all_features = torch.cat(scale_features, dim=1)

        # Apply cross-scale attention
        attended_features, attention_weights = self.scale_attention(
            all_features, all_features, all_features
        )

        # Pool to target number of patches
        if attended_features.size(1) > self.num_patches:
            # Use learned pooling or simple averaging
            pooled_features = F.adaptive_avg_pool1d(
                attended_features.transpose(1, 2),
                self.num_patches
            ).transpose(1, 2)
        else:
            pooled_features = attended_features

        # Add position embeddings
        pooled_features = pooled_features + self.pos_embedding

        return pooled_features, attention_weights

In [3]:
# Assuming the model expects input images in the format [batch_size, channels, height, width]
# Let's create a dummy batch of 4 images, each with 3 color channels (RGB) and a size of 224x224 pixels.
batch_size = 4
channels = 3
height = 224
width = 224

sample_data = torch.randn(batch_size, channels, height, width)

print(f"Sample data tensor shape: {sample_data.shape}")

Sample data tensor shape: torch.Size([4, 3, 224, 224])
