# CSIRO Image2Biomass - Advanced Hydra Inference (STRICT OFFLINE)

This notebook performs inference using the `AdvancedHydraModel`. It is strictly configured for **Internet Off** environments.

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ==========================================
# CONFIGURATION
# ==========================================
DATA_DIR = "/kaggle/input/csiro-biomass"
CHECKPOINT_PATH = "/kaggle/input/advanced-convenextbase/pytorch/default/1/best_advanced_ConvNextSmall_fold1.pth"

# USE BASE NAME ONLY - PREVENTS TIMM FROM CALLING HUGGINGFACE HUB
MODEL_NAME = "convnext_small"
IMAGE_SIZE = (224, 448)
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
NUM_SPECIES = 15


In [None]:
class AdvancedHydraModel(nn.Module):
    def __init__(self, model_name=MODEL_NAME, num_species=NUM_SPECIES):
        super().__init__()
        # FORCED OFFLINE: pretrained=False is hardcoded here to block all network calls
        self.backbone = timm.create_model(
            model_name, 
            pretrained=False, 
            num_classes=0
        )
        embed_dim = self.backbone.num_features
        
        self.meta_regressor = nn.Sequential(
            nn.Linear(embed_dim, 128), nn.GELU(), nn.Linear(128, 2)
        )
        self.species_classifier = nn.Sequential(
            nn.Linear(embed_dim, 128), nn.GELU(), nn.Linear(128, num_species)
        )
        self.species_embedding = nn.Embedding(num_species, 32)
        
        fusion_dim = embed_dim + 2 + 32
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
            for _ in range(5)
        ])
        
    def forward(self, image):
        vis_features = self.backbone(image)
        pred_reg = self.meta_regressor(vis_features)
        pred_species_logits = self.species_classifier(vis_features)
        best_species = torch.argmax(pred_species_logits, dim=1)
        spec_feat = self.species_embedding(best_species)
        
        combined_input = torch.cat([vis_features, pred_reg, spec_feat], dim=1)
        out = torch.cat([h(combined_input) for h in self.heads], dim=1)
        return out

In [None]:
class SimpleDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df, self.img_dir, self.transform = df, img_dir, transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.img_dir, row['image_path'])).convert("RGB")
        img = np.array(img)
        if self.transform: img = self.transform(image=img)["image"]
        return img

In [None]:
def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
    unique_test = test_df.drop_duplicates(subset=['image_path']).copy()
    
    # INSTANTIATION
    model = AdvancedHydraModel().to(device)
    
    if os.path.exists(CHECKPOINT_PATH):
        print(f"Loading weights from: {CHECKPOINT_PATH}")
        state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        print("Weights loaded successfully.")
    else:
        print(f"CRITICAL ERROR: Weight file not found at {CHECKPOINT_PATH}")
        return
        
    model.eval()
    transform = A.Compose([
        A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    dataset = SimpleDataset(unique_test, DATA_DIR, transform)
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    results = []
    with torch.no_grad():
        for i, image in enumerate(loader):
            # FIX: Ensure image is float32 to match model weights
            image = image.to(device).float()
            preds = model(image).cpu().numpy()[0]
            img_path = unique_test.iloc[i]['image_path']
            for j, col in enumerate(TARGET_COLUMNS):
                results.append({
                    'image_path': img_path,
                    'target_name': col,
                    'target': max(0.0, float(preds[j]))
                })
    
    pred_df = pd.DataFrame(results)
    submission = test_df[['sample_id', 'image_path', 'target_name']].merge(pred_df, on=['image_path', 'target_name'], how='left')
    submission = submission[['sample_id', 'target']]
    submission.to_csv("submission.csv", index=False)
    print("Submission saved to submission.csv")

if __name__ == "__main__":
    run_inference()