# CSIRO Swin-V2 Wide-Tiled Supreme

**Architecture Update:**
1. **Fixed Aspect Ratio**: Both Global and Tiled inputs are now 384x768.
2. **Wide Tiles**: Concatenates Left and Right high-res crops horizontally for Swin-V2 compatibility.
3. **Swin-V2-Base**: Stable multi-stream fusion logic.

In [None]:
import os, sys, functools, json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import GroupKFold

print = functools.partial(print, flush=True)
DATA_DIR = '/kaggle/input/csiro-biomass'
CHECKPOINT_DIR = './models_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']

CONFIG = {
    'model_name': 'swinv2_base_window12_192.ms_in22k', 
    'img_h': 384, 'img_w': 768,
    'batch_size': 8, 'lr': 1e-5, 'epochs': 30, 'mixup_prob': 0.5, 
    'device': 'cuda', 'n_splits': 5,
    'resume_path': None # Set to None to avoid data leakage and get true R2
}

In [None]:
class WideTiledSwin(nn.Module):
    def __init__(self, model_name=CONFIG['model_name'], num_species=15):
        super().__init__()
        # The model is initialized for 384x768 resolution
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, img_size=(CONFIG['img_h'], CONFIG['img_w']))
        d = self.backbone.num_features
        
        self.meta_reg = nn.Linear(d*2, 2); self.meta_cls = nn.Linear(d*2, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        
        fusion_dim = d*2 + 2 + 32
        self.heads = nn.ModuleList([nn.Sequential(nn.Linear(fusion_dim, 512), nn.GELU(), nn.Linear(512, 1)) for _ in range(5)])
        
    def forward(self, x_g, x_wide_tule):
        # x_g (B, 3, 384, 768) - Global resized view
        # x_wide_tule (B, 3, 384, 768) - Two high-res tiles concatenated horizontally
        fg = self.backbone(x_g)
        ft = self.backbone(x_wide_tule)
        
        vis = torch.cat([fg, ft], dim=1)
        pr, pc = self.meta_reg(vis), self.meta_cls(vis)
        se = self.species_emb(torch.argmax(pc, dim=1))
        f = torch.cat([vis, pr, se], dim=1)
        return torch.cat([h(f) for h in self.heads], dim=1), pr, pc

In [None]:
class WideTiledDs(Dataset):
    def __init__(self, df, tf, tile_tf, sm):
        self.df, self.tf, self.tile_tf, self.sm = df, tf, tile_tf, sm
        self.df, self.tf, self.sm = df, tf, sm
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_pill = Image.open(os.path.join(DATA_DIR, row['image_path'])).convert('RGB')
        img_np = np.array(img_pill)
        W = img_np.shape[1]; mid = W // 2
        
        # Global view: Resize original image to target resolution
        g = self.tf(image=img_np)['image']
        
        # Wide-Tile view: Capture high-res from original and Concat horizontally
        # This creates a wide aspect ratio image from high-res patches
        # Real Tiling: Capture high-res crops and resize them individually before concatenation
        left_tile = img_np[:, :mid, :]
        right_tile = img_np[:, mid:, :]
        # Resize individually to capture fine details
        lt_res = self.tile_tf(image=left_tile)['image'] # This tf should resize to 384x384 or similar
        rt_res = self.tile_tf(image=right_tile)['image']
        t = torch.cat([lt_res, rt_res], dim=2) # Concat horizontally in tensor space result is (3, 384, 768)
        
        b = torch.tensor(row[TARGET_COLUMNS].values.astype(np.float32))
        r = torch.tensor([row['Pre_GSHH_NDVI'], row['Height_Ave_cm']], dtype=torch.float32)
        c = torch.tensor(self.sm[row['Species']], dtype=torch.long)
        return g, t, b, r, c

In [None]:
def competition_metric(y_true, y_pred):
    weights = np.array([0.1, 0.1, 0.1, 0.2, 0.5])
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()
    w_flat = np.tile(weights, (len(y_true), 1)).flatten()
    
    y_avg_w = np.sum(w_flat * y_true_flat) / np.sum(w_flat)
    ss_res = np.sum(w_flat * (y_true_flat - y_pred_flat)**2)
    ss_tot = np.sum(w_flat * (y_true_flat - y_avg_w)**2)
    return 1 - ss_res / ss_tot if ss_tot != 0 else 0.0

def apply_mixup(g, t, b, r, c, ns):
    if np.random.rand() > CONFIG['mixup_prob']: 
        return g, t, b, r, nn.functional.one_hot(c, ns).float(), 1.0
    i = torch.randperm(g.size(0)).to(g.device); l = np.random.beta(1.0, 1.0)
    g, t = l*g + (1-l)*g[i], l*t + (1-l)*t[i]
    b, r = l*b + (1-l)*b[i], l*r + (1-l)*r[i]
    c_oh = nn.functional.one_hot(c, ns).float()
    return g, t, b, r, l*c_oh + (1-l)*c_oh[i], l

def train_fold(fold, t_df, v_df, sm):
    ns = len(sm)
    # Shared transformation for both views since they have identical target resolutions now
    tf = A.Compose([A.Resize(CONFIG['img_h'], CONFIG['img_w']), A.HorizontalFlip(), A.Normalize(), ToTensorV2()])
    # Special transform for 1:1 tiles
    tile_tf = A.Compose([A.Resize(CONFIG['img_h'], CONFIG['img_h']), A.HorizontalFlip(), A.Normalize(), ToTensorV2()])
    
    ld_t = DataLoader(WideTiledDs(t_df, tf, tile_tf, sm), batch_size=CONFIG['batch_size'], shuffle=True, num_workers=4, drop_last=True)
    ld_v = DataLoader(WideTiledDs(v_df, tf, tile_tf, sm), batch_size=CONFIG['batch_size'], shuffle=False, num_workers=4)
    
    m = WideTiledSwin(num_species=ns).cuda()
    # LOAD PREVIOUS WEIGHTS
    if CONFIG['resume_path'] != None and os.path.exists(CONFIG['resume_path']):
        print(f"Resuming from: {CONFIG['resume_path']}")
        sd = torch.load(CONFIG['resume_path'], map_location=CONFIG['device'])
        m.load_state_dict({k.replace('module.', ''): v for k, v in sd.items()})
    if torch.cuda.device_count() > 1: m = nn.DataParallel(m)
    
    opt = optim.AdamW(m.parameters(), lr=CONFIG['lr'], weight_decay=0.01)
    scaler = torch.amp.GradScaler('cuda')
    crits = [nn.HuberLoss(), nn.MSELoss(), nn.BCEWithLogitsLoss()]
    
    best_r2 = -float('inf')
    for e in range(CONFIG['epochs']):
        m.train(); loss_acc = 0
        for g, t, b, r, c in ld_t:
            g, t, b, r, c = g.cuda(), t.cuda(), b.cuda(), r.cuda(), c.cuda()
            gm, tm, bm, rm, cm, _ = apply_mixup(g, t, b, r, c, ns)
            opt.zero_grad()
            with torch.amp.autocast('cuda'):
                pb, pr, pc = m(gm, tm)
                loss = crits[0](pb, bm) + 0.1*crits[1](pr, rm) + 0.2*crits[2](pc, cm)
            scaler.scale(loss).backward(); scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(m.parameters(), 1.0); scaler.step(opt); scaler.update(); loss_acc += loss.item()
        
        m.eval(); ap, at = [], []
        with torch.no_grad():
            for g, t, b, _, _ in ld_v:
                pb, _, _ = m(g.cuda(), t.cuda())
                ap.append(pb.cpu().numpy()); at.append(b.numpy())
        
        y_t, y_p = np.vstack(at), np.vstack(ap)
        r2 = competition_metric(y_t, y_p)
        print(f'Fold {fold} | Ep {e+1} | R2: {r2:.4f} | Loss: {loss_acc/len(ld_t):.4f}')
        if r2 > best_r2: 
            best_r2 = r2
            sd = m.module.state_dict() if hasattr(m, "module") else m.state_dict()
            torch.save(sd, f"best_swinv2_widetiled_fold_{fold}.pth")
    return best_r2

In [None]:
df = pd.read_csv('/kaggle/input/csiro-biomass/train.csv')
df_wide = df.pivot_table(index=['image_path','Sampling_Date','Species','Pre_GSHH_NDVI','Height_Ave_cm'], columns='target_name', values='target').reset_index().reindex(columns=['image_path','Sampling_Date','Species','Pre_GSHH_NDVI','Height_Ave_cm'] + TARGET_COLUMNS)
species_map = {s: i for i, s in enumerate(sorted(df_wide['Species'].unique()))}
gkf = GroupKFold(n_splits=CONFIG['n_splits'])
for fold, (t, v) in enumerate(gkf.split(df_wide, groups=df_wide['Sampling_Date'])):
    print(f"\n--- Starting Fold {fold} ---")
    score = train_fold(fold, df_wide.iloc[t], df_wide.iloc[v], species_map)
    print(f"Fold {fold} Best R2 Score: {score:.4f}")