<a href="https://colab.research.google.com/github/VRSFXECE/VRS/blob/main/Copy_of_ret1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTImageProcessor

class EnhancedCattleViT(nn.Module):
    def __init__(self, num_breeds=10, model_name="google/vit-base-patch16-224"):
        super().__init__()

        # Load pretrained ViT backbone (without classification head)
        self.backbone = ViTModel.from_pretrained(model_name)
        hidden_size = self.backbone.config.hidden_size

        # Custom classification head with more capacity
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size // 2, num_breeds)
        )

        # Initialize classifier weights
        self._init_weights(self.classifier)

        self.processor = ViTImageProcessor.from_pretrained(model_name)

    def _init_weights(self, module):
        """Initialize weights for the classifier"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, pixel_values):
        # Get features from backbone
        outputs = self.backbone(pixel_values=pixel_values)

        # Use the [CLS] token representation
        cls_token = outputs.last_hidden_state[:, 0, :]

        # Pass through classifier
        logits = self.classifier(cls_token)

        return logits

    def get_features(self, pixel_values):
        """Extract features before classification"""
        outputs = self.backbone(pixel_values=pixel_values)
        return outputs.last_hidden_state[:, 0, :]  # [CLS] token