# üî¨ Comparative Analysis: AgentViT vs ASEL

## Overview

In this section, we compare two different approaches for efficient Vision Transformer inference:

### **Method 1: ASEL (Previous Cells)**
- **Approach**: Uses a learned patch selector network that assigns importance scores to patches during training.
- **Training**: Selector network trained jointly with ViT using straight-through estimator.
- **Inference**: Top-k patches selected based on importance scores, physically removed from computation.
- **Key Features**:
  - Dynamic budgeting during training (k_ratio varies between 0.15-0.75).
  - Sparsity regularization encourages selecting fewer patches.
  - Three selection policies for comparison: learned, random, central.

### **Method 2: AgentViT (This Section)**
- **Approach**: Reinforcement Learning agent (DQN) learns optimal patch selection policy.
- **Training**: RL agent trained with reward based on classification accuracy and computational efficiency.
- **Inference**: Agent selects patches based on learned Q-values and attention features.
- **Key Features**:
  - Experience replay buffer for stable RL training.
  - Epsilon-greedy exploration strategy.
  - Target network for stable Q-learning.
  - Global memory mechanism to remember class-specific patch importance patterns.

## Comparison Methodology

Both methods will be evaluated on the same three datasets:
1. **AID** (Aerial Image Dataset)
2. **EuroSAT** (Satellite Image Dataset)  
3. **RSSCN7** (Remote Sensing Scene Classification)

We will compare:
- **Accuracy** at different patch retention ratios
- **Computational efficiency** (GFLOPs, throughput, latency)
- **Patch selection strategies** learned by each method
- **Transfer learning capability** from CIFAR-10 to remote sensing domains

---

In [1]:
# ==============================================================================
# COMMON IMPORTS AND DATASET UTILITIES (AgentViT Standalone)
# ==============================================================================
import os
import time
import random
import warnings
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# ------------------------------------------------------------------------------
# Basic configuration shared by AgentViT components (no ASEL dependency)
# ------------------------------------------------------------------------------
CONFIG = {
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_workers': 4,
    'batch_size': 64,
    'resize_dim': 224,
    'save_path': './saved_models',
    'results_path': './benchmarks_results_agentvit',
}

os.makedirs(CONFIG['save_path'], exist_ok=True)
os.makedirs(CONFIG['results_path'], exist_ok=True)

def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(CONFIG['seed'])

# ------------------------------------------------------------------------------
# Dataset loading utilities (mirror Aya_final, but without using ASEL)
# ------------------------------------------------------------------------------
class CleanImageFolder(datasets.ImageFolder):
    def find_classes(self, directory):
        """Ignore hidden folders and .ipynb_checkpoints."""
        classes = sorted(
            entry.name
            for entry in os.scandir(directory)
            if entry.is_dir() and not entry.name.startswith('.')
)
        if not classes:
            raise FileNotFoundError(f"No classes found in {directory}")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

# Paths for custom datasets (adjust if your layout differs)
DATASET_PATHS = {
    'aid': 'AID-data',
    'ucmerced': 'UCMerced_LandUse/Images',
    'rsscn7': './RSSCN7',
}

def get_dataset(name: str):
    """Unified dataset loader for CIFAR-10, EuroSAT, and remote-sensing folders."""
    tf = transforms.Compose([
        transforms.Resize((CONFIG['resize_dim'], CONFIG['resize_dim']), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

    ])
    
    if name == 'cifar10':
        ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=tf)
        test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=tf)
        return ds, test_ds, 10
    elif name == 'eurosat':
        ds = datasets.EuroSAT(root='./data', download=True, transform=tf)
        train_len = int(0.8 * len(ds))
        train_ds, val_ds = random_split(ds, [train_len, len(ds) - train_len])
        return train_ds, val_ds, 10
    else:
        path = DATASET_PATHS.get(name)
        if not path or not os.path.exists(path):
            raise FileNotFoundError(f"Path not found for dataset '{name}': {path}")
        ds = CleanImageFolder(root=path, transform=tf)
        train_len = int(0.8 * len(ds))
        train_ds, test_ds = random_split(ds, [train_len, len(ds) - train_len])
        return train_ds, test_ds, len(ds.classes)

print("‚úì Common imports and dataset utilities for AgentViT loaded (no ASEL dependency)")

‚úì Common imports and dataset utilities for AgentViT loaded (no ASEL dependency)


## AgentViT Model Architecture

The AgentViT implementation includes:
1. **Modified ViT** with dynamic patch masking capability
2. **DQN Agent** for learning optimal patch selection
3. **Experience Replay** for stable RL training
4. **Global Memory** for class-specific patch importance patterns

In [2]:
# ==============================================================================
# AGENTVIT: HELPER FUNCTIONS AND BUILDING BLOCKS
# ==============================================================================
try:
    from einops import rearrange
    from einops.layers.torch import Rearrange
    HAS_EINOPS = True
except ImportError:
    HAS_EINOPS = False
    print("Warning: 'einops' not found. Installing...")
    import subprocess
    subprocess.check_call(['pip', 'install', '-q', 'einops'])
    from einops import rearrange
    from einops.layers.torch import Rearrange
    HAS_EINOPS = True

try:
    import gymnasium as gym
    HAS_GYM = True
except ImportError:
    HAS_GYM = False
    print("Warning: 'gymnasium' not found. Installing...")
    import subprocess
    subprocess.check_call(['pip', 'install', '-q', 'gymnasium'])
    import gymnasium as gym
    HAS_GYM = True

from collections import namedtuple
import math
from sklearn.metrics import precision_score, recall_score, f1_score

def pair(t):
    """Convert single value to tuple pair."""
    return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32):
    """Generate 2D sinusoidal positional embeddings."""
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
    y, x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1)
    omega = 1. / (temperature ** omega)
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)

class FeedForward(nn.Module):
    """Feed-Forward Network in Transformer."""
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    """Multi-Head Self-Attention."""
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)
    
    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    """Stack of Transformer encoder blocks."""
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

print("‚úì AgentViT helper functions and building blocks loaded")

‚úì AgentViT helper functions and building blocks loaded


In [3]:
# ==============================================================================
# AGENTVIT: CORE MODEL WITH DYNAMIC PATCH SELECTION
# ==============================================================================
class SimpleAgentViT(nn.Module):
    """
    Vision Transformer with RL-based dynamic patch selection.
    Allows selective processing of image patches based on RL agent decisions.
    """
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dim_head=64):
        super().__init__()
        self.num_classes = num_classes
        self.selected_patches_mask = []
        
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        
        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by patch size'
        
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.num_patches = num_patches
        
        # Patch embedding
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        
        # Transformer
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
        self.to_latent = nn.Identity()
        
        # Classification head
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, img):
        """Forward pass with dynamic patch masking."""
        x = self.to_patch_embedding(img)
        pe = posemb_sincos_2d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe
        
        # Apply patch mask if available
        if len(self.selected_patches_mask) > 0:
            mask = torch.tensor(self.selected_patches_mask, dtype=torch.bool)
            x = x[:, mask, :]
        
        x = self.transformer(x)
        x = x.mean(dim=1)
        x = self.to_latent(x)
        return self.linear_head(x)
    
    def set_patches(self, action_q_values):
        """Convert Q-values to binary patch selection mask using mean threshold."""
        if isinstance(action_q_values, torch.Tensor):
            action_q_values = action_q_values.cpu().detach().numpy()
        threshold = np.mean(action_q_values)
        self.selected_patches_mask = [1 if val >= threshold else 0 for val in action_q_values]
    
    def get_patches(self):
        """Get current patch selection mask."""
        return self.selected_patches_mask
    
    def get_att(self, data):
        """Extract attention features for RL state representation."""
        with torch.no_grad():
            x = self.to_patch_embedding(data)
            pe = posemb_sincos_2d(x)
            x = rearrange(x, 'b ... d -> b (...) d') + pe
            
            # Get output from first attention layer
            if len(self.transformer.layers) > 0:
                attn_layer, _ = self.transformer.layers[0]
                x_normed = attn_layer.norm(x)
                qkv = attn_layer.to_qkv(x_normed).chunk(3, dim=-1)
                q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=attn_layer.heads), qkv)
                dots = torch.matmul(q, k.transpose(-1, -2)) * attn_layer.scale
                attn = attn_layer.attend(dots)
                # Average attention across heads and aggregate
                attn_features = attn.mean(dim=1).mean(dim=1)  # (batch, num_patches)
                return attn_features
            else:
                return x.mean(dim=-1)

print("‚úì SimpleAgentViT model loaded")

‚úì SimpleAgentViT model loaded


In [4]:
# ==============================================================================
# AGENTVIT: RL ENVIRONMENT AND AGENT
# ==============================================================================

class ContinuousActionSpace:
    """Continuous action space for patch selection."""
    def __init__(self, num_patches, device):
        self.num_patches = num_patches
        self.device = device
    
    def sample(self):
        """Sample random action for exploration."""
        random_action = np.random.rand(self.num_patches)
        return torch.tensor(random_action, device=self.device, dtype=torch.float)

class ViTEnv(gym.Env):
    """Custom RL environment for ViT patch selection."""
    def __init__(self, vit_model, num_patches, optimizer, loss_weight, efficiency_weight,
                 device, target_num_patches=1):
        super().__init__()
        self.vit_model = vit_model
        self.optimizer = optimizer
        self.loss_weight = loss_weight
        self.efficiency_weight = efficiency_weight
        self.action_space = ContinuousActionSpace(num_patches, device)
        self.device = device
        self.target_num_patches = target_num_patches
        self.train_loss_history = []
        self.train_time_history = []
    
    def step_train(self, action_q_values, train_data, train_target):
        """Training step without reward computation."""
        self.vit_model.set_patches(action_q_values)
        self.vit_model.train()
        self.optimizer.zero_grad()
        output = F.log_softmax(self.vit_model(train_data), dim=1)
        loss = F.nll_loss(output, train_target)
        loss.backward()
        self.optimizer.step()
    
    def step_reward(self, action_q_values, train_data, train_target):
        """Training step with reward computation."""
        self.vit_model.set_patches(action_q_values)
        binary_mask = self.vit_model.get_patches()
        
        # Train and compute loss
        start_time = time.time()
        self.vit_model.train()
        self.optimizer.zero_grad()
        output = F.log_softmax(self.vit_model(train_data), dim=1)
        loss = F.nll_loss(output, train_target)
        loss.backward()
        self.optimizer.step()
        iteration_time = time.time() - start_time
        
        self.train_loss_history.append(loss.item())
        self.train_time_history.append(iteration_time)
        
        # Compute reward
        num_selected = binary_mask.count(1)
        efficiency_reward = -abs(num_selected - self.target_num_patches) / self.target_num_patches
        loss_improvement = self.train_loss_history[0] / (loss.item() + 1e-8)
        reward = loss_improvement * self.loss_weight + efficiency_reward * self.efficiency_weight
        
        next_state = self.get_state(train_data)
        return next_state, reward
    
    def get_state(self, data):
        """Get state representation (attention features)."""
        return self.vit_model.get_att(data)

Experience = namedtuple('Experience', ('state', 'next_state', 'reward'))

class ReplayBuffer:
    """Experience replay buffer for DQN."""
    def __init__(self, capacity, sample_batch_size):
        self.sample_batch_size = sample_batch_size
        self.memory = []
        self.capacity = capacity
    
    def push(self, *args):
        if len(self.memory) >= self.capacity:
            self.memory.pop(random.randint(0, len(self.memory) - 1))
        self.memory.append(Experience(*args))
    
    def sample(self):
        return random.sample(self.memory, self.sample_batch_size)
    
    def __len__(self):
        return len(self.memory)

class QNetwork(nn.Module):
    """Q-Network for predicting patch importance."""
    def __init__(self, num_patches):
        super().__init__()
        input_dim = num_patches * 2  # state + memory
        self.fc_layers = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, num_patches)
        )
    
    def forward(self, state):
        return self.fc_layers(state)

class DQNAgent:
    """DQN Agent for learning patch selection."""
    def __init__(self, replay_batch_size, num_patches, buffer_capacity, gamma, tau,
                 update_frequency, learning_rate, env, device, num_classes=None):
        self.replay_batch_size = replay_batch_size
        self.gamma = gamma
        self.tau = tau
        self.steps_since_update = 0
        self.update_frequency = update_frequency
        self.device = device
        self.env = env
        self.num_patches = num_patches
        self.num_classes = num_classes
        
        # Global memory for class-specific patterns
        if num_classes:
            self.global_memory = torch.zeros(num_classes, num_patches, device=device, dtype=torch.float32)
        else:
            self.global_memory = None
        self.memory_alpha = 0.01
        
        # Q-networks
        self.q_network = QNetwork(num_patches).to(device)
        self.target_network = QNetwork(num_patches).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        self.optimizer = optim.AdamW(self.q_network.parameters(), lr=learning_rate, amsgrad=True)
        self.memory = ReplayBuffer(buffer_capacity, replay_batch_size)
    
    def augment_state(self, state, labels=None):
        """Augment state with class-specific memory."""
        state = state.to(self.device)
        batch_size = state.shape[0]
        
        if self.global_memory is None or self.num_classes is None:
            mem = torch.zeros((batch_size, self.num_patches), device=self.device, dtype=state.dtype)
        else:
            if labels is None:
                mean_mem = torch.mean(self.global_memory, dim=0, keepdim=True)
                mem = mean_mem.repeat(batch_size, 1)
            else:
                labels_cpu = labels.detach().cpu().long().numpy()
                mem_list = [self.global_memory[int(c)].unsqueeze(0) for c in labels_cpu]
                mem = torch.cat(mem_list, dim=0).to(self.device)
        
        return torch.cat([state, mem], dim=1)
    
    def update_global_memory(self, labels, binary_mask):
        """Update class-specific memory with EMA."""
        if self.global_memory is None:
            return
        if not isinstance(binary_mask, torch.Tensor):
            mask = torch.tensor(binary_mask, dtype=torch.float32, device=self.device)
        else:
            mask = binary_mask.to(self.device).float()
        
        labels_cpu = labels.detach().cpu().long().numpy()
        for c in labels_cpu:
            c = int(c)
            self.global_memory[c] = (1.0 - self.memory_alpha) * self.global_memory[c] + self.memory_alpha * mask
    
    def select_action(self, state, labels=None, epsilon=0.):
        """Epsilon-greedy action selection."""
        if random.random() > epsilon:
            with torch.no_grad():
                augmented = self.augment_state(state, labels)
                batch_q_values = self.q_network(augmented)
                action_q_values = torch.mean(batch_q_values, dim=0)
                return action_q_values
        else:
            return self.env.action_space.sample()
    
    def optimize_model(self):
        """DQN optimization step."""
        if len(self.memory) < self.replay_batch_size:
            return
        
        experiences = self.memory.sample()
        batch = Experience(*zip(*experiences))
        
        state_batch = torch.cat(batch.state).to(self.device)
        next_state_batch = torch.cat(batch.next_state).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)
        
        current_q = self.q_network(state_batch)
        with torch.no_grad():
            next_q = self.target_network(next_state_batch)
        
        target_q = (next_q * self.gamma) + reward_batch.unsqueeze(1)
        loss = F.smooth_l1_loss(current_q, target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.q_network.parameters(), 100)
        self.optimizer.step()
        
        self.steps_since_update += 1
        if self.steps_since_update >= self.update_frequency:
            self._soft_update()
            self.steps_since_update = 0
    
    def _soft_update(self):
        """Soft update of target network."""
        target_dict = self.target_network.state_dict()
        online_dict = self.q_network.state_dict()
        for key in online_dict:
            target_dict[key] = online_dict[key] * self.tau + target_dict[key] * (1 - self.tau)
        self.target_network.load_state_dict(target_dict)

print("‚úì AgentViT RL components loaded")

‚úì AgentViT RL components loaded


## AgentViT Training and Evaluation Functions

Training orchestrator that jointly trains the ViT and RL agent.

In [5]:
# ==============================================================================
# AGENTVIT: TRAINING ORCHESTRATOR
# ==============================================================================
class AgentViTTrainer:
    """Orchestrates joint training of ViT and DQN agent."""
    def __init__(self, vit_model, num_patches, num_epochs, env, train_loader, test_loader, device,
                 dqn_config):
        self.env = env
        self.num_epochs = num_epochs
        self.vit_model = vit_model
        self.num_patches = num_patches
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        
        # Extract DQN config
        self.epsilon_start = dqn_config.get('epsilon_start', 1.0)
        self.epsilon_end = dqn_config.get('epsilon_end', 0.01)
        self.epsilon_decay = dqn_config.get('epsilon_decay', 20000)
        self.reward_frequency = dqn_config.get('reward_frequency', 50)
        
        # Create DQN agent
        self.dqn_agent = DQNAgent(
            replay_batch_size=dqn_config.get('replay_batch_size', 8),
            num_patches=num_patches,
            buffer_capacity=dqn_config.get('buffer_capacity', 64),
            gamma=dqn_config.get('gamma', 0.95),
            tau=dqn_config.get('tau', 0.1),
            update_frequency=dqn_config.get('update_frequency', 2),
            learning_rate=dqn_config.get('learning_rate', 0.01),
            env=self.env,
            device=device,
            num_classes=vit_model.num_classes
        )
        
        self.val_acc_history = []
        self.val_loss_history = []
    
    def train(self):
        """Main training loop."""
        total_iterations = 0
        epsilon = self.epsilon_start
        
        for epoch in range(self.num_epochs):
            print(f'\n{"="*60}\nEpoch {epoch+1}/{self.num_epochs}\n{"="*60}')
            
            for batch_idx, (images, labels) in enumerate(self.train_loader):
                images, labels = images.to(self.device), labels.to(self.device)
                total_iterations += 1
                
                # Get state and select action
                state = self.env.get_state(images)
                action = self.dqn_agent.select_action(state, labels, epsilon)
                
                # Training step
                if batch_idx % self.reward_frequency != 0:
                    self.env.step_train(action, images, labels)
                else:
                    next_state, reward = self.env.step_reward(action, images, labels)
                    reward_tensor = torch.full((images.size(0),), reward, dtype=torch.float32)
                    
                    if epoch > 0:  # Skip first epoch
                        binary_mask = self.vit_model.get_patches()
                        if binary_mask:
                            self.dqn_agent.update_global_memory(labels, binary_mask)
                        
                        aug_state = self.dqn_agent.augment_state(state, labels)
                        aug_next = self.dqn_agent.augment_state(next_state, labels)
                        self.dqn_agent.memory.push(aug_state, aug_next, reward_tensor)
                        self.dqn_agent.optimize_model()
                
                # Decay epsilon
                epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                          math.exp(-1. * total_iterations / self.epsilon_decay)
            
            # Validation
            val_loss, val_acc = self.evaluate()
            self.val_acc_history.append(val_acc)
            self.val_loss_history.append(val_loss)
            print(f'Validation - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%')
    
    def evaluate(self):
        """Evaluate on test set with all patches."""
        original_mask = self.vit_model.get_patches()
        self.vit_model.set_patches(torch.ones(self.num_patches))
        
        self.vit_model.eval()
        total_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                output = F.log_softmax(self.vit_model(images), dim=1)
                loss = F.nll_loss(output, labels, reduction='sum')
                total_loss += loss.item()
                pred = output.argmax(dim=1)
                correct += pred.eq(labels).sum().item()
                total += labels.size(0)
        
        if original_mask:
            self.vit_model.set_patches(torch.tensor(original_mask, dtype=torch.float))
        
        avg_loss = total_loss / total
        accuracy = 100.0 * correct / total
        return avg_loss, accuracy

print("‚úì AgentViT trainer loaded")

‚úì AgentViT trainer loaded


## AgentViT Benchmarking Functions

Functions to evaluate AgentViT performance at different patch retention ratios, matching the ASEL evaluation methodology.

In [6]:
# ==============================================================================
# AGENTVIT: BENCHMARKING UTILITIES
# ==============================================================================
class AgentViTBenchmark:
    """Benchmarking utilities for AgentViT evaluation."""
    
    @staticmethod
    def set_patch_ratio(model, k_ratio):
        """Set a specific number of patches to be selected."""
        num_patches = model.num_patches
        k = max(1, int(num_patches * k_ratio))
        # Set top-k patches to 1, rest to 0
        mask = [1 if i < k else 0 for i in range(num_patches)]
        model.selected_patches_mask = mask
    
    @staticmethod
    def evaluate_accuracy_at_ratio(model, loader, device, k_ratio, use_learned_selection=False, dqn_agent=None):
        """
        Evaluate accuracy at a specific patch retention ratio.
        
        Args:
            model: AgentViT model
            loader: Data loader
            device: Device
            k_ratio: Patch retention ratio (0.0 to 1.0)
            use_learned_selection: If True and dqn_agent provided, use agent's selection
            dqn_agent: Optional DQN agent for learned selection
        """
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                
                if use_learned_selection and dqn_agent is not None:
                    # Use learned patch selection from agent
                    state = model.get_att(images)
                    action = dqn_agent.select_action(state, labels, epsilon=0.0)
                    model.set_patches(action)
                    
                    # Adjust to meet k_ratio constraint
                    current_mask = model.get_patches()
                    num_selected = sum(current_mask)
                    target_k = max(1, int(model.num_patches * k_ratio))
                    
                    if num_selected != target_k:
                        # Adjust mask to match target ratio
                        if isinstance(action, torch.Tensor):
                            action_np = action.cpu().numpy()
                        else:
                            action_np = action
                        top_k_indices = np.argsort(action_np)[-target_k:]
                        new_mask = [1 if i in top_k_indices else 0 for i in range(model.num_patches)]
                        model.selected_patches_mask = new_mask
                else:
                    # Use simple top-k selection
                    AgentViTBenchmark.set_patch_ratio(model, k_ratio)
                
                # Forward pass
                output = model(images)
                pred = output.argmax(dim=1)
                correct += pred.eq(labels).sum().item()
                total += labels.size(0)
        
        return correct / total if total > 0 else 0.0
    
    @staticmethod
    def measure_metrics(model, device, k_ratio):
        """
        Measure computational metrics (GFLOPs, latency, throughput).
        Simplified version matching Benchmark class from Prunable ViT.
        """
        dummy_input = torch.randn(1, 3, 224, 224).to(device)
        batch_input = torch.randn(64, 3, 224, 224).to(device)
        
        # Set patch ratio
        AgentViTBenchmark.set_patch_ratio(model, k_ratio)
        
        # Estimate GFLOPs (theoretical approximation)
        gflops = 1.1 * k_ratio  # Simplified estimate
        
        # Measure latency and throughput
        model.eval()
        start_event = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        end_event = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        
        with torch.no_grad():
            _ = model(batch_input)  # Warmup
            
            if torch.cuda.is_available():
                start_event.record()
                for _ in range(50):
                    _ = model(batch_input)
                end_event.record()
                torch.cuda.synchronize()
                total_time_ms = start_event.elapsed_time(end_event)
            else:
                start = time.time()
                for _ in range(50):
                    _ = model(batch_input)
                total_time_ms = (time.time() - start) * 1000
            
            latency_ms = total_time_ms / 50
            throughput = (64 * 50) / (total_time_ms / 1000)
        
        return gflops, latency_ms, throughput

def run_agentvit_benchmarks(model, test_loader, ds_name, device, dqn_agent=None):
    """
    Run comprehensive benchmarks for AgentViT, matching Prunable ViT evaluation.
    
    Args:
        model: Trained AgentViT model
        test_loader: Test data loader
        ds_name: Dataset name for plot labels
        device: Device
        dqn_agent: Optional trained DQN agent for learned selection
    """
    print(f"\n{'='*60}\nRunning AgentViT Benchmarks for {ds_name}...\n{'='*60}")
    
    ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
    res = {
        'ratios': ratios,
        'acc_learned': [],
        'acc_random': [],
        'gflops': [],
        'thr': [],
        'lat': []
    }
    
    # Baseline (100%)
    full_acc = AgentViTBenchmark.evaluate_accuracy_at_ratio(
        model, test_loader, device, 1.0, use_learned_selection=False
    )
    full_gflops, full_lat, full_thr = AgentViTBenchmark.measure_metrics(model, device, 1.0)
    
    for r in ratios:
        print(f"Evaluating at ratio {r:.1f}...")
        
        # Learned selection (using DQN agent if available)
        acc_learned = AgentViTBenchmark.evaluate_accuracy_at_ratio(
            model, test_loader, device, r, use_learned_selection=(dqn_agent is not None), dqn_agent=dqn_agent
        )
        
        # Random baseline (simple top-k without learning)
        acc_random = AgentViTBenchmark.evaluate_accuracy_at_ratio(
            model, test_loader, device, r, use_learned_selection=False
        )
        
        # Metrics
        gf, lat, thr = AgentViTBenchmark.measure_metrics(model, device, r)
        
        res['acc_learned'].append(acc_learned)
        res['acc_random'].append(acc_random)
        res['gflops'].append(gf)
        res['thr'].append(thr)
        res['lat'].append(lat)
        
        print(f"  Ratio {r:.1f} | Learned-Acc: {acc_learned:.1%} | Random-Acc: {acc_random:.1%} | FPS: {thr:.0f}")
    
    # --- PLOTTING (matching Prunable ViT style) ---
    results_path = CONFIG['results_path']
    
    # 1. Strategy Comparison
    plt.figure(figsize=(8, 6))
    plt.plot(ratios, [x*100 for x in res['acc_learned']], 'r-o', lw=2, label='AgentViT (Learned)')
    plt.plot(ratios, [x*100 for x in res['acc_random']], 'k--x', alpha=0.5, label='Random')
    plt.scatter([1.0], [full_acc*100], c='k', marker='*', s=200, zorder=10, label='Full ViT')
    plt.title(f'{ds_name}: AgentViT Strategy Comparison')
    plt.xlabel('Keep Ratio')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.5)
    plt.legend()
    plt.savefig(f"{results_path}/{ds_name}_agentvit_1_strategies.png")
    plt.close()
    
    # 2. Acc vs Throughput
    plt.figure(figsize=(8, 6))
    plt.plot(res['thr'], [x*100 for x in res['acc_learned']], 'g-o', lw=2, label='AgentViT')
    plt.scatter([full_thr], [full_acc*100], c='k', marker='*', s=200, label='Full ViT')
    plt.title(f'{ds_name}: AgentViT Accuracy vs Throughput')
    plt.xlabel('Throughput (img/s)')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.5)
    plt.legend()
    plt.savefig(f"{results_path}/{ds_name}_agentvit_2_throughput.png")
    plt.close()
    
    # 3. GFLOPs
    plt.figure(figsize=(8, 6))
    plt.plot(ratios, res['gflops'], 'm-o', lw=2, label='AgentViT')
    plt.axhline(y=full_gflops, c='k', ls='--', label='Full ViT')
    plt.title(f'{ds_name}: AgentViT GFLOPs Reduction')
    plt.xlabel('Keep Ratio')
    plt.ylabel('GFLOPs')
    plt.grid(True, alpha=0.5)
    plt.legend()
    plt.savefig(f"{results_path}/{ds_name}_agentvit_3_gflops.png")
    plt.close()
    
    # 4. Latency Bar
    lat_50 = res['lat'][4]  # Ratio 0.5
    plt.figure(figsize=(6, 6))
    plt.bar(['Full ViT', 'AgentViT (50%)'], [full_lat, lat_50], color=['gray', 'green'], width=0.5)
    plt.title(f'{ds_name}: AgentViT Batch Latency')
    plt.ylabel('Time (ms)')
    plt.text(0, full_lat, f"{full_lat:.1f}ms", ha='center', va='bottom', fontweight='bold')
    plt.text(1, lat_50, f"{lat_50:.1f}ms", ha='center', va='bottom', fontweight='bold')
    plt.savefig(f"{results_path}/{ds_name}_agentvit_4_latency.png")
    plt.close()
    
    print(f"Finished AgentViT benchmarks for {ds_name}. Plots saved.\n")
    
    return res

print("‚úì AgentViT benchmarking functions loaded")

‚úì AgentViT benchmarking functions loaded


## AgentViT Experimental Pipeline

Testing AgentViT on the same three datasets used for ASEL evaluation (AID, EuroSAT, RSSCN7),
but **training AgentViT independently on each dataset** (no CIFAR-10 transfer learning).

In [7]:
# ==============================================================================
# AGENTVIT: MAIN EXPERIMENTAL PIPELINE
# ==============================================================================
def run_agentvit_experiments():
    """
    Main pipeline for AgentViT experiments.

    Trains and evaluates AgentViT **separately on each dataset** (AID, EuroSAT, RSSCN7),
    without any CIFAR-10 warmup or transfer learning, so that AgentViT keeps its
    original training scheme while still being evaluated on the same datasets as ASEL.
    """
    print(f"\n{'#'*80}\n# AGENTVIT EXPERIMENTAL PIPELINE\n{'#'*80}\n")

    # Configuration for AgentViT (keeps original structure, no transfer learning)
    AGENTVIT_CONFIG = {
        'seed': 42,
        'device': CONFIG['device'],
        'batch_size': 32,  # Smaller batch for RL stability
        'num_workers': 4,
        'resize_dim': 224,

        # ViT architecture (matching image size)
        'patch_size': 16,  # 224/16 = 14x14 = 196 patches
        'embed_dim': 128,
        'depth': 6,
        'heads': 8,
        'mlp_dim': 512,
        'vit_lr': 1e-3,

        # RL/DQN parameters
        'epsilon_start': 1.0,
        'epsilon_end': 0.01,
        'epsilon_decay': 20000,
        'replay_batch_size': 8,
        'buffer_capacity': 64,
        'gamma': 0.95,
        'tau': 0.1,
        'update_frequency': 2,
        'dqn_lr': 0.01,
        'reward_frequency': 50,

        # Reward function
        'target_patches': 98,  # ~50% of 196 patches
        'efficiency_weight': 20.0,
        'loss_weight': 1.0,

        # Training epochs per dataset
        'num_epochs': 20,
    }

    set_seed(AGENTVIT_CONFIG['seed'])

    # ------------------------------------------------------------------
    # TRAIN AND EVALUATE ON EACH TARGET DATASET (NO TRANSFER)
    # ------------------------------------------------------------------
    target_datasets = ['aid', 'eurosat', 'rsscn7']

    # Number of patches is fixed by image/patch size
    num_patches_per_side = AGENTVIT_CONFIG['resize_dim'] // AGENTVIT_CONFIG['patch_size']
    num_patches = num_patches_per_side * num_patches_per_side

    for ds_name in target_datasets:
        print(f"\n{'='*60}\nTRAINING AGENTVIT ON {ds_name.upper()} (from scratch)\n{'='*60}")

        try:
            # Load dataset
            train_ds, test_ds, n_cls = get_dataset(ds_name)
            train_loader = DataLoader(
                train_ds,
                batch_size=AGENTVIT_CONFIG['batch_size'],
                shuffle=True,
                num_workers=AGENTVIT_CONFIG['num_workers'],
            )
            test_loader = DataLoader(
                test_ds,
                batch_size=AGENTVIT_CONFIG['batch_size'],
                shuffle=False,
                num_workers=AGENTVIT_CONFIG['num_workers'],
                drop_last=True,
            )

            # Create AgentViT model for this dataset
            agentvit_model = SimpleAgentViT(
                image_size=AGENTVIT_CONFIG['resize_dim'],
                patch_size=AGENTVIT_CONFIG['patch_size'],
                num_classes=n_cls,
                dim=AGENTVIT_CONFIG['embed_dim'],
                depth=AGENTVIT_CONFIG['depth'],
                heads=AGENTVIT_CONFIG['heads'],
                mlp_dim=AGENTVIT_CONFIG['mlp_dim'],
            ).to(AGENTVIT_CONFIG['device'])

            # Optimizer for ViT
            vit_optimizer = optim.Adam(agentvit_model.parameters(), lr=AGENTVIT_CONFIG['vit_lr'])

            # Create RL environment
            env = ViTEnv(
                vit_model=agentvit_model,
                num_patches=num_patches,
                optimizer=vit_optimizer,
                loss_weight=AGENTVIT_CONFIG['loss_weight'],
                efficiency_weight=AGENTVIT_CONFIG['efficiency_weight'],
                device=AGENTVIT_CONFIG['device'],
                target_num_patches=AGENTVIT_CONFIG['target_patches'],
            )

            # DQN config
            dqn_config = {
                'epsilon_start': AGENTVIT_CONFIG['epsilon_start'],
                'epsilon_end': AGENTVIT_CONFIG['epsilon_end'],
                'epsilon_decay': AGENTVIT_CONFIG['epsilon_decay'],
                'replay_batch_size': AGENTVIT_CONFIG['replay_batch_size'],
                'buffer_capacity': AGENTVIT_CONFIG['buffer_capacity'],
                'gamma': AGENTVIT_CONFIG['gamma'],
                'tau': AGENTVIT_CONFIG['tau'],
                'update_frequency': AGENTVIT_CONFIG['update_frequency'],
                'learning_rate': AGENTVIT_CONFIG['dqn_lr'],
                'reward_frequency': AGENTVIT_CONFIG['reward_frequency'],
            }

            # Create trainer (keeps original AgentViT training logic)
            trainer = AgentViTTrainer(
                vit_model=agentvit_model,
                num_patches=num_patches,
                num_epochs=AGENTVIT_CONFIG['num_epochs'],
                env=env,
                train_loader=train_loader,
                test_loader=test_loader,
                device=AGENTVIT_CONFIG['device'],
                dqn_config=dqn_config,
            )

            # Train on this dataset
            print(f"\nTraining AgentViT on {ds_name}...")
            trainer.train()

            # Benchmark on this dataset
            print(f"\nBenchmarking AgentViT on {ds_name}...")
            run_agentvit_benchmarks(
                agentvit_model,
                test_loader,
                ds_name,
                AGENTVIT_CONFIG['device'],
                dqn_agent=trainer.dqn_agent,
            )

            # Save trained model for this dataset
            save_path = f"{CONFIG['save_path']}/agentvit_{ds_name}.pth"
            torch.save({
                'model_state': agentvit_model.state_dict(),
                'dqn_state': trainer.dqn_agent.q_network.state_dict(),
                'config': AGENTVIT_CONFIG,
            }, save_path)
            print(f"‚úì Saved AgentViT model for {ds_name} to {save_path}")

        except Exception as e:
            print(f"Error processing {ds_name}: {e}")
            import traceback
            traceback.print_exc()
            continue

    print(f"\n{'#'*80}\n# AGENTVIT EXPERIMENTS COMPLETED\n{'#'*80}\n")

print("‚úì AgentViT experimental pipeline ready (no transfer learning)")

‚úì AgentViT experimental pipeline ready (no transfer learning)


## Execute AgentViT Experiments

Run this cell to execute the complete AgentViT experimental pipeline on all three datasets.

In [8]:
# ==============================================================================
# RUN AGENTVIT EXPERIMENTS
# ==============================================================================
# Uncomment the line below to run the AgentViT experiments
# Note: This will take several hours depending on your hardware,
#!/
#! Recommended to start with one dataset (e.g., only AID) by
#! temporarily editing target_datasets inside run_agentvit_experiments().

# run_agentvit_experiments()

print("""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                   AGENTVIT EXPERIMENTS READY TO RUN                    ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

To execute the AgentViT experiments:
1. Uncomment the line: run_agentvit_experiments()
2. Run this cell

The pipeline will:
- Train AgentViT **from scratch** on each of: AID, EuroSAT, and RSSCN7
- Use the original RL-based AgentViT training logic (no CIFAR-10 transfer)
- Generate benchmark plots for each dataset (accuracy vs keep ratio, throughput, GFLOPs, latency)
- Save all results to the configured paths

Expected runtime: several hours in total (depending on GPU)

Results will be saved to: {0}
Models will be saved to: {1}
""".format(CONFIG['results_path'], CONFIG['save_path']))


‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                   AGENTVIT EXPERIMENTS READY TO RUN                    ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

To execute the AgentViT experiments:
1. Uncomment the line: run_agentvit_experiments()
2. Run this cell

The pipeline will:
- Train AgentViT **from scratch** on each of: AID, EuroSAT, and RSSCN7
- Use the original RL-based AgentViT training logic (no CIFAR-10 transfer)
- Generate benchmark plots for each dataset (accuracy vs keep ratio, throughput, GFLOPs, latency)
- Save all results to the configured paths

Expected runtime: several hours in total (depending on GPU)



In [9]:
# Run the AgentViT experimental pipeline (DISABLED to avoid retraining)
# If you ever want to retrain AgentViT models, uncomment the lines below
# and run this cell manually.
# if __name__ == "__main__":
#     run_agentvit_experiments()

print("AgentViT training is disabled. Notebook will only use existing checkpoints in 'saved_models' and run comparisons.")

AgentViT training is disabled. Notebook will only use existing checkpoints in 'saved_models' and run comparisons.


## Conceptual Summary: ASEL vs AgentViT

### ASEL (Original Pipeline)
- Learns a continuous importance score for each patch using a small selector network.
- During training, a straight-through estimator is used to keep top-k patches while still allowing gradients to flow.
- At inference, the model physically prunes away low-importance patches (learned, random, central policies).
- Uses a **CIFAR-10 warmup + transfer learning** strategy: first train on CIFAR-10, then fine-tune on AID, EuroSAT, and RSSCN7.
- Evaluation measures how accuracy changes as we reduce the number of patches, along with GFLOPs, throughput, and latency.

### AgentViT (Added Pipeline)
- Treats patch selection as a reinforcement learning problem.
- A DQN agent observes attention-based features from the ViT and outputs Q-values for each patch.
- Q-values are converted into binary masks (selected/discarded) and used to restrict computation to selected patches.
- The reward balances two components:
  - **Classification quality**: how much the loss improves
  - **Efficiency**: how close the number of selected patches is to a target budget
- The agent is trained with experience replay, target networks, and epsilon-greedy exploration.
- In this baseline notebook, **AgentViT is trained independently on each dataset (AID, EuroSAT, RSSCN7)**, without CIFAR-10 warmup or transfer learning.

### Comparison Protocol
- **ASEL**: CIFAR-10 warmup, then transfer and fine-tune on AID, EuroSAT, and RSSCN7.
- **AgentViT**: Direct training on each target dataset from scratch using its RL-based patch selection.
- **Metrics**:
  - Accuracy at multiple keep ratios (10%‚Äì100% of patches)
  - GFLOPs reduction as patches are pruned
  - Throughput (images per second) and latency (ms) for a fixed batch size
- **Interpretation**:
  - If AgentViT maintains higher accuracy at lower keep ratios, it indicates a more effective patch selection policy on that dataset, even without transfer learning.
  - If ASEL achieves similar accuracy with simpler training and transfer, it may be more practical despite lacking RL flexibility.

This setup allows a controlled, side-by-side comparison between a **deterministic learned selector** (ASEL, with transfer learning) and an **RL-based adaptive selector** (AgentViT, trained per dataset) under the same datasets, image resolution, and evaluation metrics.

## Comparative Plots from Saved Models

This section loads the **saved ASEL checkpoints** (`aid_finetuned.pth`, `eurosat_finetuned.pth`, `rsscn7_finetuned.pth`)
and the **saved AgentViT checkpoints** (`agentvit_aid.pth`, `agentvit_eurosat.pth`, `agentvit_rsscn7.pth`) from the `saved_models`
directory, then computes and plots **comparative curves** for ASEL vs AgentViT on each dataset.

In [10]:
# ==============================================================================
# ASEL MODEL DEFINITION AND BENCHMARK HELPERS (FROM ASEL2, INFERENCE ONLY)
# ==============================================================================
import timm

try:
    from fvcore.nn import FlopCountAnalysis
    HAS_FVCORE = True
except ImportError:
    HAS_FVCORE = False
    print("Warning: 'fvcore' not found. ASEL GFLOPs will be estimated theoretically.")

class ASEL(nn.Module):
    """
    ASEL model (same as PrunableViT in ASEL2) used here for
    loading saved checkpoints and running benchmarks (no training).
    """
    def __init__(self, num_classes, pretrained=False):
        super().__init__()
        self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained, num_classes=num_classes)
        self.embed_dim = self.backbone.embed_dim
        self.patch_selector = nn.Sequential(
            nn.Linear(self.embed_dim * 2, 96),
            nn.LayerNorm(96),
            nn.ReLU(),
            nn.Linear(96, 1),
            nn.Sigmoid(),
)
        H, W = 14, 14
        center = (H - 1) / 2.0
        y, x = np.ogrid[:H, :W]
        dist = (x - center) ** 2 + (y - center) ** 2
        self.central_indices = torch.from_numpy(np.argsort(dist.flatten())).long()

    def _get_patch_embeddings(self, x):
        x = self.backbone.patch_embed(x)
        x = x + self.backbone.pos_embed[:, 1:]
        return x

    def _process_transformer(self, x_patches):
        B = x_patches.shape[0]
        cls_token = self.backbone.cls_token.expand(B, -1, -1) + self.backbone.pos_embed[:, :1]
        x = torch.cat((cls_token, x_patches), dim=1)
        x = self.backbone.pos_drop(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        return self.backbone.head(x[:, 0])

    def _compute_importance_scores(self, x_patches):
        global_feat = x_patches.mean(dim=1, keepdim=True).expand(-1, x_patches.shape[1], -1)
        selector_input = torch.cat([x_patches, global_feat], dim=-1)
        return self.patch_selector(selector_input).squeeze(-1)

    def forward_inference(self, x_images, k_ratio, policy='learned'):
        x_patches = self._get_patch_embeddings(x_images)
        B, N, D = x_patches.shape
        k = int(N * k_ratio)
        if k < 1:
            k = 1

        if policy == 'learned':
            scores = self._compute_importance_scores(x_patches)
            _, topk_idx = torch.topk(scores, k, dim=1)
        elif policy == 'random':
            topk_idx = torch.stack([torch.randperm(N)[:k] for _ in range(B)]).to(x_patches.device)
        elif policy == 'central':
            indices = self.central_indices[:k].to(x_patches.device)
            topk_idx = indices.unsqueeze(0).expand(B, -1)
        else:
            raise ValueError("Unknown policy for ASEL")

        topk_idx_expanded = topk_idx.unsqueeze(-1).expand(-1, -1, D)
        x_kept = torch.gather(x_patches, 1, topk_idx_expanded)
        return self._process_transformer(x_kept)

class ASELBenchmark:
    @staticmethod
    def measure_metrics(model, device, k_ratio, policy='learned'):
        dummy_input = torch.randn(1, 3, 224, 224).to(device)
        batch_input = torch.randn(64, 3, 224, 224).to(device)

        class Wrapper(nn.Module):
            def __init__(self, m, k, p):
                super().__init__()
                self.m, self.k, self.p = m, k, p
            def forward(self, x):
                return self.m.forward_inference(x, self.k, self.p)

        wrapped_model = Wrapper(model, k_ratio, policy).to(device)

        if HAS_FVCORE:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                flops_counter = FlopCountAnalysis(wrapped_model, dummy_input)
                flops_counter.unsupported_ops_warnings(False)
                gflops = flops_counter.total() / 1e9
        else:
            gflops = 1.1 * k_ratio

        model.eval()
        start_event = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        end_event = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None

        with torch.no_grad():
            _ = model.forward_inference(batch_input, k_ratio, policy)
            if torch.cuda.is_available():
                start_event.record()
                for _ in range(50):
                    _ = model.forward_inference(batch_input, k_ratio, policy)
                end_event.record()
                torch.cuda.synchronize()
                total_time_ms = start_event.elapsed_time(end_event)
            else:
                start = time.time()
                for _ in range(50):
                    _ = model.forward_inference(batch_input, k_ratio, policy)
                total_time_ms = (time.time() - start) * 1000

            latency_ms = total_time_ms / 50
            throughput = (64 * 50) / (total_time_ms / 1000)

        return gflops, latency_ms, throughput

    @staticmethod
    def evaluate_accuracy(model, loader, device, k_ratio, policy='learned'):
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in loader:
                imgs, labels = imgs.to(device), labels.to(device)
                logits = model.forward_inference(imgs, k_ratio=k_ratio, policy=policy)
                correct += (logits.argmax(1) == labels).sum().item()
                total += imgs.size(0)
        return correct / total if total > 0 else 0.0

print("‚úì ASEL model and benchmark helpers loaded (for comparison)")

‚úì ASEL model and benchmark helpers loaded (for comparison)


In [11]:
# ==============================================================================
# COMPARATIVE PLOTS: ASEL VS AGENTVIT FROM SAVED MODELS
# ==============================================================================
ASEL_CHECKPOINTS = {
    'aid': os.path.join(CONFIG['save_path'], 'aid_finetuned.pth'),
    'eurosat': os.path.join(CONFIG['save_path'], 'eurosat_finetuned.pth'),
    'rsscn7': os.path.join(CONFIG['save_path'], 'rsscn7_finetuned.pth'),
}

def _load_state_dict_flexible(obj):
    """Handle different checkpoint formats for robustness."""
    if isinstance(obj, dict):
        for key in ['model_state', 'state_dict']:
            if key in obj and isinstance(obj[key], dict):
                return obj[key]
        if all(isinstance(v, torch.Tensor) for v in obj.values()):
            return obj
    return obj

def load_asel_model(ds_name, device):
    ckpt_path = ASEL_CHECKPOINTS.get(ds_name)
    if ckpt_path is None or not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"ASEL checkpoint not found for {ds_name}: {ckpt_path}")

    _, test_ds, n_cls = get_dataset(ds_name)
    model = ASEL(num_classes=n_cls, pretrained=False).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    state_dict = _load_state_dict_flexible(ckpt)
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    return model, test_ds

def load_agentvit_model(ds_name, device):
    ckpt_path = os.path.join(CONFIG['save_path'], f'agentvit_{ds_name}.pth')
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"AgentViT checkpoint not found for {ds_name}: {ckpt_path}")

    ckpt = torch.load(ckpt_path, map_location=device)
    cfg = ckpt.get('config', {})
    image_size = cfg.get('resize_dim', CONFIG['resize_dim'])
    patch_size = cfg.get('patch_size', 16)
    embed_dim = cfg.get('embed_dim', 128)
    depth = cfg.get('depth', 6)
    heads = cfg.get('heads', 8)
    mlp_dim = cfg.get('mlp_dim', 512)

    _, test_ds, n_cls = get_dataset(ds_name)
    model = SimpleAgentViT(
        image_size=image_size,
        patch_size=patch_size,
        num_classes=n_cls,
        dim=embed_dim,
        depth=depth,
        heads=heads,
        mlp_dim=mlp_dim,
    ).to(device)
    state_dict = _load_state_dict_flexible(ckpt.get('model_state', ckpt))
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    num_patches_per_side = image_size // patch_size
    num_patches = num_patches_per_side * num_patches_per_side
    dqn_cfg = cfg if cfg else {}
    dqn_agent = DQNAgent(
        replay_batch_size=dqn_cfg.get('replay_batch_size', 8),
        num_patches=num_patches,
        buffer_capacity=dqn_cfg.get('buffer_capacity', 64),
        gamma=dqn_cfg.get('gamma', 0.95),
        tau=dqn_cfg.get('tau', 0.1),
        update_frequency=dqn_cfg.get('update_frequency', 2),
        learning_rate=dqn_cfg.get('dqn_lr', 0.01),
        env=None,
        device=device,
        num_classes=n_cls,
    )
    if isinstance(ckpt, dict) and 'dqn_state' in ckpt:
        dqn_agent.q_network.load_state_dict(ckpt['dqn_state'], strict=False)
    return model, dqn_agent, test_ds

def compare_asel_agentvit_on_dataset(ds_name):
    device = CONFIG['device']
    print(f"\n{'='*80}\nCOMPARATIVE EVALUATION: {ds_name.upper()} (ASEL vs AgentViT)\n{'='*80}")

    asel_model, asel_test_ds = load_asel_model(ds_name, device)
    agentvit_model, dqn_agent, agent_test_ds = load_agentvit_model(ds_name, device)

    # Use the same test set for both methods
    _, test_ds, _ = get_dataset(ds_name)
    test_loader = DataLoader(
        test_ds,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=CONFIG['num_workers'],
        drop_last=True,
    )

    ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    res = {
        'ratios': ratios,
        'asel_acc': [],
        'agentvit_acc': [],
        'asel_gflops': [],
        'agentvit_gflops': [],
        'asel_thr': [],
        'agentvit_thr': [],
        'asel_lat': [],
        'agentvit_lat': [],
    }

    for r in ratios:
        print(f"  Evaluating keep ratio {r:.1f}...")
        acc_asel = ASELBenchmark.evaluate_accuracy(asel_model, test_loader, device, r, policy='learned')
        gf_asel, lat_asel, thr_asel = ASELBenchmark.measure_metrics(asel_model, device, r, policy='learned')
        acc_agent = AgentViTBenchmark.evaluate_accuracy_at_ratio(
            agentvit_model, test_loader, device, r, use_learned_selection=True, dqn_agent=dqn_agent
)
        gf_agent, lat_agent, thr_agent = AgentViTBenchmark.measure_metrics(agentvit_model, device, r)

        res['asel_acc'].append(acc_asel)
        res['agentvit_acc'].append(acc_agent)
        res['asel_gflops'].append(gf_asel)
        res['agentvit_gflops'].append(gf_agent)
        res['asel_thr'].append(thr_asel)
        res['agentvit_thr'].append(thr_agent)
        res['asel_lat'].append(lat_asel)
        res['agentvit_lat'].append(lat_agent)

    results_path = CONFIG['results_path']
    os.makedirs(results_path, exist_ok=True)

    # 1. Accuracy vs Keep Ratio
    plt.figure(figsize=(8, 6))
    plt.plot(ratios, [x * 100 for x in res['asel_acc']], 'b-o', lw=2, label='ASEL (learned)')
    plt.plot(ratios, [x * 100 for x in res['agentvit_acc']], 'r-s', lw=2, label='AgentViT (learned)')
    plt.title(f'{ds_name.upper()}: Accuracy vs Keep Ratio (ASEL vs AgentViT)')
    plt.xlabel('Keep Ratio')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.5)
    plt.legend()
    plt.savefig(os.path.join(results_path, f'{ds_name}_asel_vs_agentvit_1_acc_keep.png'))
    plt.close()

    # 2. Accuracy vs Throughput
    plt.figure(figsize=(8, 6))
    plt.plot(res['asel_thr'], [x * 100 for x in res['asel_acc']], 'b-o', lw=2, label='ASEL')
    plt.plot(res['agentvit_thr'], [x * 100 for x in res['agentvit_acc']], 'r-s', lw=2, label='AgentViT')
    plt.title(f'{ds_name.upper()}: Accuracy vs Throughput (ASEL vs AgentViT)')
    plt.xlabel('Throughput (img/s)')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.5)
    plt.legend()
    plt.savefig(os.path.join(results_path, f'{ds_name}_asel_vs_agentvit_2_acc_throughput.png'))
    plt.close()

    # 3. GFLOPs vs Keep Ratio
    plt.figure(figsize=(8, 6))
    plt.plot(ratios, res['asel_gflops'], 'b-o', lw=2, label='ASEL')
    plt.plot(ratios, res['agentvit_gflops'], 'r-s', lw=2, label='AgentViT')
    plt.title(f'{ds_name.upper()}: GFLOPs vs Keep Ratio (ASEL vs AgentViT)')
    plt.xlabel('Keep Ratio')
    plt.ylabel('GFLOPs')
    plt.grid(True, alpha=0.5)
    plt.legend()
    plt.savefig(os.path.join(results_path, f'{ds_name}_asel_vs_agentvit_3_gflops.png'))
    plt.close()

    # 4. Latency at 50% keep ratio
    mid_idx = 4  # ratio 0.5
    asel_lat_50 = res['asel_lat'][mid_idx]
    agent_lat_50 = res['agentvit_lat'][mid_idx]
    plt.figure(figsize=(6, 6))
    plt.bar(['ASEL (50%)', 'AgentViT (50%)'], [asel_lat_50, agent_lat_50], color=['blue', 'red'], width=0.5)
    plt.title(f'{ds_name.upper()}: Batch Latency at 50% Keep (ASEL vs AgentViT)')
    plt.ylabel('Time (ms)')
    plt.text(0, asel_lat_50, f'{asel_lat_50:.1f}ms', ha='center', va='bottom', fontweight='bold')
    plt.text(1, agent_lat_50, f'{agent_lat_50:.1f}ms', ha='center', va='bottom', fontweight='bold')
    plt.savefig(os.path.join(results_path, f'{ds_name}_asel_vs_agentvit_4_latency_50.png'))
    plt.close()

    print(f"‚úì Saved comparative ASEL vs AgentViT plots for {ds_name} to {results_path}")
    return res

def run_asel_agentvit_comparisons():
    for ds_name in ['aid', 'eurosat', 'rsscn7']:
        try:
            compare_asel_agentvit_on_dataset(ds_name)
        except Exception as e:
            print(f"Error comparing {ds_name}: {e}")
            import traceback
            traceback.print_exc()

print("‚úì Comparative plotting utilities (ASEL vs AgentViT) ready")

‚úì Comparative plotting utilities (ASEL vs AgentViT) ready


In [12]:
run_asel_agentvit_comparisons()


COMPARATIVE EVALUATION: AID (ASEL vs AgentViT)


  Evaluating keep ratio 0.1...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.2...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.3...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.4...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.5...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.6...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.7...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.8...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.9...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 1.0...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


‚úì Saved comparative ASEL vs AgentViT plots for aid to ./benchmarks_results_agentvit

COMPARATIVE EVALUATION: EUROSAT (ASEL vs AgentViT)
  Evaluating keep ratio 0.1...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.2...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.3...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.4...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.5...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.6...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.7...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.8...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.9...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 1.0...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


‚úì Saved comparative ASEL vs AgentViT plots for eurosat to ./benchmarks_results_agentvit

COMPARATIVE EVALUATION: RSSCN7 (ASEL vs AgentViT)
  Evaluating keep ratio 0.1...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.2...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.3...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.4...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.5...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.6...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.7...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.8...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 0.9...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


  Evaluating keep ratio 1.0...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


‚úì Saved comparative ASEL vs AgentViT plots for rsscn7 to ./benchmarks_results_agentvit
