# CSIRO Image2Biomass - Contextual Model Inference Notebook

This notebook performs inference using the **Environmental Context Model**.
It implements cyclical date encoding and handles one-hot encoded metadata via an Environment Encoder branch.

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import joblib

In [None]:
DATA_DIR = r"d:\personalProject\CSIRO-Image2Biomass_Prediction\csiro-biomass"
CHECKPOINT_PATH = "../models_checkpoints/best_local_model_contextual.pth"
META_INFO_PATH = "../models_checkpoints/metadata_info_local_contextual.pkl"
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
IMAGE_SIZE = (224, 448) # H, W

In [None]:
class BiomassDataset(Dataset):
    def __init__(self, df, img_dir, tabular_columns, transform=None, is_test=False):
        self.df = df
        self.img_dir = img_dir
        self.tabular_columns = tabular_columns
        self.transform = transform
        self.is_test = is_test
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image_path'])
        image = np.array(Image.open(img_path).convert('RGB'))
        
        if self.transform:
            image = self.transform(image=image)['image']
            
        tabular = torch.tensor(row[self.tabular_columns].values.astype(np.float32), dtype=torch.float32)
        
        if self.is_test:
            return image, tabular
        
        targets = torch.tensor(row[TARGET_COLUMNS].values.astype(np.float32), dtype=torch.float32)
        return image, tabular, targets

def get_val_transform(img_h, img_w):
    return A.Compose([
        A.Resize(img_h, img_w),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

In [None]:
class EnvironmentalContextModel(nn.Module):
    def __init__(self, model_name='convnext_nano.in12k_ft_in1k', pretrained=False, tabular_dim=22):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        vis_dim = self.backbone.num_features
        
        self.env_encoder = nn.Sequential(
            nn.Linear(tabular_dim, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.GELU(),
            nn.Linear(128, 64)
        )
        
        self.fusion_head = nn.Sequential(
            nn.Linear(vis_dim + 64, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Linear(256, 5)
        )
        
    def forward(self, image, tabular):
        vis_feats = self.backbone(image)
        ctx_feats = self.env_encoder(tabular)
        combined = torch.cat([vis_feats, ctx_feats], dim=1)
        return self.fusion_head(combined)

In [None]:
def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 1. Load Data
    test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
    unique_test_images = test_df.drop_duplicates(subset=['image_path']).copy()
    
    # 2. Load Metadata Info
    meta_info = joblib.load(META_INFO_PATH)
    tab_cols = meta_info['tab_cols']
    scaler = meta_info['scaler']
    
    # 3. Preprocess Test Metadata
    unique_test_images['Sampling_Date'] = pd.to_datetime(unique_test_images.get('Sampling_Date', '2015-01-01'))
    unique_test_images['DayOfYear'] = unique_test_images['Sampling_Date'].dt.dayofyear
    
    # Cyclical Encoding
    unique_test_images['sin_day'] = np.sin(2 * np.pi * unique_test_images['DayOfYear'] / 365.0)
    unique_test_images['cos_day'] = np.cos(2 * np.pi * unique_test_images['DayOfYear'] / 365.0)
    
    # One-Hot Handling (Placeholders for training cols missing in test)
    for col in tab_cols:
        if col not in unique_test_images.columns:
            unique_test_images[col] = 0.0
            
    unique_test_images[tab_cols] = scaler.transform(unique_test_images[tab_cols])
    
    # 4. Setup Model
    model = EnvironmentalContextModel(tabular_dim=len(tab_cols)).to(device)
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    model.eval()
    
    # 5. Dataloader
    transform = get_val_transform(IMAGE_SIZE[0], IMAGE_SIZE[1])
    dataset = BiomassDataset(unique_test_images, DATA_DIR, tab_cols, transform, is_test=True)
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    # 6. Predict
    results = []
    with torch.no_grad():
        for i, (image, tabular) in enumerate(loader):
            image, tabular = image.to(device), tabular.to(device)
            outputs = model(image, tabular)
            preds = outputs.cpu().numpy()[0]
            
            img_path = unique_test_images.iloc[i]['image_path']
            for j, col in enumerate(TARGET_COLUMNS):
                results.append({
                    'image_path': img_path,
                    'target_name': col,
                    'target': max(0, preds[j])
                })
    
    # 7. Format Submission
    pred_df = pd.DataFrame(results)
    submission = test_df[['sample_id', 'image_path', 'target_name']].merge(pred_df, on=['image_path', 'target_name'], how='left')
    submission = submission[['sample_id', 'target']]
    submission.to_csv("submission_contextual.csv", index=False)
    print("Submission saved to submission_contextual.csv")

run_inference()