# CSIRO Image2Biomass - Swin-V2 with Mixup Augmentation

This notebook implements the **Mixup** augmentation strategy for the `AdvancedSwinHydra` model. Mixup helps the model generalize by creating "virtual" training examples that are linear combinations of image pairs and their target biomass/metadata values.

In [1]:
import os, sys, functools
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import GroupKFold

DATA_DIR = "/kaggle/input/csiro-biomass"
CHECKPOINT_DIR = "./models_checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5]

CONFIG = {
    "model_name": "swinv2_large_window12_192.ms_in22k", 
    "img_h": 384, 
    "img_w": 768,
    "batch_size": 8, 
    "lr": 5e-6, 
    "epochs": 50,
    "n_splits": 5,
    "mixup_prob": 0.4,
    "alpha": 1.0,
    "device": "cuda",
    # UPDATE THIS PATH TO YOUR SAVED WEIGHTS FILE
    "resume_path": None
    # "resume_path": "/kaggle/input/swin-mixup-kfold/pytorch/default/4/models_checkpoints/best_swinv2_large_ft_fold4.pth" 
}



In [None]:
class AdvancedSwinHydra(nn.Module):
    def __init__(self, model_name=CONFIG['model_name'], num_species=15, num_region=4):
        super().__init__()
        try:
            self.backbone = timm.create_model(
                model_name, 
                pretrained=True, 
                num_classes=0, global_pool='avg', 
                img_size=(CONFIG['img_h'], CONFIG['img_w'])
            )
        except Exception as e:
            base_name = model_name.split('.')[0]
            self.backbone = timm.create_model(
                base_name, 
                pretrained=True, 
                num_classes=0, global_pool='avg', 
                img_size=(CONFIG['img_h'], CONFIG['img_w'])
            )
            
        embed_dim = self.backbone.num_features
        self.meta_reg = nn.Linear(embed_dim, 2) 
        self.meta_cls = nn.Linear(embed_dim, num_species)
        self.meta_rgn = nn.Linear(embed_dim, num_region)
        self.species_emb = nn.Embedding(num_species, 32)
        self.region_emb = nn.Embedding(num_region, 16)
        
        fusion_dim = embed_dim + 2 + 32 + 16
        self.heads = nn.ModuleList([nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1)) for _ in range(5)])
        
    def forward(self, x, return_meta=False):
        feat = self.backbone(x)
        p_reg = self.meta_reg(feat)
        p_cls = self.meta_cls(feat)
        p_rgn = self.meta_rgn(feat)
        
        spec_idx = torch.argmax(p_cls, dim=-1)
        rgn_idx = torch.argmax(p_rgn, dim=-1)
        s_emb = self.species_emb(spec_idx)
        rgn_emb = self.region_emb(rgn_idx)
        
        fusion = torch.cat([feat, p_reg, s_emb, rgn_emb], dim=-1)
        out = torch.cat([h(fusion) for h in self.heads], dim=-1)
        
        if return_meta: return out, p_reg, p_cls, p_rgn
        return out

In [None]:
def apply_mixup(images, bio_gt, reg_gt, cls_gt, plc_gt, num_species, num_region):
    """
    Blends two images and their labels.
    Images: Weighted average of pixels.
    Labels (Regression): Weighted average of values.
    Labels (Species): Weighted average of one-hot vectors.
    Labels (State): Weighted average of one-hot vectors.
    """
    if np.random.rand() > CONFIG['mixup_prob']:
        return images, bio_gt, reg_gt, torch.nn.functional.one_hot(cls_gt, num_species).float(), torch.nn.functional.one_hot(plc_gt, num_region).float(), 1.0
    
    batch_size = images.size(0)
    index = torch.randperm(batch_size).to(images.device)
    lam = np.random.beta(CONFIG['alpha'], CONFIG['alpha'])
    
    mixed_images = lam * images + (1 - lam) * images[index, :]
    mixed_bio = lam * bio_gt + (1 - lam) * bio_gt[index, :]
    mixed_reg = lam * reg_gt + (1 - lam) * reg_gt[index, :]
    
    cls_onehot = torch.nn.functional.one_hot(cls_gt, num_species).float()
    mixed_cls = lam * cls_onehot + (1 - lam) * cls_onehot[index, :]

    plc_onehot = torch.nn.functional.one_hot(plc_gt, num_region).float()
    mixed_plcs = lam * plc_onehot + (1 - lam) * plc_onehot[index, :]
    
    return mixed_images, mixed_bio, mixed_reg, mixed_cls, mixed_plcs, lam

In [None]:
class AdvancedDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, species_map=None, region_map=None):
        self.df, self.img_dir, self.transform = df, img_dir, transform
        self.species_map, self.region_map = species_map, region_map
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(os.path.join(self.img_dir, row['image_path'])).convert('RGB'))
        if self.transform: img = self.transform(image=img)['image']
        
        bio = torch.tensor(row[TARGET_COLUMNS].values.astype(np.float32))
        reg = torch.tensor([row['Pre_GSHH_NDVI'], row['Height_Ave_cm']], dtype=torch.float32)
        cls = torch.tensor(self.species_map[row['Species']], dtype=torch.long)
        plc = torch.tensor(self.region_map[row['State']], dtype=torch.long)
        
        return img, bio, reg, cls, plc

def competition_metric(y_true, y_pred):
    N = y_true.shape[0]
    w = np.tile(TARGET_WEIGHTS, (N, 1)).flatten()
    y_t, y_p = y_true.flatten(), y_pred.flatten()
    avg = np.sum(w * y_t) / np.sum(w)
    res = np.sum(w * (y_t - y_p)**2)
    tot = np.sum(w * (y_t - avg)**2)
    return 1 - (res/tot) if tot != 0 else 0

In [None]:
def train_fold(fold, train_df, val_df, species_map, region_map):
    num_species = len(species_map)
    num_regions = len(region_map)
    train_ds = AdvancedDataset(train_df, DATA_DIR, A.Compose([A.Resize(CONFIG['img_h'], CONFIG['img_w']), A.HorizontalFlip(), A.Normalize(), ToTensorV2()]), species_map, region_map)
    val_ds = AdvancedDataset(val_df, DATA_DIR, A.Compose([A.Resize(CONFIG['img_h'], CONFIG['img_w']), A.Normalize(), ToTensorV2()]), species_map, region_map)
    
    loader_t = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2, drop_last=True)
    loader_v = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False)
    
    model = AdvancedSwinHydra(num_species=num_species, num_region=num_regions).to(CONFIG['device'])
    
    # LOAD PREVIOUS WEIGHTS
    if CONFIG['resume_path'] != None and os.path.exists(CONFIG['resume_path']):
        print(f"Resuming from: {CONFIG['resume_path']}")
        sd = torch.load(CONFIG['resume_path'], map_location=CONFIG['device'])
        model.load_state_dict({k.replace('module.', ''): v for k, v in sd.items()})

    if torch.cuda.device_count() > 1: model = nn.DataParallel(model)
    
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'])
    crit_bio = nn.HuberLoss(); crit_reg = nn.MSELoss(); crit_cls = nn.BCEWithLogitsLoss()
    crit_plc = nn.BCEWithLogitsLoss()
    scaler = torch.amp.GradScaler('cuda')
    
    best_r2 = -float('inf')
    for epoch in range(CONFIG['epochs']):
        model.train(); epoch_loss = 0
        for imgs, bios, regs, clss, plcs in loader_t:
            imgs, bios, regs, clss, plcs = imgs.to(CONFIG['device']), bios.to(CONFIG['device']), regs.to(CONFIG['device']), clss.to(CONFIG['device']), plcs.to(CONFIG['device'])
            imgs_m, bios_m, regs_m, clss_m, plcs_m, _ = apply_mixup(imgs, bios, regs, clss, plcs, num_species, num_regions)
            
            optimizer.zero_grad()
            with torch.amp.autocast('cuda', enabled=True):
                p_bio, p_reg, p_cls, p_plc = model(imgs_m, return_meta=True)
                loss = crit_bio(p_bio, bios_m) + 0.1*crit_reg(p_reg, regs_m) + 0.2*crit_cls(p_cls, clss_m) +0.2*crit_plc(p_plc, plcs_m)
            
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer); scaler.update(); epoch_loss += loss.item()
        
        model.eval(); all_p, all_t = [], []
        with torch.no_grad():
            for imgs, bios, _, _, _ in loader_v:
                all_p.append(model(imgs.to(CONFIG['device'])).cpu().numpy()); all_t.append(bios.numpy())
        
        r2 = competition_metric(np.vstack(all_t), np.vstack(all_p))
        print(f"Fold {fold} | Ep {epoch+1} | Loss: {epoch_loss/len(loader_t):.4f} | R2: {r2:.4f}")
        if r2 > best_r2:
            best_r2 = r2
            torch.save(model.state_dict(), f"{CHECKPOINT_DIR}/best_swinv2_large_region_fold{fold}.pth")
        scheduler.step()
    return best_r2

In [3]:
df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"))
df_wide = df.pivot_table(index=['image_path', 'Sampling_Date', 'Species', 'State', 'Pre_GSHH_NDVI', 'Height_Ave_cm'], columns='target_name', values='target').reset_index()
species_map = {s: i for i, s in enumerate(sorted(df_wide['Species'].unique()))}
region_map = {rg: i for i, rg in enumerate(sorted(df_wide['State'].unique()))}

print(f"Unique species in training: {len(df_wide['Species'].unique())}")
print(f"Unique location in training: {len(df_wide['State'].unique())}")

gkf = GroupKFold(n_splits=CONFIG['n_splits'])
for fold, (t, v) in enumerate(gkf.split(df_wide, groups=df_wide['Sampling_Date'])):
    print(f"\n--- Starting Fold {fold} ---")
    score = train_fold(fold, df_wide.iloc[t], df_wide.iloc[v], species_map, region_map)
    print(f"Fold {fold} Best R2 Score: {score:.4f}")

Unique species in training: 15
Unique location in training: 4
