# CSIRO Image2Biomass - Advanced Swin Inference (STRICT OFFLINE)

This notebook performs inference using the `AdvancedSwinHydra`. It is strictly configured for **Internet Off** environments.

In [2]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ==========================================
# CONFIGURATION
# ==========================================
DATA_DIR = r"D:\personalProject\CSIRO-Image2Biomass_Prediction\csiro-biomass"
CHECKPOINT_PATH = r"D:\personalProject\CSIRO-Image2Biomass_Prediction\models_checkpoints\best_swinv2_large_ft_fold2.pth"

# USE BASE NAME ONLY - PREVENTS TIMM FROM CALLING HUGGINGFACE HUB
MODEL_NAME = "swinv2_large_window12_192.ms_in22k"
IMAGE_SIZE = (384, 768)
TARGET_COLUMNS = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
NUM_SPECIES = 15


In [3]:
class AdvancedSwinHydra(nn.Module):
    def __init__(self, model_name=MODEL_NAME, num_species=NUM_SPECIES):
        super().__init__()
        # FORCED OFFLINE: pretrained=False is hardcoded here to block all network calls
        self.backbone = timm.create_model(
            model_name, 
            pretrained=False, 
            num_classes=0, 
            img_size=IMAGE_SIZE
        )
        embed_dim = self.backbone.num_features
        
        self.meta_reg = nn.Linear(embed_dim, 2) 
        self.meta_cls = nn.Linear(embed_dim, num_species)
        self.species_emb = nn.Embedding(num_species, 32)
        
        fusion_dim = embed_dim + 2 + 32
        self.heads = nn.ModuleList([
            nn.Sequential(nn.Linear(fusion_dim, 256), nn.GELU(), nn.Linear(256, 1))
            for _ in range(5)
        ])
        
    def forward(self, x):
        feat = self.backbone(x)
        p_reg = self.meta_reg(feat)
        p_cls = self.meta_cls(feat)
        
        spec_idx = torch.argmax(p_cls, dim=1)
        s_emb = self.species_emb(spec_idx)
        
        fusion = torch.cat([feat, p_reg, s_emb], dim=1)
        out = torch.cat([h(fusion) for h in self.heads], dim=1)
        return out

In [4]:
class SimpleDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df, self.img_dir, self.transform = df, img_dir, transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.img_dir, row['image_path'])).convert("RGB")
        img = np.array(img)
        if self.transform: img = self.transform(image=img)["image"]
        return img

In [10]:
def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    test_df = pd.read_csv(os.path.join(DATA_DIR, "test.csv"))
    unique_test = test_df.drop_duplicates(subset=['image_path']).copy()
    
    # INSTANTIATION
    model = AdvancedSwinHydra().to(device)
    
    if os.path.exists(CHECKPOINT_PATH):
        print(f"Loading weights from: {CHECKPOINT_PATH}")
        state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        print("Weights loaded successfully.")
    else:
        print(f"CRITICAL ERROR: Weight file not found at {CHECKPOINT_PATH}")
        return
        
    model.eval()
    transform = A.Compose([
        A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1]),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    dataset = SimpleDataset(unique_test, DATA_DIR, transform)
    loader = DataLoader(dataset, batch_size=8, shuffle=False)
    
    results = []
    with torch.no_grad():
        for i, image in enumerate(loader):
            image = image.to(device).float()
            # Change 2: Predict for the whole batch
            batch_preds = model(image).cpu().numpy() 
            
            # Change 3: Iterate THROUGH the batch
            for b in range(batch_preds.shape[0]):
                # Calculate the exact row in unique_test
                global_idx = i * loader.batch_size + b
                img_path = unique_test.iloc[global_idx]['image_path']
                
                preds = batch_preds[b]
                for j, col in enumerate(TARGET_COLUMNS):
                    results.append({
                        'image_path': img_path,
                        'target_name': col,
                        'target': max(0.0, float(preds[j]))
                    })
    
    pred_df = pd.DataFrame(results)
    submission = test_df[['sample_id', 'image_path', 'target_name']].merge(pred_df, on=['image_path', 'target_name'], how='left')
    submission = submission[['sample_id', 'target']]
    submission.to_csv("submission.csv", index=False)
    print(submission)
    print("Submission saved to submission.csv")

if __name__ == "__main__":
    run_inference()

Loading weights from: D:\personalProject\CSIRO-Image2Biomass_Prediction\models_checkpoints\best_swinv2_large_ft_fold2.pth
Weights loaded successfully.
                     sample_id     target
0   ID1001187975__Dry_Clover_g   0.411185
1     ID1001187975__Dry_Dead_g  28.347141
2    ID1001187975__Dry_Green_g  27.112358
3    ID1001187975__Dry_Total_g  58.057373
4          ID1001187975__GDM_g  29.188145
5      ID4464212__Dry_Clover_g   0.000000
6        ID4464212__Dry_Dead_g  13.171807
7       ID4464212__Dry_Green_g  32.280479
8       ID4464212__Dry_Total_g  46.282780
9             ID4464212__GDM_g  31.856617
10     ID6269659__Dry_Clover_g  14.796210
11       ID6269659__Dry_Dead_g   4.284080
12      ID6269659__Dry_Green_g  33.170723
13      ID6269659__Dry_Total_g  57.305706
14            ID6269659__GDM_g  51.166885
15     ID7850481__Dry_Clover_g  13.952002
16       ID7850481__Dry_Dead_g   0.745448
17      ID7850481__Dry_Green_g   8.568916
18      ID7850481__Dry_Total_g  28.235245
19       