# CSIRO Swin-V2 Wide-Tiled Ensemble Inference

In [1]:
import os
import pandas as pd
import numpy as np
import torch, timm
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch.nn as nn

DATA_DIR = '/kaggle/input/csiro-biomass'
PATHS = [
    '/kaggle/input/swinv2-base-tiled/pytorch/default/2/best_swinv2_widetiled_fold_0.pth',
    '/kaggle/input/swinv2-base-tiled/pytorch/default/2/best_swinv2_widetiled_fold_1.pth',
    '/kaggle/input/swinv2-base-tiled/pytorch/default/2/best_swinv2_widetiled_fold_2.pth',
    '/kaggle/input/swinv2-base-tiled/pytorch/default/2/best_swinv2_widetiled_fold_3.pth'
]
ENSEMBLE_WEIGHTS = [0.15, 0.4, 0.05, 0.4]
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

  data = fetch_version_info()


In [None]:
class WideTiledSwin(nn.Module):
    def __init__(self, model_name='swinv2_base_window12_192.ms_in22k', num_species=15):
        super().__init__()
        # The model is initialized for 384x768 resolution
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0, img_size=(384, 768))
        d = self.backbone.num_features
        
        self.meta_reg = nn.Linear(d*2, 2); self.meta_cls = nn.Linear(d*2, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        
        fusion_dim = d*2 + 2 + 32
        self.heads = nn.ModuleList([nn.Sequential(nn.Linear(fusion_dim, 512), nn.GELU(), nn.Linear(512, 1)) for _ in range(5)])
        
    def forward(self, x_g, x_wide_tule):
        # x_g (B, 3, 384, 768) - Global resized view
        # x_wide_tule (B, 3, 384, 768) - Two high-res tiles concatenated horizontally
        fg = self.backbone(x_g)
        ft = self.backbone(x_wide_tule)
        
        vis = torch.cat([fg, ft], dim=1)
        pr, pc = self.meta_reg(vis), self.meta_cls(vis)
        se = self.species_emb(torch.argmax(pc, dim=1))
        f = torch.cat([vis, pr, se], dim=1)
        return torch.cat([h(f) for h in self.heads], dim=1)

In [None]:
def run_inference():
    models = []
    for p in PATHS:
        if os.path.exists(p):
            m = WideTiledSwin().to(DEVICE).eval()
            sd = torch.load(p, map_location=DEVICE)
            m.load_state_dict({k.replace('module.',''): v for k,v in sd.items()})
            models.append(m)
            print(f'Loaded: {p}')

    if not models: return print('No models found!')

    test = pd.read_csv(f'{DATA_DIR}/test.csv')
    uni = test[['image_path']].drop_duplicates()
    tf = A.Compose([A.Resize(384, 768), A.Normalize(), ToTensorV2()])

    class InfDs(Dataset):
        def __len__(self): return len(uni)
        def __getitem__(self, i):
            p = uni.iloc[i]['image_path']; img = np.array(Image.open(f'{DATA_DIR}/{p}').convert('RGB'))
            mid = img.shape[1]//2
            g = tf(image=img)['image']
            lt = img[:, :mid]; rt = img[:, mid:]
            # Tile transform: resize 1:1 to 384x384
            ttf = A.Compose([A.Resize(384, 384), A.Normalize(), ToTensorV2()])
            t = torch.cat([ttf(image=lt)['image'], ttf(image=rt)['image']], dim=2)
            return g, t, p

    ld = DataLoader(InfDs(), batch_size=8)
    res = []
    with torch.no_grad():
        for g, t, ps in ld:
            fps = []
            for m in models: fps.append(m(g.to(DEVICE), t.to(DEVICE)).cpu().numpy())
            w = np.array(ENSEMBLE_WEIGHTS[:len(models)]); w = w / w.sum()
            avg = np.zeros_like(fps[0])
            for i in range(len(fps)): avg += fps[i] * w[i]

            for b in range(len(ps)):
                for i, col in enumerate(TARGET_COLUMNS):
                    res.append({'image_path': ps[b], 'target_name': col, 'target': max(0.0, float(avg[b,i]))})

    out = pd.DataFrame(res)
    sub = test[['sample_id', 'image_path', 'target_name']].merge(out, on=['image_path','target_name'], how='left')
    sub[['sample_id', 'target']].to_csv('submission.csv', index=False)
    print('Submission saved.')

run_inference()