# CSIRO Image2Biomass - DINOv2 Tiling Inference (Kaggle Offline Ready)

**Improvements:**
1. **Fixed Image Mapping**: Corrected the index math for mapping batches to image paths.
2. **Unique Image Processing**: Only processes each test image once (5x speedup).
3. **Memory Safety**: Optimized for T4 GPU (Batch Size 8).
4. **Strict Offline**: No network calls during model creation.

In [1]:
import os, sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

  data = fetch_version_info()


In [2]:
DATA_DIR = '/kaggle/input/csiro-biomass'
# Update this path to where you attached your model checkpoint
CHECKPOINT_PATH = '/kaggle/input/dinov2-mt/pytorch/default/1/models_checkpoints/best_dino_mixup_tiled.pth'

BATCH_SIZE = 8
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

Using device: cuda


In [3]:
class GlobalLocalDinoHydra(nn.Module):
    def __init__(self, model_name='vit_base_patch14_dinov2.lvd142m', num_species=15):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0, dynamic_img_size=True)
        embed_dim = self.backbone.num_features
        self.meta_reg = nn.Linear(embed_dim * 2, 2) 
        self.meta_cls = nn.Linear(embed_dim * 2, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        fusion_dim = (embed_dim * 2) + 2 + 32
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, 1))
            for _ in range(5)
        ])
        
    def forward(self, x_global, x_tiles):
        feat_global = self.backbone(x_global) 
        B, N, C, H, W = x_tiles.shape
        feat_tiles = self.backbone(x_tiles.view(B*N, C, H, W))
        feat_tiles = feat_tiles.view(B, N, -1).mean(dim=1) 
        fused_vis = torch.cat([feat_global, feat_tiles], dim=1)
        p_reg = self.meta_reg(fused_vis)
        p_cls = self.meta_cls(fused_vis)
        s_emb = self.species_emb(torch.argmax(p_cls, dim=1))
        f_all = torch.cat([fused_vis, p_reg, s_emb], dim=1) 
        out = torch.cat([h(f_all) for h in self.heads], dim=1)
        return out

In [4]:
class InferenceDataset(Dataset):
    def __init__(self, df, img_dir, tf_global, tf_tile):
        self.df, self.img_dir, self.tf_global, self.tf_tile = df, img_dir, tf_global, tf_tile
        
    def __len__(self): return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image_path'])
        img = Image.open(img_path).convert('RGB')
        img_np = np.array(img)
        H, W, _ = img_np.shape
        img_global = self.tf_global(image=img_np)['image']
        mid_w = W // 2
        tile_left = img_np[:, :mid_w, :]
        tile_right = img_np[:, mid_w:, :]
        img_tiles = torch.stack([
            self.tf_tile(image=tile_left)['image'],
            self.tf_tile(image=tile_right)['image']
        ])
        return img_global, img_tiles, row['image_path']

In [5]:
def run_inference():
    model = GlobalLocalDinoHydra().to(DEVICE)
    if os.path.exists(CHECKPOINT_PATH):
        print(f'Loading weights: {CHECKPOINT_PATH}')
        sd = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
        sd = {k.replace("module.", ""): v for k, v in sd.items()}
        model.load_state_dict(sd)
    else:
        print('CRITICAL: Checkpoint not found!')
    model.eval()
    
    test_df = pd.read_csv(os.path.join(DATA_DIR, 'test.csv'))
    unique_images = test_df[['image_path']].drop_duplicates()
    
    tf_g = A.Compose([A.Resize(392, 784), A.Normalize(), ToTensorV2()])
    tf_t = A.Compose([A.Resize(392, 392), A.Normalize(), ToTensorV2()])
    
    ds = InferenceDataset(unique_images, DATA_DIR, tf_g, tf_t)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    all_results = []
    with torch.no_grad():
        for batch_idx, (g_imgs, t_imgs, paths) in enumerate(loader):
            preds = model(g_imgs.to(DEVICE), t_imgs.to(DEVICE)).cpu().numpy()
            for b in range(preds.shape[0]):
                img_path = paths[b]
                for i, target_name in enumerate(TARGET_COLUMNS):
                    all_results.append({
                        'image_path': img_path,
                        'target_name': target_name,
                        'target': max(0.0, float(preds[b, i]))
                    })
            if (batch_idx + 1) % 10 == 0: 
                print(f'Processed {(batch_idx+1)*BATCH_SIZE} images...')

    pred_df = pd.DataFrame(all_results)
    submission = test_df.merge(pred_df, on=['image_path', 'target_name'], how='left')
    submission[['sample_id', 'target']].to_csv('submission.csv', index=False)
    print('Submission file saved successfully!')
    print(submission.head())

if __name__ == '__main__':
    run_inference()

Loading weights: /kaggle/input/dinov2-mt/pytorch/default/1/models_checkpoints/best_dino_mixup_tiled.pth
Submission file saved successfully!
                    sample_id             image_path   target_name     target
0  ID1001187975__Dry_Clover_g  test/ID1001187975.jpg  Dry_Clover_g   0.904442
1    ID1001187975__Dry_Dead_g  test/ID1001187975.jpg    Dry_Dead_g  24.447477
2   ID1001187975__Dry_Green_g  test/ID1001187975.jpg   Dry_Green_g  28.540804
3   ID1001187975__Dry_Total_g  test/ID1001187975.jpg   Dry_Total_g  57.467464
4         ID1001187975__GDM_g  test/ID1001187975.jpg         GDM_g  29.460234
