# Import libraries

In [1]:
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import os
import numpy as np

# Utility functions

In [2]:
def get_device():
    """Determine and return the available computation device."""
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

def load_sam_model(model_type, checkpoint_path, device):
    """Load the SAM model and move it to the specified device."""
    print(f"Loading SAM model ({model_type})...")
    model = sam_model_registry[model_type](checkpoint=checkpoint_path)
    model.to(device=device)
    print("Model loaded successfully!")
    return model

def calculate_model_size(model, checkpoint_path=None):
    """Calculate and return the model size information."""
    # Calculate model size in memory
    model_size_bytes = 0
    for param in model.parameters():
        model_size_bytes += param.nelement() * param.element_size()
    
    model_size_mb = model_size_bytes / (1024 * 1024)
    
    # Calculate checkpoint file size if provided
    checkpoint_size_mb = 0
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint_size_mb = os.path.getsize(checkpoint_path) / (1024 * 1024)
    
    return {
        "model_size_mb": model_size_mb,
        "checkpoint_size_mb": checkpoint_size_mb
    }

def count_parameters(model):
    """Count and return parameter statistics for the model."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    
    return {
        "total": total_params,
        "trainable": trainable_params,
        "non_trainable": non_trainable_params
    }

def analyze_top_level_components(model):
    """Analyze the top-level components of the model."""
    components = []
    total_params = sum(p.numel() for p in model.parameters())
    
    for name, module in model.named_children():
        num_params = sum(p.numel() for p in module.parameters())
        percent_of_model = (num_params / total_params * 100) if total_params > 0 else 0
        
        components.append({
            "name": name,
            "type": module.__class__.__name__,
            "parameters": num_params,
            "percent_of_model": percent_of_model
        })
    
    return components

# Define attention analysis functions
def analyze_attention_mechanisms(model):
    """Recursively analyze attention mechanisms in the model."""
    attention_info = {}
    
    def _analyze_attention(module, prefix=""):
        module_name = prefix
        if module_name not in attention_info and hasattr(module, "num_heads"):
            attention_info[module_name] = {
                "num_heads": module.num_heads,
                "head_dim": getattr(module, "head_dim", None),
                "embedding_dim": getattr(module, "embedding_dim", None)
            }
        
        # Recursively analyze children
        for name, child in module.named_children():
            child_name = f"{prefix}.{name}" if prefix else name
            _analyze_attention(child, child_name)
    
    # Start recursive analysis
    _analyze_attention(model)
    return attention_info

def analyze_image_encoder_attention(model):
    """Analyze the attention mechanisms in the image encoder."""
    if not hasattr(model, "image_encoder"):
        return {"has_image_encoder": False}
    
    vit = model.image_encoder
    result = {"has_image_encoder": True}
    
    # Analyze transformer blocks
    if hasattr(vit, "blocks"):
        blocks_info = []
        num_layers = len(vit.blocks)
        result["num_layers"] = num_layers
        
        # Check first block for attention details
        if num_layers > 0:
            first_block = vit.blocks[0]
            if hasattr(first_block, "attn"):
                attn = first_block.attn
                num_heads = getattr(attn, "num_heads", None)
                head_dim = getattr(attn, "head_dim", None)
                
                result["attention_heads_per_layer"] = num_heads
                result["head_dimension"] = head_dim
                
                if isinstance(num_heads, int) and isinstance(head_dim, int):
                    result["attention_capacity_per_layer"] = num_heads * head_dim
        
        # Analyze each block's attention
        for i, block in enumerate(vit.blocks):
            if hasattr(block, "attn") and hasattr(block.attn, "num_heads"):
                blocks_info.append({
                    "block_index": i,
                    "num_heads": block.attn.num_heads
                })
        
        result["blocks"] = blocks_info
    
    return result

# Setup the model analysis

In [6]:
model_type = "vit_b"
checkpoint_path = "/Users/haki911/Documents/research/segment-anything/checkpoint/sam_vit_b_01ec64.pth"

# Get device and load model
device = get_device()
print(f"Using device: {device}")

model = load_sam_model(model_type, checkpoint_path, device)

Using device: mps
Loading SAM model (vit_b)...
Model loaded successfully!


# Analyze model size and parameters

In [16]:
# Calculate model size
size_info = calculate_model_size(model, checkpoint_path)
print(f"Model size in memory: {size_info['model_size_mb']:.2f} MB")
print(f"Checkpoint file size: {size_info['checkpoint_size_mb']:.2f} MB")
print("")
# Count parameters
param_info = count_parameters(model)
print(f"Total parameters: {param_info['total']:,}")
# print(f"Trainable parameters: {param_info['trainable']:,}")
# print(f"Non-trainable parameters: {param_info['non_trainable']:,}")

Model size in memory: 357.57 MB
Checkpoint file size: 357.67 MB

Total parameters: 93,735,472


# Analyze top-level components

In [9]:
print(f"{'='*15} TOP-LEVEL COMPONENTS {'='*15}")
components = analyze_top_level_components(model)
for comp in components:
    print(f"- {comp['name']}: {comp['type']}")
    print(f"  Parameters: {comp['parameters']:,} ({comp['percent_of_model']:.2f}% of model)")

- image_encoder: ImageEncoderViT
  Parameters: 89,670,912 (95.66% of model)
- prompt_encoder: PromptEncoder
  Parameters: 6,220 (0.01% of model)
- mask_decoder: MaskDecoder
  Parameters: 4,058,340 (4.33% of model)


# Image Encoder Attention Analysis

In [11]:
# Cell 9: 
print(f"{'='*15} IMAGE ENCODER ATTENTION {'='*15}")
img_encoder_info = analyze_image_encoder_attention(model)

if img_encoder_info.get("has_image_encoder", False):
    print("Image Encoder Attention Analysis:")
    print(f"Number of transformer layers: {img_encoder_info.get('num_layers', 'Unknown')}")
    print(f"Attention heads per layer: {img_encoder_info.get('attention_heads_per_layer', 'Unknown')}")
else:
    print("The model does not have an image encoder component.")

Image Encoder Attention Analysis:
Number of transformer layers: 12
Attention heads per layer: 12


# Attention Mechanisms Analysis

In [12]:
print(f"{'='*15} ALL ATTENTION MECHANISMS {'='*15}")
attention_info = analyze_attention_mechanisms(model)

print("Components with Attention Mechanisms:")
for name, info in attention_info.items():
    heads = info["num_heads"]
    head_dim = info["head_dim"]
    emb_dim = info["embedding_dim"]
    
    print(f"  - {name}:")
    print(f"    Number of heads: {heads}")
    if head_dim:
        print(f"    Head dimension: {head_dim}")
    if emb_dim:
        print(f"    Embedding dimension: {emb_dim}")

Components with Attention Mechanisms:
  - image_encoder.blocks.0.attn:
    Number of heads: 12
  - image_encoder.blocks.1.attn:
    Number of heads: 12
  - image_encoder.blocks.2.attn:
    Number of heads: 12
  - image_encoder.blocks.3.attn:
    Number of heads: 12
  - image_encoder.blocks.4.attn:
    Number of heads: 12
  - image_encoder.blocks.5.attn:
    Number of heads: 12
  - image_encoder.blocks.6.attn:
    Number of heads: 12
  - image_encoder.blocks.7.attn:
    Number of heads: 12
  - image_encoder.blocks.8.attn:
    Number of heads: 12
  - image_encoder.blocks.9.attn:
    Number of heads: 12
  - image_encoder.blocks.10.attn:
    Number of heads: 12
  - image_encoder.blocks.11.attn:
    Number of heads: 12
  - mask_decoder.transformer:
    Number of heads: 8
    Embedding dimension: 256
  - mask_decoder.transformer.layers.0.self_attn:
    Number of heads: 8
    Embedding dimension: 256
  - mask_decoder.transformer.layers.0.cross_attn_token_to_image:
    Number of heads: 8
    E

# Complete Model Architecture

In [13]:
print(f"{'='*15} COMPLETE MODEL ARCHITECTURE {'='*15}")
print(model)

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d()
    )
