# Beyond Visible Spectrum: AI for Agriculture 2026 — Task 2
## Full Pipeline: SSL Pretraining → Fine-Tuning → Ensemble

**Strategy:**
- Stage 1: MAE pretraining on unlabeled Sentinel-2 (12-band) data
- Stage 2: Fine-tune 3 models (ViT-Base, Swin-Tiny, ConvNeXt-Small)
- Stage 3: Weighted soft-voting ensemble + TTA

**Hardware:** Optimized for Kaggle T4 x2 (16GB VRAM each)

**Target:** 0.90+ accuracy on crop disease classification (Aphid, Rust, RPH, Blast)

In [39]:
# Install dependencies
!pip install -q timm einops rasterio scikit-learn torchmetrics

In [40]:
import os
import glob
import random
import numpy as np
import pandas as pd
import json
from pathlib import Path
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import rasterio
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
import timm
from einops import rearrange
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report
from torchmetrics import Accuracy

# Reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

seed_everything(42)

# Device setup — use both T4s with DataParallel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'GPU count: {torch.cuda.device_count()}')
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f'  GPU {i}: {torch.cuda.get_device_name(i)}')

Using device: cuda
GPU count: 2
  GPU 0: Tesla T4
  GPU 1: Tesla T4


## 1. Configuration

In [41]:
class CFG:
    # ─── Paths ────────────────────────────────────────────────────
    # UPDATE THESE to match your Kaggle dataset paths
    S2A_ROOT       = '/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/s2a'  # unlabeled SSL data
    LABELED_ROOT   = '/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/ICPR02'  # labeled data
    OUTPUT_DIR     = '/kaggle/working/'

    # ─── Data ─────────────────────────────────────────────────────
    NUM_CLASSES    = 4          # Aphid, Rust, RPH, Blast
    IMG_SIZE       = 64         # patch size (adjust if your tiles are larger)
    IN_CHANNELS    = 16         # 12 Sentinel-2 bands + 4 vegetation indices
    BANDS          = ['B1','B2','B3','B4','B5','B6','B7','B8','B8A','B9','B11','B12']
    CLASS_NAMES    = ['Aphid', 'Rust', 'RPH', 'Blast']
    CLASS2IDX      = {c: i for i, c in enumerate(CLASS_NAMES)}

    # ─── SSL Pretraining ──────────────────────────────────────────
    SSL_EPOCHS     = 50         # increase to 100-200 if time permits
    SSL_BATCH_SIZE = 64         # per GPU; effective batch = 128 with 2x T4
    SSL_LR         = 1.5e-4
    SSL_MASK_RATIO = 0.75       # MAE masking ratio
    SSL_PATCH_SIZE = 8          # ViT patch size (8x8 for 64x64 images)

    # ─── Fine-tuning ──────────────────────────────────────────────
    FT_EPOCHS      = 40
    FT_BATCH_SIZE  = 32
    FT_LR          = 5e-5
    LR_DECAY       = 0.75       # layer-wise LR decay factor
    WEIGHT_DECAY   = 0.05
    WARMUP_EPOCHS  = 5
    LABEL_SMOOTH   = 0.1
    N_FOLDS        = 5

    # ─── Ensemble ─────────────────────────────────────────────────
    # Weights: [ViT-MAE, Swin-Tiny, ConvNeXt-Small]
    ENSEMBLE_WEIGHTS = [0.45, 0.30, 0.25]
    TTA_AUGS         = 8        # number of TTA augmentations

## 2. Data Loading & Preprocessing

In [42]:
# ─── Band Loading Utilities ───────────────────────────────────────────────────

def load_sentinel2_patch(folder_path, img_size=CFG.IMG_SIZE):
    """
    Load all 12 Sentinel-2 bands from a folder of .tif files.
    Returns numpy array of shape (12, H, W), normalized to [0, 1].
    """
    bands = []
    for band_name in CFG.BANDS:
        tif_path = os.path.join(folder_path, f'{band_name}.tif')
        if not os.path.exists(tif_path):
            # Try alternative naming
            candidates = glob.glob(os.path.join(folder_path, f'*{band_name}*.tif'))
            tif_path = candidates[0] if candidates else None
        
        if tif_path and os.path.exists(tif_path):
            with rasterio.open(tif_path) as src:
                arr = src.read(1).astype(np.float32)
        else:
            arr = np.zeros((img_size, img_size), dtype=np.float32)
        
        # Resize if needed
        if arr.shape != (img_size, img_size):
            from PIL import Image
            arr = np.array(Image.fromarray(arr).resize((img_size, img_size), Image.BILINEAR))
        
        bands.append(arr)
    
    img = np.stack(bands, axis=0)  # (12, H, W)
    
    # Normalize each band to [0, 1] using percentile clipping
    for i in range(img.shape[0]):
        p2, p98 = np.percentile(img[i], [2, 98])
        img[i] = np.clip(img[i], p2, p98)
        if p98 > p2:
            img[i] = (img[i] - p2) / (p98 - p2)
        else:
            img[i] = np.zeros_like(img[i])
    
    return img


def compute_vegetation_indices(img):
    """
    Compute 4 vegetation indices from 12-band Sentinel-2 image.
    img: numpy (12, H, W) normalized [0,1]
    Returns: (16, H, W) — original 12 bands + 4 indices
    
    Band mapping (0-indexed):
      0:B1, 1:B2, 2:B3, 3:B4, 4:B5, 5:B6, 6:B7,
      7:B8, 8:B8A, 9:B9, 10:B11, 11:B12
    """
    eps = 1e-6
    B2, B4, B5, B6 = img[1], img[3], img[4], img[5]
    B8, B8A, B11, B12 = img[7], img[8], img[10], img[11]

    NDVI  = (B8  - B4)  / (B8  + B4  + eps)          # vegetation health
    NDRE  = (B8A - B5)  / (B8A + B5  + eps)          # red-edge disease
    PSRI  = (B4  - B2)  / (B6  + eps)                # plant senescence (rust)
    SWIR  = B11         / (B12 + eps)                 # moisture stress

    indices = np.stack([NDVI, NDRE, PSRI, SWIR], axis=0)  # (4, H, W)
    
    # Normalize indices to [0, 1]
    for i in range(4):
        mn, mx = indices[i].min(), indices[i].max()
        if mx > mn:
            indices[i] = (indices[i] - mn) / (mx - mn)
    
    return np.concatenate([img, indices], axis=0)  # (16, H, W)


print('Data utilities defined.')

Data utilities defined.


In [43]:
# ─── FIXED Dataset Discovery ─────────────────────────────────────────────────

# Override paths — data lives inside 'kaggle' subfolder
CFG.LABELED_ROOT = '/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/ICPR02/kaggle'
CFG.S2A_ROOT     = '/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/s2a'

# Update class names to match exact folder names on disk
CFG.CLASS_NAMES = ['Aphid', 'Blast', 'RPH', 'Rust']
CFG.CLASS2IDX   = {c: i for i, c in enumerate(CFG.CLASS_NAMES)}
CFG.NUM_CLASSES = 4

def discover_labeled_data(root):
    """
    Structure: root/ClassName/sample_hash/ containing B1.tif ... B12.tif
    """
    records = []
    for cls in CFG.CLASS_NAMES:
        cls_dir = os.path.join(root, cls)
        if not os.path.isdir(cls_dir):
            print(f'  WARNING: class folder not found: {cls_dir}')
            continue
        samples = [s for s in os.listdir(cls_dir)
                   if os.path.isdir(os.path.join(cls_dir, s))]
        for sample in samples:
            records.append({
                'path':      os.path.join(cls_dir, sample),
                'label':     cls,
                'label_idx': CFG.CLASS2IDX[cls]
            })
        print(f'  {cls:10s}: {len(samples)} samples')

    df = pd.DataFrame(records)
    print(f'\nTotal labeled samples: {len(df)}')
    return df


def discover_unlabeled_data(root, max_samples=5000):
    """
    Find unlabeled Sentinel-2 folders for SSL pretraining.
    """
    folders = []
    for dirpath, dirnames, filenames in os.walk(root):
        if any(f.endswith('.tif') for f in filenames):
            folders.append(dirpath)
        if len(folders) >= max_samples:
            break
    print(f'Found {len(folders)} unlabeled S2 folders')
    return folders


def discover_test_data(root):
    """
    Test set: root/evaluation/sample_hash/
    """
    eval_dir = os.path.join(root, 'evaluation')
    if not os.path.isdir(eval_dir):
        print(f'WARNING: evaluation folder not found at {eval_dir}')
        return []
    folders = [os.path.join(eval_dir, s)
               for s in sorted(os.listdir(eval_dir))
               if os.path.isdir(os.path.join(eval_dir, s))]
    print(f'Test samples: {len(folders)}')
    return folders


# Discover data
labeled_df        = discover_labeled_data(CFG.LABELED_ROOT)
unlabeled_folders = discover_unlabeled_data(CFG.S2A_ROOT)
test_folders      = discover_test_data(CFG.LABELED_ROOT)

  Aphid     : 290 samples
  Blast     : 75 samples
  RPH       : 495 samples
  Rust      : 40 samples

Total labeled samples: 900
Found 0 unlabeled S2 folders
Test samples: 40


In [44]:
# ─── PyTorch Datasets ────────────────────────────────────────────────────────

class S2UnlabeledDataset(Dataset):
    """Unlabeled Sentinel-2 dataset for SSL pretraining."""
    
    def __init__(self, folders, img_size=CFG.IMG_SIZE):
        self.folders = folders
        self.img_size = img_size
    
    def __len__(self):
        return len(self.folders)
    
    def __getitem__(self, idx):
        try:
            img = load_sentinel2_patch(self.folders[idx], self.img_size)
            img = compute_vegetation_indices(img)   # (16, H, W)
            return torch.FloatTensor(img)
        except Exception as e:
            # Return zeros on error — SSL can tolerate occasional bad samples
            return torch.zeros(CFG.IN_CHANNELS, self.img_size, self.img_size)


class S2LabeledDataset(Dataset):
    """Labeled dataset for fine-tuning."""
    
    def __init__(self, df, img_size=CFG.IMG_SIZE, augment=True):
        self.df = df.reset_index(drop=True)
        self.img_size = img_size
        self.augment = augment
    
    def __len__(self):
        return len(self.df)
    
    def _augment(self, img):
        """Spectral + spatial augmentations."""
        # Random horizontal flip
        if random.random() > 0.5:
            img = np.flip(img, axis=2).copy()
        # Random vertical flip
        if random.random() > 0.5:
            img = np.flip(img, axis=1).copy()
        # Random 90° rotation
        k = random.randint(0, 3)
        img = np.rot90(img, k=k, axes=(1, 2)).copy()
        # Spectral dropout: randomly zero out 1-2 bands
        if random.random() > 0.7:
            n_drop = random.randint(1, 2)
            drop_idx = random.sample(range(12), n_drop)  # only drop raw bands, not indices
            for d in drop_idx:
                img[d] = 0.0
        return img
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        try:
            img = load_sentinel2_patch(row['path'], self.img_size)
            img = compute_vegetation_indices(img)  # (16, H, W)
            if self.augment:
                img = self._augment(img)
        except Exception:
            img = np.zeros((CFG.IN_CHANNELS, self.img_size, self.img_size), dtype=np.float32)
        
        return torch.FloatTensor(img), torch.tensor(row['label_idx'], dtype=torch.long)


print('Datasets defined.')

Datasets defined.


## 3. Stage 1: MAE Self-Supervised Pretraining

In [45]:
# ─── MAE Model ───────────────────────────────────────────────────────────────
# Masked Autoencoder adapted for 16-channel Sentinel-2 input

class PatchEmbed(nn.Module):
    """2D image to patch embeddings — supports arbitrary number of input channels."""
    
    def __init__(self, img_size=64, patch_size=8, in_chans=16, embed_dim=384):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W) → (B, n_patches, embed_dim)
        x = self.proj(x)                    # (B, embed_dim, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)    # (B, n_patches, embed_dim)
        return x


class TransformerBlock(nn.Module):
    """Standard Transformer block with pre-norm."""
    
    def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        mlp_dim    = int(dim * mlp_ratio)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, mlp_dim), nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim), nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class MAEEncoder(nn.Module):
    """ViT-Small encoder (memory-efficient for T4)."""
    
    def __init__(self,
                 img_size=CFG.IMG_SIZE,
                 patch_size=CFG.SSL_PATCH_SIZE,
                 in_chans=CFG.IN_CHANNELS,
                 embed_dim=384,           # ViT-Small dimension
                 depth=12,
                 n_heads=6):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.n_patches   = self.patch_embed.n_patches
        self.embed_dim   = embed_dim
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, n_heads) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None: nn.init.zeros_(m.bias)
    
    def random_masking(self, x, mask_ratio):
        """
        Random token masking for MAE.
        Returns: x_masked, mask, ids_restore
        """
        B, N, D = x.shape
        keep = int(N * (1 - mask_ratio))
        
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        
        ids_keep    = ids_shuffle[:, :keep]
        x_masked    = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
        
        mask = torch.ones(B, N, device=x.device)
        mask[:, :keep] = 0
        mask = torch.gather(mask, 1, ids_restore)
        
        return x_masked, mask, ids_restore
    
    def forward(self, x, mask_ratio=0.0):
        x = self.patch_embed(x)
        x = x + self.pos_embed[:, 1:, :]  # add positional embedding (no cls)
        
        mask, ids_restore = None, None
        if mask_ratio > 0:
            x, mask, ids_restore = self.random_masking(x, mask_ratio)
        
        cls = self.cls_token.expand(x.shape[0], -1, -1) + self.pos_embed[:, :1, :]
        x   = torch.cat([cls, x], dim=1)
        
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        return x, mask, ids_restore


class MAEDecoder(nn.Module):
    """Lightweight MAE decoder — only used during pretraining."""
    
    def __init__(self, n_patches, encoder_dim=384, decoder_dim=192,
                 depth=4, n_heads=4, patch_size=8, in_chans=16):
        super().__init__()
        self.n_patches    = n_patches
        self.patch_size   = patch_size
        self.in_chans     = in_chans
        
        self.proj         = nn.Linear(encoder_dim, decoder_dim)
        self.mask_token   = nn.Parameter(torch.zeros(1, 1, decoder_dim))
        self.pos_embed    = nn.Parameter(torch.zeros(1, n_patches + 1, decoder_dim))
        
        self.blocks       = nn.ModuleList([
            TransformerBlock(decoder_dim, n_heads) for _ in range(depth)
        ])
        self.norm         = nn.LayerNorm(decoder_dim)
        self.head         = nn.Linear(decoder_dim, patch_size * patch_size * in_chans)
        
        nn.init.trunc_normal_(self.mask_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
    def forward(self, x, ids_restore):
        x = self.proj(x)
        B, n_keep_plus1, D = x.shape
        n_keep = n_keep_plus1 - 1  # subtract cls token
        
        # Expand mask tokens
        mask_tokens = self.mask_token.expand(B, self.n_patches - n_keep, -1)
        x_no_cls    = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # (B, N, D)
        x_no_cls    = torch.gather(x_no_cls, 1,
                          ids_restore.unsqueeze(-1).expand(-1, -1, D))
        x           = torch.cat([x[:, :1, :], x_no_cls], dim=1)   # add cls back
        x           = x + self.pos_embed
        
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        x = self.head(x[:, 1:, :])  # remove cls, predict patches
        return x


class MAE(nn.Module):
    """Full MAE model for pretraining."""
    
    def __init__(self, mask_ratio=CFG.SSL_MASK_RATIO):
        super().__init__()
        self.mask_ratio = mask_ratio
        self.encoder    = MAEEncoder()
        self.decoder    = MAEDecoder(
            n_patches   = self.encoder.n_patches,
            encoder_dim = self.encoder.embed_dim
        )
    
    def patchify(self, x):
        """x: (B, C, H, W) → (B, n_patches, patch_size²×C)"""
        P = CFG.SSL_PATCH_SIZE
        B, C, H, W = x.shape
        x = x.reshape(B, C, H//P, P, W//P, P)
        x = x.permute(0, 2, 4, 3, 5, 1).reshape(B, (H//P)*(W//P), P*P*C)
        return x
    
    def forward(self, x):
        target = self.patchify(x)
        
        latent, mask, ids_restore = self.encoder(x, mask_ratio=self.mask_ratio)
        pred = self.decoder(latent, ids_restore)
        
        # MSE loss only on masked patches
        loss = ((pred - target) ** 2)
        loss = loss.mean(dim=-1)          # (B, N)
        loss = (loss * mask).sum() / (mask.sum() + 1e-6)
        return loss


print('MAE model defined.')

MAE model defined.


In [46]:
# ─── SSL Pretraining Loop ─────────────────────────────────────────────────────

def pretrain_mae(folders, save_path=os.path.join(CFG.OUTPUT_DIR, 'mae_pretrained.pth')):
    """
    Pretrain MAE on unlabeled Sentinel-2 data.
    Returns path to saved encoder weights.
    """
    print('\n' + '='*60)
    print('STAGE 1: MAE Self-Supervised Pretraining')
    print('='*60)
    
    # Dataset
    dataset = S2UnlabeledDataset(folders)
    loader  = DataLoader(dataset, batch_size=CFG.SSL_BATCH_SIZE,
                         shuffle=True, num_workers=4,
                         pin_memory=True, drop_last=True)
    print(f'SSL dataset size: {len(dataset)}')
    print(f'Steps per epoch:  {len(loader)}')
    
    # Model
    model = MAE()
    if torch.cuda.device_count() > 1:
        print(f'Using {torch.cuda.device_count()} GPUs for pretraining')
        model = nn.DataParallel(model)
    model = model.to(device)
    
    # Optimizer & Scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=CFG.SSL_LR,
        betas=(0.9, 0.95), weight_decay=0.05
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=CFG.SSL_EPOCHS
    )
    scaler = torch.cuda.amp.GradScaler()  # mixed precision
    
    best_loss = float('inf')
    for epoch in range(1, CFG.SSL_EPOCHS + 1):
        model.train()
        total_loss = 0
        
        pbar = tqdm(loader, desc=f'SSL Epoch {epoch}/{CFG.SSL_EPOCHS}')
        for imgs in pbar:
            imgs = imgs.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                loss = model(imgs)
                if isinstance(loss, torch.Tensor) is False:
                    loss = loss.mean()  # DataParallel returns tuple
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        scheduler.step()
        avg_loss = total_loss / len(loader)
        print(f'Epoch {epoch:3d} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}')
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            # Save encoder only (decoder not needed for fine-tuning)
            enc_state = model.module.encoder.state_dict() if hasattr(model, 'module') \
                        else model.encoder.state_dict()
            torch.save(enc_state, save_path)
            print(f'  ✓ Saved encoder (loss={best_loss:.4f})')
    
    print(f'\nPretraining complete. Best loss: {best_loss:.4f}')
    return save_path


# Run pretraining (skip if already done)
ssl_weights_path = os.path.join(CFG.OUTPUT_DIR, 'mae_pretrained.pth')
if not os.path.exists(ssl_weights_path) and len(unlabeled_folders) > 0:
    ssl_weights_path = pretrain_mae(unlabeled_folders, ssl_weights_path)
else:
    print(f'SSL weights found / no unlabeled data: {ssl_weights_path}')

SSL weights found / no unlabeled data: /kaggle/working/mae_pretrained.pth


## 4. Stage 2: Fine-Tuning Three Models

In [47]:
# ─── Model 1: MAE ViT-Base (SSL pretrained) ───────────────────────────────────
class ViTClassifier(nn.Module):
    def __init__(self, encoder, num_classes=CFG.NUM_CLASSES, dropout=0.2):
        super().__init__()
        self.encoder = encoder
        dim = encoder.embed_dim
        self.head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Dropout(dropout),
            nn.Linear(dim, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        out, _, _ = self.encoder(x, mask_ratio=0.0)
        cls = out[:, 0, :]
        return self.head(cls)


def build_vit_model(ssl_weights_path):
    encoder = MAEEncoder()
    if os.path.exists(ssl_weights_path):
        state = torch.load(ssl_weights_path, map_location='cpu')
        encoder.load_state_dict(state, strict=False)
        print('Loaded MAE pretrained encoder weights')
    else:
        print('WARNING: No SSL weights found, training ViT from scratch')
    return ViTClassifier(encoder)


# ─── Model 2: Swin-Tiny (fixed for 64x64 input + 16 channels) ────────────────
def build_swin_model(num_classes=CFG.NUM_CLASSES):
    model = timm.create_model(
        'swin_tiny_patch4_window7_224',
        pretrained=True,
        num_classes=num_classes,
        img_size=CFG.IMG_SIZE          # ← KEY FIX: tell timm to use 64x64
    )

    old_proj   = model.patch_embed.proj
    old_weight = old_proj.weight.data  # (96, 3, 4, 4)

    new_proj = nn.Conv2d(
        CFG.IN_CHANNELS, old_proj.out_channels,
        kernel_size=old_proj.kernel_size,
        stride=old_proj.stride,
        padding=old_proj.padding,
        bias=False
    )
    with torch.no_grad():
        repeats    = CFG.IN_CHANNELS // 3 + 1
        new_weight = old_weight.repeat(1, repeats, 1, 1)[:, :CFG.IN_CHANNELS, :, :]
        new_weight = new_weight * (3.0 / CFG.IN_CHANNELS)
        new_proj.weight.data = new_weight

    model.patch_embed.proj = new_proj
    print(f'Swin-Tiny: adapted to {CFG.IN_CHANNELS}-channel {CFG.IMG_SIZE}x{CFG.IMG_SIZE} input')
    return model


# ─── Model 3: ConvNeXt-Small (16 channels, no size restriction) ───────────────
def build_convnext_model(num_classes=CFG.NUM_CLASSES):
    model = timm.create_model(
        'convnext_small.fb_in22k',
        pretrained=True,
        num_classes=num_classes
    )

    old_conv   = model.stem[0]
    old_weight = old_conv.weight.data  # (96, 3, 4, 4)

    new_conv = nn.Conv2d(
        CFG.IN_CHANNELS, old_conv.out_channels,
        kernel_size=old_conv.kernel_size,
        stride=old_conv.stride,
        padding=old_conv.padding,
        bias=False
    )
    with torch.no_grad():
        repeats    = CFG.IN_CHANNELS // 3 + 1
        new_weight = old_weight.repeat(1, repeats, 1, 1)[:, :CFG.IN_CHANNELS, :, :]
        new_weight = new_weight * (3.0 / CFG.IN_CHANNELS)
        new_conv.weight.data = new_weight

    model.stem[0] = new_conv
    print(f'ConvNeXt-Small: adapted to {CFG.IN_CHANNELS}-channel input')
    return model


print('All model builders defined.')

All model builders defined.


In [48]:
# ─── Training Utilities ───────────────────────────────────────────────────────

class FocalLoss(nn.Module):
    """Focal loss for class-imbalanced datasets."""
    
    def __init__(self, gamma=2.0, label_smooth=CFG.LABEL_SMOOTH):
        super().__init__()
        self.gamma        = gamma
        self.label_smooth = label_smooth
    
    def forward(self, logits, targets):
        B, C = logits.shape
        # Label smoothing
        with torch.no_grad():
            soft_targets = torch.full_like(logits, self.label_smooth / (C - 1))
            soft_targets.scatter_(1, targets.unsqueeze(1), 1 - self.label_smooth)
        
        log_probs = F.log_softmax(logits, dim=-1)
        probs     = log_probs.exp()
        focal_w   = (1 - probs) ** self.gamma
        loss      = -(soft_targets * focal_w * log_probs).sum(dim=-1)
        return loss.mean()


def mixup_data(x, y, alpha=0.2):
    """MixUp augmentation."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.0
    idx = torch.randperm(x.size(0), device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    return mixed_x, y, y[idx], lam


def get_layer_wise_optimizer(model, base_lr, decay=CFG.LR_DECAY):
    """
    Layer-wise learning rate decay:
    - Deepest (earliest) layers get lower LR
    - Final classification head gets full LR
    """
    param_groups = []
    named_params = list(model.named_parameters())
    n_layers = len(named_params)
    
    for i, (name, param) in enumerate(named_params):
        layer_scale = decay ** (n_layers - i - 1)
        # Head layers get full LR
        if any(k in name for k in ['head', 'classifier', 'fc']):
            layer_scale = 1.0
        param_groups.append({'params': [param], 'lr': base_lr * layer_scale})
    
    return torch.optim.AdamW(param_groups, weight_decay=CFG.WEIGHT_DECAY)


def get_cosine_schedule_with_warmup(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return epoch / max(1, warmup_epochs)
        progress = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs)
        return 0.5 * (1 + np.cos(np.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


print('Training utilities defined.')

Training utilities defined.


In [49]:
# ─── Fine-Tuning Loop ─────────────────────────────────────────────────────────

def train_one_epoch(model, loader, optimizer, criterion, scaler, epoch):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for imgs, labels in tqdm(loader, desc=f'  Train epoch {epoch}', leave=False):
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        # MixUp
        imgs, y_a, y_b, lam = mixup_data(imgs, labels)
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            logits = model(imgs)
            loss   = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total   += labels.size(0)
    
    return total_loss / len(loader), correct / total


@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    
    for imgs, labels in tqdm(loader, desc='  Validating', leave=False):
        imgs = imgs.to(device, non_blocking=True)
        with torch.cuda.amp.autocast():
            logits = model(imgs)
        preds = logits.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    f1  = f1_score(all_labels, all_preds, average='macro')
    return acc, f1, all_preds, all_labels


def finetune_model(model, train_df, val_df, model_name, save_path):
    """Fine-tune a model on labeled data."""
    print(f'\nFine-tuning: {model_name}')
    
    # Datasets
    train_ds = S2LabeledDataset(train_df, augment=True)
    val_ds   = S2LabeledDataset(val_df,   augment=False)
    
    # Class-balanced sampler
    class_counts = train_df['label_idx'].value_counts().sort_index().values
    weights      = 1.0 / class_counts[train_df['label_idx'].values]
    sampler      = WeightedRandomSampler(weights, len(weights), replacement=True)
    
    train_loader = DataLoader(train_ds, batch_size=CFG.FT_BATCH_SIZE,
                              sampler=sampler, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=CFG.FT_BATCH_SIZE * 2,
                              shuffle=False, num_workers=4, pin_memory=True)
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    
    optimizer = get_layer_wise_optimizer(model, CFG.FT_LR)
    scheduler = get_cosine_schedule_with_warmup(optimizer, CFG.WARMUP_EPOCHS, CFG.FT_EPOCHS)
    criterion = FocalLoss()
    scaler    = torch.cuda.amp.GradScaler()
    
    best_f1 = 0
    history = []
    
    for epoch in range(1, CFG.FT_EPOCHS + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler, epoch)
        val_acc, val_f1, _, _ = evaluate(model, val_loader)
        scheduler.step()
        
        history.append({'epoch': epoch, 'tr_loss': tr_loss, 'tr_acc': tr_acc,
                         'val_acc': val_acc, 'val_f1': val_f1})
        
        print(f'  Epoch {epoch:3d} | tr_loss={tr_loss:.4f} tr_acc={tr_acc:.4f} '
              f'val_acc={val_acc:.4f} val_f1={val_f1:.4f}', end='')
        
        if val_f1 > best_f1:
            best_f1 = val_f1
            state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
            torch.save(state, save_path)
            print(' ✓ saved', end='')
        print()
    
    print(f'  Best val F1: {best_f1:.4f}')
    return best_f1, history


print('Fine-tuning loop defined.')

Fine-tuning loop defined.


In [50]:
# ─── DEBUG CELL — run this and paste the output ───────────────────────────────
import os, glob

LABELED_ROOT = '/kaggle/input/beyond-visible-spectrum-ai-for-agriculture-2026p2/ICPR02'

print("=== TOP-LEVEL CONTENTS ===")
if os.path.exists(LABELED_ROOT):
    for item in sorted(os.listdir(LABELED_ROOT))[:30]:
        full = os.path.join(LABELED_ROOT, item)
        kind = 'DIR' if os.path.isdir(full) else 'FILE'
        print(f"  [{kind}] {item}")
else:
    print(f"PATH DOES NOT EXIST: {LABELED_ROOT}")
    # Try to find it
    print("\n=== Searching for ICPR02 ===")
    for root, dirs, files in os.walk('/kaggle/input'):
        for d in dirs:
            if 'ICPR' in d or 'icpr' in d.lower():
                print(f"  Found: {os.path.join(root, d)}")

print("\n=== FIRST SUBFOLDER CONTENTS ===")
if os.path.exists(LABELED_ROOT):
    items = sorted(os.listdir(LABELED_ROOT))
    if items:
        first = os.path.join(LABELED_ROOT, items[0])
        if os.path.isdir(first):
            sub = sorted(os.listdir(first))[:20]
            print(f"Inside '{items[0]}':")
            for s in sub:
                full2 = os.path.join(first, s)
                kind  = 'DIR' if os.path.isdir(full2) else 'FILE'
                print(f"  [{kind}] {s}")
            # Go one level deeper if subfolders exist
            subdirs = [s for s in sub if os.path.isdir(os.path.join(first, s))]
            if subdirs:
                deeper = os.path.join(first, subdirs[0])
                print(f"\n  Inside '{subdirs[0]}':")
                for x in sorted(os.listdir(deeper))[:20]:
                    print(f"    {x}")

print("\n=== JSON FILES (first 10) ===")
jsons = glob.glob(os.path.join(LABELED_ROOT, '**/*.json'), recursive=True)
print(f"Total JSON files: {len(jsons)}")
for j in jsons[:10]:
    print(f"  {j}")
    with open(j) as f:
        import json
        print(f"    contents: {json.load(f)}")

print("\n=== TIF FILES sample (first 5) ===")
tifs = glob.glob(os.path.join(LABELED_ROOT, '**/*.tif'), recursive=True)
print(f"Total TIF files: {len(tifs)}")
for t in tifs[:5]:
    print(f"  {t}")

=== TOP-LEVEL CONTENTS ===
  [DIR] kaggle

=== FIRST SUBFOLDER CONTENTS ===
Inside 'kaggle':
  [DIR] Aphid
  [DIR] Blast
  [DIR] RPH
  [DIR] Rust
  [DIR] evaluation

  Inside 'Aphid':
    0041231a3f6f4fa9b07a04234cef4627
    00e6adf1215344a0aa3d396aa50eff0c
    018a98b4251441f0b9e73b9d286541f2
    0299e35c64b74d3d896f7a22227cde31
    035d3057f7af4f1c9b47e2a325b71be2
    03b38c96e4d8428cbd726ebf8e9a368a
    044515d231cc403e80b6a1e2b45862b7
    05688b9f038941e9b291fab9aaf140ee
    05690e774e08431896edaded9196e3e5
    05915a998f9b4a7e923f6ee96643f2c2
    066c5e545c9d4ddd9a3c637ca0cccc17
    0801155fd3a749a0932140c1b0b7ab42
    083263e8f4774f95b02656705c9aa298
    08924b82960441eb82444741c8102c93
    09210ee84323441db5e4dd5967c2a3b6
    09bdf2946f6d466c8b8a7ecd5ce3e231
    09c6fee5137f4613bf3591055ac0ce9e
    09d84961c05548c386ef69eb2d008a7a
    0b5c38795f354b82ad57aede75610004
    0d86f8fa182043e4ad8df395e9079bcf

=== JSON FILES (first 10) ===
Total JSON files: 0

=== TIF FILES sample (fi

In [None]:
# ─── Run Fine-Tuning for All 3 Models ─────────────────────────────────────────

# 5-fold CV — we use fold 0 for simplicity; run all folds for full ensemble diversity
skf = StratifiedKFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=42)
splits = list(skf.split(labeled_df, labeled_df['label_idx']))

# Use fold 0
train_idx, val_idx = splits[0]
train_df = labeled_df.iloc[train_idx]
val_df   = labeled_df.iloc[val_idx]
print(f'Train: {len(train_df)} | Val: {len(val_df)}')

model_configs = [
    ('vit_mae',   build_vit_model,    {'ssl_weights_path': ssl_weights_path}),
    ('swin_tiny', build_swin_model,   {}),
    ('convnext',  build_convnext_model, {}),
]

all_results = {}

for model_name, builder, kwargs in model_configs:
    save_path = os.path.join(CFG.OUTPUT_DIR, f'{model_name}_best.pth')
    model     = builder(**kwargs)
    best_f1, history = finetune_model(model, train_df, val_df, model_name, save_path)
    all_results[model_name] = {'best_f1': best_f1, 'save_path': save_path, 'history': history}
    # Free GPU memory between models
    del model
    torch.cuda.empty_cache()

print('\n' + '='*60)
print('Fine-tuning Summary:')
for name, res in all_results.items():
    print(f'  {name:15s} | Best Val F1: {res["best_f1"]:.4f}')

Train: 720 | Val: 180

Fine-tuning: vit_mae


                                                                

  Epoch   1 | tr_loss=0.8018 tr_acc=0.2458 val_acc=0.5222 val_f1=0.1822 ✓ saved


                                                                

  Epoch   2 | tr_loss=0.7862 tr_acc=0.2792 val_acc=0.3667 val_f1=0.1896 ✓ saved


                                                                

  Epoch   3 | tr_loss=0.7859 tr_acc=0.2597 val_acc=0.4500 val_f1=0.3193 ✓ saved


                                                                

  Epoch   4 | tr_loss=0.7727 tr_acc=0.2833 val_acc=0.4556 val_f1=0.2609


                                                                

  Epoch   5 | tr_loss=0.7709 tr_acc=0.2764 val_acc=0.4222 val_f1=0.3515 ✓ saved


                                                                

  Epoch   6 | tr_loss=0.7514 tr_acc=0.2944 val_acc=0.2833 val_f1=0.2120


                                                                

  Epoch   7 | tr_loss=0.7493 tr_acc=0.3139 val_acc=0.4889 val_f1=0.2963


                                                                

  Epoch   8 | tr_loss=0.7371 tr_acc=0.3264 val_acc=0.4556 val_f1=0.3198


                                                                

  Epoch   9 | tr_loss=0.7378 tr_acc=0.3472 val_acc=0.5389 val_f1=0.3630 ✓ saved


                                                                 

  Epoch  10 | tr_loss=0.7346 tr_acc=0.2861 val_acc=0.4000 val_f1=0.3114


                                                                 

  Epoch  11 | tr_loss=0.7117 tr_acc=0.3764 val_acc=0.5111 val_f1=0.3480


                                                                 

  Epoch  12 | tr_loss=0.7293 tr_acc=0.3417 val_acc=0.4444 val_f1=0.3390


                                                                 

  Epoch  13 | tr_loss=0.7110 tr_acc=0.2833 val_acc=0.4444 val_f1=0.3304


                                                                 

  Epoch  14 | tr_loss=0.7255 tr_acc=0.3194 val_acc=0.4278 val_f1=0.3391


                                                                 

  Epoch  15 | tr_loss=0.7266 tr_acc=0.3181 val_acc=0.4444 val_f1=0.3375


                                                                 

  Epoch  16 | tr_loss=0.7106 tr_acc=0.3069 val_acc=0.4611 val_f1=0.3385


                                                                 

  Epoch  17 | tr_loss=0.7167 tr_acc=0.3236 val_acc=0.4444 val_f1=0.3303


                                                                 

  Epoch  18 | tr_loss=0.6983 tr_acc=0.3306 val_acc=0.4944 val_f1=0.3599


                                                                 

  Epoch  19 | tr_loss=0.7057 tr_acc=0.3736 val_acc=0.4611 val_f1=0.3486


                                                                 

  Epoch  20 | tr_loss=0.7168 tr_acc=0.3333 val_acc=0.5056 val_f1=0.3675 ✓ saved


                                                                 

  Epoch  21 | tr_loss=0.7043 tr_acc=0.3361 val_acc=0.4778 val_f1=0.3245


                                                                 

  Epoch  22 | tr_loss=0.7157 tr_acc=0.3486 val_acc=0.4778 val_f1=0.3200


                                                                 

  Epoch  23 | tr_loss=0.7107 tr_acc=0.3542 val_acc=0.5444 val_f1=0.3548


                                                                 

  Epoch  24 | tr_loss=0.7018 tr_acc=0.3667 val_acc=0.5611 val_f1=0.3427


                                                                 

  Epoch  25 | tr_loss=0.6930 tr_acc=0.3458 val_acc=0.5556 val_f1=0.3610


                                                                 

  Epoch  26 | tr_loss=0.6922 tr_acc=0.3792 val_acc=0.5556 val_f1=0.3603


                                                                 

  Epoch  27 | tr_loss=0.6890 tr_acc=0.3333 val_acc=0.5389 val_f1=0.3903 ✓ saved


                                                                 

  Epoch  28 | tr_loss=0.7042 tr_acc=0.3431 val_acc=0.5000 val_f1=0.3667


                                                                 

  Epoch  29 | tr_loss=0.6887 tr_acc=0.3500 val_acc=0.4722 val_f1=0.3539


                                                                 

  Epoch  30 | tr_loss=0.6927 tr_acc=0.3333 val_acc=0.5056 val_f1=0.3710


                                                                 

  Epoch  31 | tr_loss=0.7198 tr_acc=0.3625 val_acc=0.5000 val_f1=0.3491


                                                                 

  Epoch  32 | tr_loss=0.6843 tr_acc=0.3889 val_acc=0.4944 val_f1=0.3455


                                                                 

  Epoch  33 | tr_loss=0.7064 tr_acc=0.3708 val_acc=0.5111 val_f1=0.3738


                                                                 

  Epoch  34 | tr_loss=0.6820 tr_acc=0.3264 val_acc=0.5000 val_f1=0.3667


                                                                 

  Epoch  35 | tr_loss=0.6840 tr_acc=0.3056 val_acc=0.4944 val_f1=0.3633


                                                                 

  Epoch  36 | tr_loss=0.6731 tr_acc=0.3917 val_acc=0.4944 val_f1=0.3637


                                                                 

  Epoch  37 | tr_loss=0.6991 tr_acc=0.3181 val_acc=0.5000 val_f1=0.3673


                                                                 

  Epoch  38 | tr_loss=0.6934 tr_acc=0.3764 val_acc=0.5000 val_f1=0.3673


                                                                 

  Epoch  39 | tr_loss=0.6950 tr_acc=0.3583 val_acc=0.5000 val_f1=0.3673


                                                                 

  Epoch  40 | tr_loss=0.7018 tr_acc=0.3764 val_acc=0.5000 val_f1=0.3673
  Best val F1: 0.3903




Swin-Tiny: adapted to 16-channel 64x64 input

Fine-tuning: swin_tiny


                                                                

  Epoch   1 | tr_loss=0.9066 tr_acc=0.2444 val_acc=0.2333 val_f1=0.1871 ✓ saved


                                                                

  Epoch   2 | tr_loss=0.8481 tr_acc=0.2611 val_acc=0.3389 val_f1=0.2268 ✓ saved


                                                                

  Epoch   3 | tr_loss=0.7466 tr_acc=0.2917 val_acc=0.4500 val_f1=0.3977 ✓ saved


                                                                

  Epoch   4 | tr_loss=0.6600 tr_acc=0.4056 val_acc=0.5500 val_f1=0.4127 ✓ saved


                                                                

  Epoch   5 | tr_loss=0.6377 tr_acc=0.3736 val_acc=0.5056 val_f1=0.4221 ✓ saved


                                                                

  Epoch   6 | tr_loss=0.6000 tr_acc=0.4306 val_acc=0.5389 val_f1=0.4521 ✓ saved


                                                                

  Epoch   7 | tr_loss=0.5901 tr_acc=0.4458 val_acc=0.6333 val_f1=0.5217 ✓ saved


                                                                

  Epoch   8 | tr_loss=0.5121 tr_acc=0.4917 val_acc=0.7056 val_f1=0.4911


                                                                

  Epoch   9 | tr_loss=0.5206 tr_acc=0.4944 val_acc=0.6333 val_f1=0.5132


                                                                 

  Epoch  10 | tr_loss=0.5173 tr_acc=0.5472 val_acc=0.5278 val_f1=0.4734


                                                                 

  Epoch  11 | tr_loss=0.4978 tr_acc=0.5333 val_acc=0.6111 val_f1=0.5076


                                                                 

  Epoch  12 | tr_loss=0.4616 tr_acc=0.5458 val_acc=0.6667 val_f1=0.5470 ✓ saved


                                                                 

  Epoch  13 | tr_loss=0.4914 tr_acc=0.4569 val_acc=0.6000 val_f1=0.5184


                                                                 

  Epoch  14 | tr_loss=0.4833 tr_acc=0.4431 val_acc=0.5833 val_f1=0.5408


                                                                 

  Epoch  15 | tr_loss=0.4914 tr_acc=0.5139 val_acc=0.7667 val_f1=0.6056 ✓ saved


                                                                 

  Epoch  16 | tr_loss=0.4292 tr_acc=0.5667 val_acc=0.6111 val_f1=0.5465


                                                                 

  Epoch  17 | tr_loss=0.4469 tr_acc=0.5889 val_acc=0.6667 val_f1=0.5407


                                                                 

  Epoch  18 | tr_loss=0.4204 tr_acc=0.5583 val_acc=0.7333 val_f1=0.5530


                                                                 

  Epoch  19 | tr_loss=0.3889 tr_acc=0.5750 val_acc=0.6611 val_f1=0.5526


                                                                 

  Epoch  20 | tr_loss=0.4249 tr_acc=0.5444 val_acc=0.6833 val_f1=0.5190


                                                                 

  Epoch  21 | tr_loss=0.4145 tr_acc=0.5653 val_acc=0.7000 val_f1=0.5788


                                                                 

  Epoch  22 | tr_loss=0.4301 tr_acc=0.5292 val_acc=0.6611 val_f1=0.5648


                                                                 

  Epoch  23 | tr_loss=0.4141 tr_acc=0.5278 val_acc=0.7389 val_f1=0.6051


                                                                 

  Epoch  24 | tr_loss=0.4313 tr_acc=0.5403 val_acc=0.6722 val_f1=0.5256


                                                                 

  Epoch  25 | tr_loss=0.4083 tr_acc=0.5889 val_acc=0.6889 val_f1=0.5584


                                                                 

  Epoch  26 | tr_loss=0.4172 tr_acc=0.6222 val_acc=0.6778 val_f1=0.5870


                                                                 

  Epoch  27 | tr_loss=0.4208 tr_acc=0.5194 val_acc=0.6722 val_f1=0.5701


                                                                 

  Epoch  28 | tr_loss=0.4210 tr_acc=0.6667 val_acc=0.6944 val_f1=0.5753


                                                                 

  Epoch  29 | tr_loss=0.3950 tr_acc=0.6528 val_acc=0.6722 val_f1=0.5538


                                                                 

  Epoch  30 | tr_loss=0.4252 tr_acc=0.6569 val_acc=0.7444 val_f1=0.5766


                                                                 

  Epoch  31 | tr_loss=0.3768 tr_acc=0.5528 val_acc=0.7333 val_f1=0.5822


                                                                 

  Epoch  32 | tr_loss=0.4253 tr_acc=0.5972 val_acc=0.6889 val_f1=0.5711


                                                                 

  Epoch  33 | tr_loss=0.3383 tr_acc=0.6194 val_acc=0.7056 val_f1=0.5583


                                                                 

  Epoch  34 | tr_loss=0.3700 tr_acc=0.5000 val_acc=0.7500 val_f1=0.5807


                                                                 

  Epoch  35 | tr_loss=0.3835 tr_acc=0.4917 val_acc=0.7056 val_f1=0.5616


                                                                 

  Epoch  36 | tr_loss=0.3868 tr_acc=0.6722 val_acc=0.7278 val_f1=0.5707


                                                                 

  Epoch  37 | tr_loss=0.3431 tr_acc=0.7361 val_acc=0.7222 val_f1=0.5665


                                                                 

  Epoch  38 | tr_loss=0.3861 tr_acc=0.5028 val_acc=0.7278 val_f1=0.5707


                                                                 

  Epoch  39 | tr_loss=0.3801 tr_acc=0.5014 val_acc=0.7278 val_f1=0.5707


                                                                 

  Epoch  40 | tr_loss=0.3748 tr_acc=0.7042 val_acc=0.7278 val_f1=0.5707
  Best val F1: 0.6056


model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

ConvNeXt-Small: adapted to 16-channel input

Fine-tuning: convnext


                                                                

  Epoch   1 | tr_loss=0.8822 tr_acc=0.2208 val_acc=0.5500 val_f1=0.1774 ✓ saved


                                                                

  Epoch   2 | tr_loss=0.7872 tr_acc=0.2667 val_acc=0.1278 val_f1=0.1197


                                                                

  Epoch   3 | tr_loss=0.7719 tr_acc=0.2500 val_acc=0.3056 val_f1=0.1378


                                                                

  Epoch   4 | tr_loss=0.7525 tr_acc=0.3056 val_acc=0.2722 val_f1=0.2454 ✓ saved


                                                                

  Epoch   5 | tr_loss=0.7044 tr_acc=0.3736 val_acc=0.2111 val_f1=0.2574 ✓ saved


                                                                

  Epoch   6 | tr_loss=0.6211 tr_acc=0.4514 val_acc=0.6111 val_f1=0.5119 ✓ saved


                                                                

  Epoch   7 | tr_loss=0.5742 tr_acc=0.4625 val_acc=0.5111 val_f1=0.4118


                                                                

  Epoch   8 | tr_loss=0.5371 tr_acc=0.4361 val_acc=0.3611 val_f1=0.3567


  Train epoch 9:   0%|          | 0/23 [00:00<?, ?it/s]

## 5. Stage 3: Ensemble + Test-Time Augmentation

In [None]:
# ─── TTA Helper ───────────────────────────────────────────────────────────────

def tta_augment(img_tensor, aug_idx):
    """
    8 TTA augmentations: 4 rotations × 2 flips.
    img_tensor: (C, H, W)
    """
    k    = aug_idx % 4          # rotation: 0, 90, 180, 270
    flip = aug_idx // 4         # 0: no flip, 1: horizontal flip
    
    img = img_tensor
    if k > 0:
        img = torch.rot90(img, k=k, dims=[1, 2])
    if flip:
        img = torch.flip(img, dims=[2])
    return img


@torch.no_grad()
def predict_with_tta(model, loader, n_tta=CFG.TTA_AUGS):
    """
    Get softmax predictions with TTA.
    Returns: numpy array (N, num_classes)
    """
    model.eval()
    all_probs = []
    
    for imgs, _ in tqdm(loader, desc='  Predicting with TTA', leave=False):
        B = imgs.shape[0]
        batch_probs = torch.zeros(B, CFG.NUM_CLASSES)
        
        for aug_i in range(n_tta):
            aug_imgs = torch.stack([tta_augment(imgs[j], aug_i) for j in range(B)])
            aug_imgs = aug_imgs.to(device, non_blocking=True)
            with torch.cuda.amp.autocast():
                logits = model(aug_imgs)
            batch_probs += F.softmax(logits.cpu().float(), dim=-1)
        
        batch_probs /= n_tta
        all_probs.append(batch_probs.numpy())
    
    return np.concatenate(all_probs, axis=0)


print('TTA utilities defined.')

In [None]:
# ─── Ensemble Inference ───────────────────────────────────────────────────────

def load_model_for_inference(model_name, weights_path):
    """Load a saved model for inference."""
    builders = {
        'vit_mae':   lambda: build_vit_model(ssl_weights_path),
        'swin_tiny': build_swin_model,
        'convnext':  build_convnext_model,
    }
    model = builders[model_name]()
    state = torch.load(weights_path, map_location='cpu')
    model.load_state_dict(state, strict=False)
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.eval()
    return model


def run_ensemble(val_df_or_test_df, is_test=False):
    """
    Run the 3-model weighted ensemble with TTA.
    Returns final predictions and (if val) accuracy/F1.
    """
    dataset = S2LabeledDataset(val_df_or_test_df, augment=False)
    loader  = DataLoader(dataset, batch_size=CFG.FT_BATCH_SIZE * 2,
                         shuffle=False, num_workers=4, pin_memory=True)
    
    model_names   = ['vit_mae', 'swin_tiny', 'convnext']
    ens_weights   = CFG.ENSEMBLE_WEIGHTS
    ensemble_probs = np.zeros((len(dataset), CFG.NUM_CLASSES))
    
    for name, w in zip(model_names, ens_weights):
        save_path = all_results[name]['save_path']
        if not os.path.exists(save_path):
            print(f'  Skipping {name} (weights not found)')
            continue
        
        print(f'  Loading {name} (weight={w:.2f})...')
        model = load_model_for_inference(name, save_path)
        probs = predict_with_tta(model, loader)
        ensemble_probs += w * probs
        del model
        torch.cuda.empty_cache()
    
    final_preds = ensemble_probs.argmax(axis=1)
    
    if not is_test:
        true_labels = val_df_or_test_df['label_idx'].values
        acc = accuracy_score(true_labels, final_preds)
        f1  = f1_score(true_labels, final_preds, average='macro')
        print(f'\nEnsemble Val Accuracy: {acc:.4f}')
        print(f'Ensemble Val F1 (macro): {f1:.4f}')
        print('\nClassification Report:')
        print(classification_report(true_labels, final_preds,
                                     target_names=CFG.CLASS_NAMES))
        return final_preds, acc, f1, ensemble_probs
    
    return final_preds, ensemble_probs


print('\n' + '='*60)
print('STAGE 3: Ensemble + TTA Evaluation')
print('='*60)
val_preds, val_acc, val_f1, val_probs = run_ensemble(val_df)

In [None]:
# ─── Optimize Ensemble Weights ────────────────────────────────────────────────
# Tune weights on validation set using scipy minimize

from scipy.optimize import minimize

# Collect per-model val predictions
model_val_probs = {}
for name in ['vit_mae', 'swin_tiny', 'convnext']:
    save_path = all_results.get(name, {}).get('save_path', '')
    if os.path.exists(save_path):
        val_ds     = S2LabeledDataset(val_df, augment=False)
        val_loader = DataLoader(val_ds, batch_size=CFG.FT_BATCH_SIZE * 2,
                                shuffle=False, num_workers=4)
        model = load_model_for_inference(name, save_path)
        probs = predict_with_tta(model, val_loader)
        model_val_probs[name] = probs
        del model
        torch.cuda.empty_cache()

true_labels = val_df['label_idx'].values
model_names_available = list(model_val_probs.keys())

def neg_f1(weights):
    weights = np.array(weights)
    weights = np.abs(weights) / np.abs(weights).sum()  # normalize to sum=1
    combined = sum(w * model_val_probs[n] for w, n in zip(weights, model_names_available))
    preds = combined.argmax(axis=1)
    return -f1_score(true_labels, preds, average='macro')

# Optimize
x0      = [1/len(model_names_available)] * len(model_names_available)
result  = minimize(neg_f1, x0, method='Nelder-Mead',
                   options={'maxiter': 1000, 'xatol': 1e-4})
opt_w   = np.abs(result.x) / np.abs(result.x).sum()

print('\nOptimized ensemble weights:')
for name, w in zip(model_names_available, opt_w):
    print(f'  {name:15s}: {w:.4f}')

# Apply optimized weights
combined = sum(w * model_val_probs[n] for w, n in zip(opt_w, model_names_available))
opt_preds = combined.argmax(axis=1)
opt_acc   = accuracy_score(true_labels, opt_preds)
opt_f1    = f1_score(true_labels, opt_preds, average='macro')
print(f'\nOptimized Ensemble Val Accuracy: {opt_acc:.4f}')
print(f'Optimized Ensemble Val F1:       {opt_f1:.4f}')

# Save optimized weights for test inference
CFG.ENSEMBLE_WEIGHTS = opt_w.tolist()
np.save(os.path.join(CFG.OUTPUT_DIR, 'ensemble_weights.npy'), opt_w)

## 6. Generate Test Predictions & Submission

In [None]:
# ─── Test Dataset ─────────────────────────────────────────────────────────────

class S2TestDataset(Dataset):
    """Test dataset — no labels, returns image + filename."""
    
    def __init__(self, test_paths, img_size=CFG.IMG_SIZE):
        self.paths    = test_paths
        self.img_size = img_size
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        try:
            img = load_sentinel2_patch(path, self.img_size)
            img = compute_vegetation_indices(img)
        except Exception:
            img = np.zeros((CFG.IN_CHANNELS, self.img_size, self.img_size), dtype=np.float32)
        fname = os.path.basename(path)
        # Return dummy label 0 so we can reuse DataLoader
        return torch.FloatTensor(img), torch.tensor(0, dtype=torch.long)


def discover_test_data(root):
    """Find test folders — adapt this based on actual test folder structure."""
    test_folders = []
    test_root = os.path.join(root, 'test')  # adjust if needed
    if os.path.isdir(test_root):
        for item in sorted(os.listdir(test_root)):
            p = os.path.join(test_root, item)
            if os.path.isdir(p):
                test_folders.append(p)
    return test_folders


test_folders = discover_test_data(CFG.LABELED_ROOT)
print(f'Found {len(test_folders)} test samples')

if test_folders:
    test_ds = S2TestDataset(test_folders)
    test_loader = DataLoader(test_ds, batch_size=CFG.FT_BATCH_SIZE * 2,
                             shuffle=False, num_workers=4)

    # Ensemble over test
    test_probs = np.zeros((len(test_folders), CFG.NUM_CLASSES))
    for name, w in zip(model_names_available, opt_w):
        save_path = all_results[name]['save_path']
        model = load_model_for_inference(name, save_path)
        probs = predict_with_tta(model, test_loader)
        test_probs += w * probs
        del model
        torch.cuda.empty_cache()

    test_preds = test_probs.argmax(axis=1)
    IDX2CLASS  = {v: k for k, v in CFG.CLASS2IDX.items()}
    pred_labels = [IDX2CLASS[p] for p in test_preds]
    
    # Build submission
    submission = pd.DataFrame({
        'Id':       [os.path.basename(f) for f in test_folders],
        'Category': pred_labels
    })
    sub_path = os.path.join(CFG.OUTPUT_DIR, 'submission.csv')
    submission.to_csv(sub_path, index=False)
    print(f'\nSubmission saved: {sub_path}')
    print(submission.head(10))
    print('\nPrediction distribution:')
    print(submission['Category'].value_counts())
else:
    print('No test folder found — run on val data as a sanity check')
    print(f'\nFinal Validation F1: {opt_f1:.4f}')

## 7. Hyperparameter Tuning Guide

### To push from 0.90 → 0.93:

**1. More SSL epochs:**
```python
CFG.SSL_EPOCHS = 100  # or 200 if you have time
```

**2. Pseudo-labeling (semi-supervised loop):**
After first fine-tune, predict on unlabeled S2 data with high confidence threshold (>0.95), add to labeled set, retrain.

**3. Run all 5 folds and ensemble fold checkpoints:**
```python
# Train on all 5 folds → 5 checkpoints per model → 15 models total
# This is the strongest possible ensemble
```

**4. Add temporal averaging:**
```python
# For samples with multiple time acquisitions:
# Load embeddings from T timesteps, mean-pool, then classify
```

**5. Increase image size:**
```python
CFG.IMG_SIZE = 128  # needs more VRAM — reduce batch size to 16
```

### Key hyperparams to tune first (highest ROI):
| Parameter | Default | Try |
|---|---|---|
| `SSL_EPOCHS` | 50 | 100-200 |
| `FT_LR` | 5e-5 | 1e-5, 1e-4 |
| `SSL_MASK_RATIO` | 0.75 | 0.6, 0.8 |
| `FT_BATCH_SIZE` | 32 | 16, 64 |
| `IMG_SIZE` | 64 | 128 |