# ASEL Pipeline: Main Process Overview

This notebook implements **ASEL**, a ViT-based model with a learned patch selector for efficient inference on CIFAR‑10 and remote‑sensing datasets (AID, EuroSAT, RSSCN7). The pipeline in this notebook has **two concrete phases**, with caching and benchmarking built in.

---

## Phase 1 – CIFAR‑10 Warmup

1. **Model and Selector**
   - Backbone: `vit_tiny_patch16_224` from `timm` (classification head + transformer blocks).
   - Selector: a small MLP (`patch_selector`) that takes each patch embedding concatenated with a global feature and outputs an importance score per patch.

2. **Training Setup on CIFAR‑10**
   - Dataset: CIFAR‑10 loaded via `get_dataset('cifar10')`, with a single train/test split (no validation).
   - Checkpoint path: `./saved_models/cifar10_warmup.pth`.
   - If this file **exists**, it is loaded and **training on CIFAR‑10 is skipped**.
   - If it does **not** exist:
     - A new `ASEL(num_classes=10)` model is trained for **15 epochs** using `train_epoch`.
     - Dynamic keep‑ratio inside `train_epoch`:
       - For epochs `< 5`: `k_ratio = 1.0` (all patches kept).
       - For epochs `≥ 5`: `k_ratio` is sampled uniformly in `[0.15, 0.75]` for each batch.
     - Loss per batch:
       - `loss_cls = cross_entropy(logits, labels)`
       - `loss_sparsity = 0.02 * scores.mean()`
       - `loss = loss_cls + loss_sparsity`.
     - Optimizer during warmup:
       - `patch_selector` parameters: lr = `CONFIG['learning_rate_selector']` (1e‑4).
       - Backbone parameters (excluding head): lr = 5e‑5.
       - Classification head parameters: lr = 5e‑4.
     - After training, the model weights are saved to `cifar10_warmup.pth`.

3. **CIFAR‑10 Benchmarks**
   - Regardless of whether the model was loaded or trained, `run()` always:
     - Builds a deterministic test DataLoader using `create_loader(cifar_te, batch_size, shuffle=False)`.
     - Calls `run_benchmarks_and_plot(model, loader_te, 'cifar10', device)`.
   - `run_benchmarks_and_plot` evaluates ASEL at keep‑ratios `[0.1, 0.2, ..., 1.0]` for three policies in `forward_inference`:
     - `learned`: top‑k patches by selector scores.
     - `random`: random subset of patches with the same budget.
     - `central`: most central patches in the 14×14 patch grid.
   - For each ratio, it records: accuracy (per policy), GFLOPs, batch latency, and throughput via the `Benchmark` class.
   - Plots are saved in `CONFIG['results_path']` (`./benchmarks_results`) with names:
     - `{ds_name}_1_strategies.png` (accuracy vs keep‑ratio for learned/random/central + full ViT).
     - `{ds_name}_2_throughput.png` (accuracy vs throughput).
     - `{ds_name}_3_gflops.png` (GFLOPs vs keep‑ratio).
     - `{ds_name}_4_latency.png` (latency bar: full ViT vs 50% keep).

---

## Phase 2 – Transfer to Remote‑Sensing Datasets

4. **Target Datasets and Splits**
   - Targets: `['aid', 'eurosat', 'rsscn7']`.
   - `get_dataset(name)` uses:
     - For `'eurosat'`: torchvision `EuroSAT` with an **80/20 train/val split** using a fixed generator `GEN`.
     - For `'aid'`, `'ucmerced'`, `'rsscn7'`:
       - Image folders read via `CleanImageFolder` (ignores hidden folders).
       - **80/20 train/test split** using `random_split(..., generator=GEN)` for deterministic behavior.
   - `create_loader` builds DataLoaders with:
     - `worker_init_fn=seed_worker` for reproducible workers.
     - `generator=GEN` when `shuffle=True` to keep shuffling deterministic.

5. **Transfer Learning from CIFAR‑10**
   - For each dataset `ds_name` in `['aid', 'eurosat', 'rsscn7']`, `run()` does:
     - Builds `save_name = ./saved_models/{ds_name}_finetuned.pth`.
     - Creates a fresh `ASEL(num_classes=n_cls)` model on the configured device.
     - If `save_name` exists:
       - Loads `model_transfer.load_state_dict(torch.load(save_name))`.
       - **Skips all further training** for that dataset.
     - Otherwise (no saved model):
       - Loads CIFAR‑10 warmup weights from `cifar10_warmup.pth`.
       - Removes all head weights from the CIFAR‑10 state dict (`'head' not in k`) so the target dataset gets a new classification head.
       - Updates `model_transfer.state_dict()` with the non‑head weights and loads them.
       - Builds a train loader for the target dataset with `create_loader(train_ds, shuffle=True)`.
       - Trains for **25 epochs** using `train_epoch` with the same dynamic k‑ratio schedule and loss structure as in warmup.
       - Optimizer for transfer:
         - `patch_selector`: lr = `CONFIG['learning_rate_selector']` (1e‑4).
         - Backbone (excluding head): lr = 2e‑5.
         - New classification head: lr = 5e‑4.
       - Saves the fine‑tuned model to `./saved_models/{ds_name}_finetuned.pth`.

6. **Benchmarks on Target Datasets**
   - For **every** target dataset (loaded or newly trained), `run()` always:
     - Creates a non‑shuffled test loader with `create_loader(test_ds, shuffle=False)`.
     - Calls `run_benchmarks_and_plot(model_transfer, test_loader, ds_name, device)`.
   - The same keep‑ratio grid, selector policies, and metrics as in CIFAR‑10 are used, and the plots are written to `./benchmarks_results` with filenames based on `ds_name`.

---

## Reproducibility Notes

- `set_seed(CONFIG['seed'])` sets seeds for Python, NumPy, and all PyTorch/CUDA RNGs and configures CuDNN for deterministic behavior.
- `GEN` (a `torch.Generator`) and `seed_worker` are used consistently in `random_split` and DataLoaders so that dataset splits and shuffling are repeatable across runs.


In [None]:
import os
import time
import random
import warnings
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm import tqdm
import timm

# ------------------------------------------------------------------------------
# Try importing FLOPs counting libraries (for Benchmarking only)
# ------------------------------------------------------------------------------
try:
    from fvcore.nn import FlopCountAnalysis
    HAS_FVCORE = True
except ImportError:
    HAS_FVCORE = False
    print("Warning: 'fvcore' not found. GFLOPs will be estimated theoretically.")

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
CONFIG = {
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_workers': 4,
    'batch_size': 64,
    'resize_dim': 224,
    'learning_rate_selector': 1e-4,
    'save_path': './saved_models',
    'results_path': './benchmarks_results'
}

DATASET_PATHS = {
    'aid': 'AID-data',
    'ucmerced': 'UCMerced_LandUse/Images',
    'rsscn7': './RSSCN7',
}

os.makedirs(CONFIG['save_path'], exist_ok=True)
os.makedirs(CONFIG['results_path'], exist_ok=True)

# ------------------------------------------------------------------------------
# REPRODUCIBILITY SETUP
# ------------------------------------------------------------------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in CuDNN (Trade-off: may be slower)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(CONFIG['seed'])

# Generator for reproducible random_split
GEN = torch.Generator()
GEN.manual_seed(CONFIG['seed'])

# Worker init for DataLoaders
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# ==============================================================================
# 2. MODEL ARCHITECTURE
# ==============================================================================
class ASEL(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        # Load standard ViT
        self.backbone = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained, num_classes=num_classes)
        self.embed_dim = self.backbone.embed_dim
        
        # The "Selector" Network
        self.patch_selector = nn.Sequential(
            nn.Linear(self.embed_dim * 2, 96), 
            nn.LayerNorm(96),
            nn.ReLU(),
            nn.Linear(96, 1),
            nn.Sigmoid() 
        )
        
        # (Added only for Benchmarking Plots: Pre-computed central indices)
        H, W = 14, 14
        center = (H - 1) / 2.0
        y, x = np.ogrid[:H, :W]
        dist = (x - center)**2 + (y - center)**2
        self.central_indices = torch.from_numpy(np.argsort(dist.flatten())).long()

    def _get_patch_embeddings(self, x):
        x = self.backbone.patch_embed(x)
        x = x + self.backbone.pos_embed[:, 1:]
        return x

    def _process_transformer(self, x_patches):
        B = x_patches.shape[0]
        cls_token = self.backbone.cls_token.expand(B, -1, -1) + self.backbone.pos_embed[:, :1]
        x = torch.cat((cls_token, x_patches), dim=1)
        x = self.backbone.pos_drop(x)
        x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        return self.backbone.head(x[:, 0])

    def _compute_importance_scores(self, x_patches):
        global_feat = x_patches.mean(dim=1, keepdim=True).expand(-1, x_patches.shape[1], -1)
        selector_input = torch.cat([x_patches, global_feat], dim=-1)
        return self.patch_selector(selector_input).squeeze(-1)

    # --------------------------------------------------------------------------
    # MODE A: TRAINING
    # --------------------------------------------------------------------------
    def forward_train(self, x_images, k_ratio):
        x_patches = self._get_patch_embeddings(x_images)
        scores = self._compute_importance_scores(x_patches)
        B, N, D = x_patches.shape
        k = int(N * k_ratio)
        if k < 1: k = 1

        _, topk_idx = torch.topk(scores, k, dim=1)
        mask_hard = torch.zeros_like(scores)
        mask_hard.scatter_(1, topk_idx, 1.0)
        
        # Straight-Through Estimator
        mask = mask_hard - scores.detach() + scores
        
        x_masked = x_patches * mask.unsqueeze(-1)
        logits = self._process_transformer(x_masked)
        return logits, scores

    # --------------------------------------------------------------------------
    # MODE B: INFERENCE (Adapted for Benchmarking Policies)
    # --------------------------------------------------------------------------
    def forward_inference(self, x_images, k_ratio, policy='learned'):
        x_patches = self._get_patch_embeddings(x_images)
        B, N, D = x_patches.shape
        k = int(N * k_ratio)
        if k < 1: k = 1

        # Select indices based on Policy
        if policy == 'learned':
            scores = self._compute_importance_scores(x_patches)
            _, topk_idx = torch.topk(scores, k, dim=1)
        elif policy == 'random':
            topk_idx = torch.stack([torch.randperm(N)[:k] for _ in range(B)]).to(x_patches.device)
        elif policy == 'central':
            indices = self.central_indices[:k].to(x_patches.device)
            topk_idx = indices.unsqueeze(0).expand(B, -1)
        else:
            raise ValueError("Unknown Policy")

        topk_idx_expanded = topk_idx.unsqueeze(-1).expand(-1, -1, D)
        x_kept = torch.gather(x_patches, 1, topk_idx_expanded)
        
        logits = self._process_transformer(x_kept)
        return logits

# ==============================================================================
# 3. TRAINING ENGINE
# ==============================================================================
def train_epoch(model, loader, optimizer, epoch, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    if epoch < 5:
        min_k, max_k = 1.0, 1.0 
    else:
        min_k, max_k = 0.15, 0.75

    pbar = tqdm(loader, desc=f"Train Ep {epoch}", leave=True)
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        
        current_k = random.uniform(min_k, max_k)
        
        logits, scores = model.forward_train(imgs, k_ratio=current_k)
        
        loss_cls = F.cross_entropy(logits, labels)
        loss_sparsity = 0.02 * scores.mean()
        loss = loss_cls + loss_sparsity
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (logits.argmax(1) == labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Acc": f"{correct/total:.2%}"})

# ==============================================================================
# 4. BENCHMARKING SUITE
# ==============================================================================
class Benchmark:
    @staticmethod
    def measure_metrics(model, device, k_ratio, policy='learned'):
        dummy_input = torch.randn(1, 3, 224, 224).to(device)
        batch_input = torch.randn(64, 3, 224, 224).to(device) 
        
        class Wrapper(nn.Module):
            def __init__(self, m, k, p): super().__init__(); self.m = m; self.k = k; self.p = p
            def forward(self, x): return self.m.forward_inference(x, self.k, self.p)
        
        wrapped_model = Wrapper(model, k_ratio, policy)
        
        if HAS_FVCORE:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                flops_counter = FlopCountAnalysis(wrapped_model, dummy_input)
                flops_counter.unsupported_ops_warnings(False)
                gflops = flops_counter.total() / 1e9
        else:
            gflops = 1.1 * k_ratio 
            
        model.eval()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        with torch.no_grad():
            _ = model.forward_inference(batch_input, k_ratio, policy) # Warmup
            start_event.record()
            for _ in range(50):
                _ = model.forward_inference(batch_input, k_ratio, policy)
            end_event.record()
            torch.cuda.synchronize()
            
        total_time_ms = start_event.elapsed_time(end_event)
        latency_ms = total_time_ms / 50
        throughput = (64 * 50) / (total_time_ms / 1000)
        
        return gflops, latency_ms, throughput

    @staticmethod
    def evaluate_accuracy(model, loader, device, k_ratio, policy):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for imgs, labels in loader:
                imgs, labels = imgs.to(device), labels.to(device)
                logits = model.forward_inference(imgs, k_ratio=k_ratio, policy=policy)
                correct += (logits.argmax(1) == labels).sum().item()
                total += imgs.size(0)
        return correct / total

def run_benchmarks_and_plot(model, test_loader, ds_name, device):
    print(f"\nRunning Benchmarks for {ds_name}...")
    ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    
    res = {
        'ratios': ratios,
        'acc_learned': [], 'acc_random': [], 'acc_central': [],
        'gflops': [], 'thr': [], 'lat': []
    }
    
    full_acc = Benchmark.evaluate_accuracy(model, test_loader, device, 1.0, 'learned')
    full_gflops, full_lat, full_thr = Benchmark.measure_metrics(model, device, 1.0)
    
    for r in ratios:
        acc_l = Benchmark.evaluate_accuracy(model, test_loader, device, r, 'learned')
        acc_r = Benchmark.evaluate_accuracy(model, test_loader, device, r, 'random')
        acc_c = Benchmark.evaluate_accuracy(model, test_loader, device, r, 'central')
        gf, lat, thr = Benchmark.measure_metrics(model, device, r, 'learned')
        
        res['acc_learned'].append(acc_l)
        res['acc_random'].append(acc_r)
        res['acc_central'].append(acc_c)
        res['gflops'].append(gf)
        res['thr'].append(thr)
        res['lat'].append(lat)
        
        print(f" Ratio {r:.1f} | L-Acc: {acc_l:.1%} | R-Acc: {acc_r:.1%} | FPS: {thr:.0f}")

    # PLOTS
    plt.figure(figsize=(8, 6))
    plt.plot(ratios, [x*100 for x in res['acc_learned']], 'r-o', lw=2, label='Ours (Learned)')
    plt.plot(ratios, [x*100 for x in res['acc_central']], 'b--s', alpha=0.7, label='Central')
    plt.plot(ratios, [x*100 for x in res['acc_random']], 'k--x', alpha=0.5, label='Random')
    plt.scatter([1.0], [full_acc*100], c='k', marker='*', s=200, zorder=10, label='Full ViT')
    plt.title(f'{ds_name}: Strategy Comparison')
    plt.xlabel('Keep Ratio'); plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.5); plt.legend()
    plt.savefig(f"{CONFIG['results_path']}/{ds_name}_1_strategies.png"); plt.close()

    plt.figure(figsize=(8, 6))
    plt.plot(res['thr'], [x*100 for x in res['acc_learned']], 'g-o', lw=2, label='Ours')
    plt.scatter([full_thr], [full_acc*100], c='k', marker='*', s=200, label='Full ViT')
    plt.title(f'{ds_name}: Accuracy vs Throughput')
    plt.xlabel('Throughput (img/s)'); plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.5); plt.legend()
    plt.savefig(f"{CONFIG['results_path']}/{ds_name}_2_throughput.png"); plt.close()

    plt.figure(figsize=(8, 6))
    plt.plot(ratios, res['gflops'], 'm-o', lw=2, label='Ours')
    plt.axhline(y=full_gflops, c='k', ls='--', label='Full ViT')
    plt.title(f'{ds_name}: GFLOPs Reduction')
    plt.xlabel('Keep Ratio'); plt.ylabel('GFLOPs')
    plt.grid(True, alpha=0.5); plt.legend()
    plt.savefig(f"{CONFIG['results_path']}/{ds_name}_3_gflops.png"); plt.close()

    lat_50 = res['lat'][4] # Ratio 0.5
    plt.figure(figsize=(6, 6))
    plt.bar(['Full ViT', 'Ours (50%)'], [full_lat, lat_50], color=['gray', 'green'], width=0.5)
    plt.title(f'{ds_name}: Batch Latency')
    plt.ylabel('Time (ms)')
    plt.text(0, full_lat, f"{full_lat:.1f}ms", ha='center', va='bottom', fontweight='bold')
    plt.text(1, lat_50, f"{lat_50:.1f}ms", ha='center', va='bottom', fontweight='bold')
    plt.savefig(f"{CONFIG['results_path']}/{ds_name}_4_latency.png"); plt.close()

# ==============================================================================
# 5. DATA LOADING (Robust & Deterministic)
# ==============================================================================
class CleanImageFolder(datasets.ImageFolder):
    def find_classes(self, directory):
        classes = sorted(entry.name for entry in os.scandir(directory) 
                         if entry.is_dir() and not entry.name.startswith('.'))
        if not classes:
            raise FileNotFoundError(f"No classes in {directory}")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

def create_loader(dataset, batch_size, shuffle):
    # Enforce reproducibility inside the DataLoader
    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle, 
        num_workers=CONFIG['num_workers'], 
        worker_init_fn=seed_worker,  # Important for worker seeding
        generator=GEN if shuffle else None, # Important for shuffle seeding
        drop_last=False
    )

def get_dataset(name):
    tf = transforms.Compose([
        transforms.Resize((224, 224), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    if name == 'cifar10':
        ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=tf)
        test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=tf)
        return ds, test_ds, 10
    elif name == 'eurosat':
        ds = datasets.EuroSAT(root='./data', download=True, transform=tf)
        train_len = int(0.8 * len(ds))
        # Use GEN for deterministic split
        train_ds, val_ds = random_split(ds, [train_len, len(ds)-train_len], generator=GEN)
        return train_ds, val_ds, 10
    else:
        path = DATASET_PATHS.get(name)
        if not path or not os.path.exists(path):
            raise FileNotFoundError(f"Path not found: {path}")
        ds = CleanImageFolder(root=path, transform=tf)
        train_len = int(0.8 * len(ds))
        # Use GEN for deterministic split
        train_ds, test_ds = random_split(ds, [train_len, len(ds)-train_len], generator=GEN)
        return train_ds, test_ds, len(ds.classes)

# ==============================================================================
# 6. MAIN PIPELINE (With Caching Logic)
# ==============================================================================
def run():
    # ------------------------------------------------------------------
    # PHASE 1: WARMUP ON CIFAR-10
    # ------------------------------------------------------------------
    print(f"\n{'='*40}\nPHASE 1: WARMUP ON CIFAR-10\n{'='*40}")
    cifar_path = f"{CONFIG['save_path']}/cifar10_warmup.pth"
    cifar_tr, cifar_te, cifar_n = get_dataset('cifar10')
    
    # Init Model
    model = ASEL(num_classes=cifar_n).to(CONFIG['device'])
    
    # CHECK: Do we have a saved model?
    if os.path.exists(cifar_path):
        print(f">> Found existing checkpoint: {cifar_path}")
        print(">> Loading model and skipping training...")
        model.load_state_dict(torch.load(cifar_path))
    else:
        print(">> No checkpoint found. Starting training...")
        loader_tr = create_loader(cifar_tr, CONFIG['batch_size'], shuffle=True)
        
        # Optimizer
        head_params = list(map(id, model.backbone.head.parameters()))
        backbone_params = filter(lambda p: id(p) not in head_params, model.backbone.parameters())
        optimizer = optim.Adam([
            {'params': model.patch_selector.parameters(), 'lr': CONFIG['learning_rate_selector']},
            {'params': backbone_params, 'lr': 5e-5},
            {'params': model.backbone.head.parameters(), 'lr': 5e-4}
        ])
        
        for ep in range(15):
            train_epoch(model, loader_tr, optimizer, ep, CONFIG['device'])
        
        torch.save(model.state_dict(), cifar_path)
        print(">> Training finished. Model saved.")

    # Always run benchmarks (even if loaded)
    loader_te = create_loader(cifar_te, CONFIG['batch_size'], shuffle=False)
    run_benchmarks_and_plot(model, loader_te, 'cifar10', CONFIG['device'])
    
    # ------------------------------------------------------------------
    # PHASE 2: TRANSFER TO TARGET DATASETS
    # ------------------------------------------------------------------
    targets = ['aid', 'eurosat', 'rsscn7']
    
    for ds_name in targets:
        print(f"\n{'='*40}\nPHASE 2: TRANSFER TO {ds_name.upper()}\n{'='*40}")
        save_name = f"{CONFIG['save_path']}/{ds_name}_finetuned.pth"

        try:
            train_ds, test_ds, n_cls = get_dataset(ds_name)
        except Exception as e:
            print(f"Skipping {ds_name}: {e}")
            continue

        test_loader = create_loader(test_ds, CONFIG['batch_size'], shuffle=False)
        model_transfer = ASEL(num_classes=n_cls).to(CONFIG['device'])

        # CHECK: Do we have a saved fine-tuned model?
        if os.path.exists(save_name):
            print(f">> Found existing checkpoint: {save_name}")
            print(">> Loading model and skipping training...")
            model_transfer.load_state_dict(torch.load(save_name))
        
        else:
            print(f">> No checkpoint found for {ds_name}. Training...")
            # Load weights from CIFAR-10 Warmup (Excluding Head)
            if not os.path.exists(cifar_path):
                raise RuntimeError("CIFAR-10 model missing! Cannot transfer learn.")
                
            pretrained_dict = torch.load(cifar_path)
            model_dict = model_transfer.state_dict()
            # Filter head weights
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if 'head' not in k}
            model_dict.update(pretrained_dict)
            model_transfer.load_state_dict(model_dict)
            print(">> Loaded CIFAR-10 weights (Head re-initialized)")

            train_loader = create_loader(train_ds, CONFIG['batch_size'], shuffle=True)

            # Optimizer
            head_params = list(map(id, model_transfer.backbone.head.parameters()))
            backbone_params = filter(lambda p: id(p) not in head_params, model_transfer.backbone.parameters())
            optimizer = optim.Adam([
                {'params': model_transfer.patch_selector.parameters(), 'lr': CONFIG['learning_rate_selector']},
                {'params': backbone_params, 'lr': 2e-5},
                {'params': model_transfer.backbone.head.parameters(), 'lr': 5e-4}
            ])
            
            for ep in range(25):
                train_epoch(model_transfer, train_loader, optimizer, ep, CONFIG['device'])
            
            torch.save(model_transfer.state_dict(), save_name)
            print(f">> Training finished for {ds_name}. Model saved.")

        # Always run benchmarks
        run_benchmarks_and_plot(model_transfer, test_loader, ds_name, CONFIG['device'])
        print(f"Finished {ds_name}. Plots saved.")

if __name__ == "__main__":
    run()


PHASE 1: WARMUP ON CIFAR-10


  entry = pickle.load(f, encoding="latin1")


>> No checkpoint found. Starting training...


Train Ep 0: 100%|██████████| 782/782 [00:25<00:00, 30.61it/s, Loss=0.2102, Acc=94.41%]
Train Ep 1: 100%|██████████| 782/782 [00:25<00:00, 30.53it/s, Loss=0.0308, Acc=98.19%]
Train Ep 2: 100%|██████████| 782/782 [00:25<00:00, 30.35it/s, Loss=0.0185, Acc=98.86%]
Train Ep 3: 100%|██████████| 782/782 [00:25<00:00, 30.50it/s, Loss=0.0010, Acc=98.93%]
Train Ep 4: 100%|██████████| 782/782 [00:25<00:00, 30.42it/s, Loss=0.0007, Acc=99.17%]
Train Ep 5: 100%|██████████| 782/782 [00:25<00:00, 30.48it/s, Loss=0.0314, Acc=89.65%]
Train Ep 6: 100%|██████████| 782/782 [00:25<00:00, 30.68it/s, Loss=0.0053, Acc=94.28%]
Train Ep 7: 100%|██████████| 782/782 [00:25<00:00, 30.43it/s, Loss=0.0008, Acc=94.87%]
Train Ep 8: 100%|██████████| 782/782 [00:25<00:00, 30.60it/s, Loss=0.2544, Acc=96.18%]
Train Ep 9: 100%|██████████| 782/782 [00:25<00:00, 30.24it/s, Loss=0.0174, Acc=96.22%]
Train Ep 10: 100%|██████████| 782/782 [00:25<00:00, 30.74it/s, Loss=0.9254, Acc=96.97%]
Train Ep 11: 100%|██████████| 782/782 [00:

>> Training finished. Model saved.

Running Benchmarks for cifar10...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have z

 Ratio 0.1 | L-Acc: 72.9% | R-Acc: 65.9% | FPS: 37605


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.2 | L-Acc: 86.4% | R-Acc: 82.2% | FPS: 30053


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.3 | L-Acc: 90.9% | R-Acc: 89.0% | FPS: 26506


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.4 | L-Acc: 93.2% | R-Acc: 92.2% | FPS: 21197


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.5 | L-Acc: 94.7% | R-Acc: 94.0% | FPS: 19906


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.6 | L-Acc: 95.8% | R-Acc: 95.3% | FPS: 16715


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.7 | L-Acc: 96.4% | R-Acc: 95.9% | FPS: 13930


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.8 | L-Acc: 96.7% | R-Acc: 96.5% | FPS: 13356


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.9 | L-Acc: 97.0% | R-Acc: 96.7% | FPS: 10854


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 1.0 | L-Acc: 97.1% | R-Acc: 97.1% | FPS: 9944

PHASE 2: TRANSFER TO AID
>> No checkpoint found for aid. Training...
>> Loaded CIFAR-10 weights (Head re-initialized)


Train Ep 0: 100%|██████████| 125/125 [00:08<00:00, 14.57it/s, Loss=0.3077, Acc=69.95%]
Train Ep 1: 100%|██████████| 125/125 [00:08<00:00, 14.52it/s, Loss=0.2954, Acc=93.64%]
Train Ep 2: 100%|██████████| 125/125 [00:08<00:00, 14.41it/s, Loss=0.0865, Acc=97.81%]
Train Ep 3: 100%|██████████| 125/125 [00:08<00:00, 14.00it/s, Loss=0.0258, Acc=99.39%]
Train Ep 4: 100%|██████████| 125/125 [00:08<00:00, 14.27it/s, Loss=0.0107, Acc=99.91%]
Train Ep 5: 100%|██████████| 125/125 [00:08<00:00, 14.62it/s, Loss=0.2232, Acc=89.45%]
Train Ep 6: 100%|██████████| 125/125 [00:08<00:00, 14.47it/s, Loss=0.4249, Acc=93.96%]
Train Ep 7: 100%|██████████| 125/125 [00:08<00:00, 14.50it/s, Loss=0.6277, Acc=95.05%]
Train Ep 8: 100%|██████████| 125/125 [00:08<00:00, 14.06it/s, Loss=0.0723, Acc=97.00%]
Train Ep 9: 100%|██████████| 125/125 [00:08<00:00, 14.24it/s, Loss=0.0209, Acc=97.59%]
Train Ep 10: 100%|██████████| 125/125 [00:08<00:00, 14.52it/s, Loss=0.0271, Acc=97.72%]
Train Ep 11: 100%|██████████| 125/125 [00:

>> Training finished for aid. Model saved.

Running Benchmarks for aid...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have z

 Ratio 0.1 | L-Acc: 64.3% | R-Acc: 55.9% | FPS: 37538


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.2 | L-Acc: 79.7% | R-Acc: 72.7% | FPS: 30023


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.3 | L-Acc: 86.8% | R-Acc: 80.1% | FPS: 26697


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.4 | L-Acc: 89.8% | R-Acc: 85.3% | FPS: 21242


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.5 | L-Acc: 91.8% | R-Acc: 87.6% | FPS: 19943


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.6 | L-Acc: 92.8% | R-Acc: 89.5% | FPS: 16771


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.7 | L-Acc: 93.2% | R-Acc: 89.8% | FPS: 14007


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.8 | L-Acc: 93.5% | R-Acc: 91.4% | FPS: 13457


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.9 | L-Acc: 93.1% | R-Acc: 91.7% | FPS: 10882


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 1.0 | L-Acc: 92.3% | R-Acc: 92.3% | FPS: 9994
Finished aid. Plots saved.

PHASE 2: TRANSFER TO EUROSAT
>> No checkpoint found for eurosat. Training...
>> Loaded CIFAR-10 weights (Head re-initialized)


Train Ep 0: 100%|██████████| 338/338 [00:11<00:00, 30.67it/s, Loss=0.0413, Acc=92.54%]
Train Ep 1: 100%|██████████| 338/338 [00:11<00:00, 30.56it/s, Loss=0.0049, Acc=98.42%]
Train Ep 2: 100%|██████████| 338/338 [00:11<00:00, 30.50it/s, Loss=0.0060, Acc=99.35%]
Train Ep 3: 100%|██████████| 338/338 [00:11<00:00, 30.46it/s, Loss=0.0869, Acc=99.59%]
Train Ep 4: 100%|██████████| 338/338 [00:11<00:00, 30.53it/s, Loss=0.0016, Acc=99.90%]
Train Ep 5: 100%|██████████| 338/338 [00:11<00:00, 30.21it/s, Loss=0.2111, Acc=96.00%]
Train Ep 6: 100%|██████████| 338/338 [00:11<00:00, 30.71it/s, Loss=0.0309, Acc=97.44%]
Train Ep 7: 100%|██████████| 338/338 [00:11<00:00, 30.59it/s, Loss=0.0007, Acc=98.00%]
Train Ep 8: 100%|██████████| 338/338 [00:11<00:00, 30.47it/s, Loss=0.0020, Acc=98.44%]
Train Ep 9: 100%|██████████| 338/338 [00:11<00:00, 30.37it/s, Loss=0.0061, Acc=98.88%]
Train Ep 10: 100%|██████████| 338/338 [00:11<00:00, 30.57it/s, Loss=0.0345, Acc=98.98%]
Train Ep 11: 100%|██████████| 338/338 [00:

>> Training finished for eurosat. Model saved.

Running Benchmarks for eurosat...


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have z

 Ratio 0.1 | L-Acc: 88.8% | R-Acc: 84.1% | FPS: 37616


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.2 | L-Acc: 94.6% | R-Acc: 93.3% | FPS: 29913


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.3 | L-Acc: 96.4% | R-Acc: 95.4% | FPS: 26566


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.4 | L-Acc: 97.4% | R-Acc: 96.8% | FPS: 21210


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.5 | L-Acc: 97.9% | R-Acc: 97.4% | FPS: 19902


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.6 | L-Acc: 98.1% | R-Acc: 97.6% | FPS: 16721


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.7 | L-Acc: 98.4% | R-Acc: 98.1% | FPS: 13943


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.8 | L-Acc: 98.4% | R-Acc: 98.2% | FPS: 13346


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.9 | L-Acc: 98.3% | R-Acc: 98.3% | FPS: 10842


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 1.0 | L-Acc: 98.4% | R-Acc: 98.4% | FPS: 9944
Finished eurosat. Plots saved.

PHASE 2: TRANSFER TO RSSCN7
>> No checkpoint found for rsscn7. Training...
>> Loaded CIFAR-10 weights (Head re-initialized)


Train Ep 0: 100%|██████████| 35/35 [00:01<00:00, 25.84it/s, Loss=0.5101, Acc=60.27%]
Train Ep 1: 100%|██████████| 35/35 [00:01<00:00, 25.67it/s, Loss=0.2476, Acc=89.06%]
Train Ep 2: 100%|██████████| 35/35 [00:01<00:00, 25.93it/s, Loss=0.2313, Acc=96.03%]
Train Ep 3: 100%|██████████| 35/35 [00:01<00:00, 25.67it/s, Loss=0.0242, Acc=99.11%]
Train Ep 4: 100%|██████████| 35/35 [00:01<00:00, 26.09it/s, Loss=0.0203, Acc=99.96%] 
Train Ep 5: 100%|██████████| 35/35 [00:01<00:00, 26.04it/s, Loss=0.2212, Acc=90.40%]
Train Ep 6: 100%|██████████| 35/35 [00:01<00:00, 26.04it/s, Loss=0.0867, Acc=95.18%]
Train Ep 7: 100%|██████████| 35/35 [00:01<00:00, 26.16it/s, Loss=0.2052, Acc=95.71%]
Train Ep 8: 100%|██████████| 35/35 [00:01<00:00, 26.12it/s, Loss=0.0944, Acc=97.54%]
Train Ep 9: 100%|██████████| 35/35 [00:01<00:00, 26.33it/s, Loss=0.0607, Acc=97.95%]
Train Ep 10: 100%|██████████| 35/35 [00:01<00:00, 26.22it/s, Loss=0.0136, Acc=98.21%]
Train Ep 11: 100%|██████████| 35/35 [00:01<00:00, 25.99it/s, Lo

>> Training finished for rsscn7. Model saved.

Running Benchmarks for rsscn7...



The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have 

 Ratio 0.1 | L-Acc: 78.0% | R-Acc: 74.6% | FPS: 37628


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.2 | L-Acc: 85.4% | R-Acc: 84.6% | FPS: 30009


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.3 | L-Acc: 90.0% | R-Acc: 87.3% | FPS: 26588


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.4 | L-Acc: 91.6% | R-Acc: 89.8% | FPS: 21183


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.5 | L-Acc: 93.4% | R-Acc: 90.7% | FPS: 19936


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.6 | L-Acc: 93.4% | R-Acc: 92.3% | FPS: 16732


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.7 | L-Acc: 93.2% | R-Acc: 93.8% | FPS: 13960


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.8 | L-Acc: 93.0% | R-Acc: 94.1% | FPS: 13366


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 0.9 | L-Acc: 93.6% | R-Acc: 93.4% | FPS: 10859


The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
m.backbone.blocks.0.attn.attn_drop, m.backbone.blocks.1.attn.attn_drop, m.backbone.blocks.10.attn.attn_drop, m.backbone.blocks.11.attn.attn_drop, m.backbone.blocks.2.attn.attn_drop, m.backbone.blocks.3.attn.attn_drop, m.backbone.blocks.4.attn.attn_drop, m.backbone.blocks.5.attn.attn_drop, m.backbone.blocks.6.attn.attn_drop, m.backbone.blocks.7.attn.attn_drop, m.backbone.blocks.8.attn.attn_drop, m.backbone.blocks.9.attn.attn_drop, m.backbone.head_drop


 Ratio 1.0 | L-Acc: 93.8% | R-Acc: 93.8% | FPS: 9947
Finished rsscn7. Plots saved.
