# CSIRO SigLIP Inference
Average predictions across original high-res tiles for submission.

In [4]:
import os
import pandas as pd
import numpy as np
import torch, timm
import albumentations as A
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from PIL import Image
import torch.nn as nn

DATA_DIR = '/kaggle/input/csiro-biomass'
WEIGHT_PATH = '/kaggle/input/siglip-512/pytorch/default/1/models_checkpoints/siglip_best_fold4.pth'
TARGET_COLS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']

class SigLIPBiomassModel(nn.Module):
    def __init__(self, model_name='vit_base_patch16_siglip_gap_512.webli', num_species=15):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
        d = self.backbone.num_features
        
        # Auxiliary Tasks
        self.meta_reg = nn.Linear(d, 2)     # Predicted NDVI/Height
        self.meta_cls = nn.Linear(d, num_species) # Predicted Species
        self.species_emb = nn.Embedding(num_species, 32)
        
        # 5 Hydra Heads for Biomass (Log-Scale Prediction)
        # Fusion: Vis Feat + Meta Reg + Species Emb
        fusion_dim = d + 2 + 32
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
            for _ in range(5)
        ])
        
    def forward(self, x):
        feat = self.backbone(x)
        pr = self.meta_reg(feat)
        pc = self.meta_cls(feat)
        se = self.species_emb(torch.argmax(pc, dim=1))
        
        fus = torch.cat([feat, pr, se], dim=1)
        out = torch.cat([h(fus) for h in self.heads], dim=1)
        
        return out

class SimpleDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image_path'])
        img = Image.open(img_path).convert("RGB")
        img = np.array(img)
        if self.transform:
            img = self.transform(image=img)["image"]
        return img

def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
    unique_test = test_df.drop_duplicates(subset=['image_path']).copy()
    
    model = SigLIPBiomassModel().to(device)
    if os.path.exists(WEIGHT_PATH):
        print(f"Loading weights from: {WEIGHT_PATH}")
        state_dict = torch.load(WEIGHT_PATH, map_location=device)
        # Handle DataParallel prefix
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
    else:
        print("WARNING: Checkpoint path not found. Running with random weights!")
        
    model.eval()
    
    transform = A.Compose([
        A.Resize(512, 512),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), #mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ToTensorV2(),
    ])
    
    dataset = SimpleDataset(unique_test, DATA_DIR, transform)
    loader = DataLoader(dataset, batch_size=8, shuffle=False)
    
    results = []
    with torch.no_grad():
        for i, image in enumerate(loader):
            image = image.to(device).float()
            # Change 2: Predict for the whole batch
            batch_preds = torch.expm1(model(image)).cpu().numpy() 
            
            # Change 3: Iterate THROUGH the batch
            for b in range(batch_preds.shape[0]):
                # Calculate the exact row in unique_test
                global_idx = i * loader.batch_size + b
                img_path = unique_test.iloc[global_idx]['image_path']
                
                preds = batch_preds[b]
                for j, col in enumerate(TARGET_COLS):
                    results.append({
                        'image_path': img_path,
                        'target_name': col,
                        'target': max(0.0, float(preds[j]))
                    })
    
    pred_df = pd.DataFrame(results)
    submission = test_df[['sample_id', 'image_path', 'target_name']].merge(pred_df, on=['image_path', 'target_name'], how='left')
    submission = submission[['sample_id', 'target']]
    submission.to_csv("submission.csv", index=False)
    print("Preview of submission.csv:")
    print(submission.head())
    print("Submission saved to submission.csv")

if __name__ == "__main__":
    run_inference()

Loading weights from: /kaggle/input/siglip-512/pytorch/default/1/models_checkpoints/siglip_best_fold4.pth
Preview of submission.csv:
                    sample_id    target
0  ID1001187975__Dry_Clover_g  0.000000
1    ID1001187975__Dry_Dead_g  3.240776
2   ID1001187975__Dry_Green_g  3.473956
3   ID1001187975__Dry_Total_g  4.024693
4         ID1001187975__GDM_g  3.590855
Submission saved to submission.csv
