# CSIRO DINOv2 Tiling + Fold Ensemble Inference

In [None]:
import os, pd, numpy as np, torch, timm, albumentations as A
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch.nn as nn

DATA_DIR = '/kaggle/input/csiro-biomass'
PATHS = ['/kaggle/input/your-dino-weights/dino_fold0.pth', '/kaggle/input/your-dino-weights/dino_fold1.pth']
ENSEMBLE_WEIGHTS = [0.6, 0.4] # Tweak weights here! Must match number of models found.
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Dry_Total_g', 'GDM_g']

In [None]:
class GlobalLocalDinoHydra(nn.Module):
    def __init__(self, model_name='vit_base_patch14_dinov2.lvd142m', num_species=15):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0, dynamic_img_size=True)
        d = self.backbone.num_features
        self.meta_reg = nn.Linear(d*2, 2); self.meta_cls = nn.Linear(d*2, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        self.heads = nn.ModuleList([nn.Sequential(nn.Linear(d*2+2+32, 512), nn.GELU(), nn.Linear(512, 1)) for _ in range(5)])
    def forward(self, x_g, x_t):
        fg = self.backbone(x_g)
        B, N, C, H, W = x_t.shape
        ft = self.backbone(x_t.view(B*N, C, H, W)).view(B, N, -1).mean(1)
        v = torch.cat([fg, ft], dim=1)
        pr, pc = self.meta_reg(v), self.meta_cls(v)
        se = self.species_emb(pc.argmax(1))
        fus = torch.cat([v, pr, se], dim=1)
        return torch.cat([h(fus) for h in self.heads], dim=1)

In [None]:
def run():
    models = []
    for p in PATHS:
        if os.path.exists(p):
            m = GlobalLocalDinoHydra().cuda().eval()
            sd = torch.load(p, map_location='cuda')
            m.load_state_dict({k.replace('module.',''): v for k,v in sd.items()})
            models.append(m)
            print(f'Loaded: {p}')

    if not models:
        print('No models found!')
        return

    test = pd.read_csv(f'{DATA_DIR}/test.csv')
    uni = test[['image_path']].drop_duplicates()
    tf_g = A.Compose([A.Resize(392, 784), A.Normalize(), ToTensorV2()])
    tf_t = A.Compose([A.Resize(392, 392), A.Normalize(), ToTensorV2()])

    class InfDs(Dataset):
        def __len__(self): return len(uni)
        def __getitem__(self, i):
            p = uni.iloc[i]['image_path']; img = np.array(Image.open(f'{DATA_DIR}/{p}').convert('RGB'))
            mid = img.shape[1]//2
            g = tf_g(image=img)['image']
            t = torch.stack([tf_t(image=img[:, :mid])['image'], tf_t(image=img[:, mid:])['image']])
            return g, t, p

    ld = DataLoader(InfDs(), batch_size=8)
    res = []
    with torch.no_grad():
        for g, t, ps in ld:
            fps = []
            for m in models: fps.append(m(g.cuda(), t.cuda()).cpu().numpy())
            
            # Weighted Averaging logic
            w = np.array(ENSEMBLE_WEIGHTS[:len(models)])
            w = w / w.sum()
            avg = np.zeros_like(fps[0])
            for i in range(len(fps)): avg += fps[i] * w[i]

            for b in range(len(ps)):
                for i, col in enumerate(TARGET_COLUMNS):
                    res.append({'image_path': ps[b], 'target_name': col, 'target': max(0.0, float(avg[b,i]))})

    out = pd.DataFrame(res)
    sub = test[['sample_id', 'image_path', 'target_name']].merge(out, on=['image_path','target_name'], how='left')
    sub[['sample_id', 'target']].to_csv('submission.csv', index=False)
    print('Submission saved with weighted ensemble.')

run()