# CSIRO Image2Biomass - Self-Augmented Hydra Inference

This notebook uses the **Self-Augmented Hydra Model**. 
It internally predicts environmental signals (NDVI, Height, Species) to improve biomass estimation without requiring any metadata in the test set.

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
DATA_DIR = r"d:\personalProject\CSIRO-Image2Biomass_Prediction\csiro-biomass"
CHECKPOINT_PATH = "../models_checkpoints/best_self_augmented_fold1.pth"
IMAGE_SIZE = (384, 384)
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']

In [None]:
class SelfAugmentedHydraModel(nn.Module):
    """
    Self-Augmented Hydra Architecture as per USER request:
    1. Backbone (Image -> Embeddings)
    2. Meta Predictor (Embeddings -> NDVI, Height, Species)
       - Species is treated as a single numeric scalar.
    3. Final Hydra Predictors (Embeddings + Predicted Meta -> 5 Biomass targets)
    """
    def __init__(self, model_name='convnext_tiny.in12k_ft_in1k', pretrained=False):
        super().__init__()
        # 1. Backbone
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        embed_dim = self.backbone.num_features
        
        # 2. Meta Predictor
        self.meta_predictor = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.GELU(),
            nn.Linear(128, 3) # [NDVI, Height, Species_Index]
        )
        
        # 3. Final Hydra Predictors (Input: Embeddings + 3 Meta Predictions)
        fusion_dim = embed_dim + 3
        
        self.head_clover = nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
        self.head_dead   = nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
        self.head_green  = nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
        self.head_gdm    = nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
        self.head_total  = nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
        
    def forward(self, image):
        embeddings = self.backbone(image)
        pred_meta = self.meta_predictor(embeddings)
        combined_input = torch.cat([embeddings, pred_meta], dim=1)
        
        p_clover = self.head_clover(combined_input)
        p_dead   = self.head_dead(combined_input)
        p_green  = self.head_green(combined_input)
        p_gdm    = self.head_gdm(combined_input)
        p_total  = self.head_total(combined_input)
        
        return torch.cat([p_clover, p_dead, p_green, p_gdm, p_total], dim=1)

class SimpleDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = np.array(Image.open(os.path.join(self.img_dir, row['image_path'])).convert('RGB'))
        if self.transform:
            image = self.transform(image=image)['image']
        return image

In [None]:
def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
    unique_test = test_df.drop_duplicates(subset=['image_path']).copy()
    
    model = SelfAugmentedHydraModel().to(device)
    if os.path.exists(CHECKPOINT_PATH):
        model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    model.eval()
    
    transform = A.Compose([
        A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    dataset = SimpleDataset(unique_test, DATA_DIR, transform)
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    results = []
    with torch.no_grad():
        for i, image in enumerate(loader):
            image = image.to(device)
            preds = model(image).cpu().numpy()[0]
            img_path = unique_test.iloc[i]['image_path']
            for j, col in enumerate(TARGET_COLUMNS):
                results.append({
                    'image_path': img_path,
                    'target_name': col,
                    'target': max(0, float(preds[j]))
                })
    
    pred_df = pd.DataFrame(results)
    submission = test_df[['sample_id', 'image_path', 'target_name']].merge(pred_df, on=['image_path', 'target_name'], how='left')
    submission = submission[['sample_id', 'target']]
    submission.to_csv("submission_self_augmented.csv", index=False)
    print("Submission saved to submission_self_augmented.csv")

if __name__ == "__main__":
    run_inference()