In [None]:
%%writefile inference.py
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import cv2
import os
import sys
import subprocess
import argparse
import gc
import warnings

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================
class CFG:
    test_csv = '/kaggle/input/csiro-biomass/test.csv'
    test_dir = '/kaggle/input/csiro-biomass/test'
    
    # Path to the directory containing the Aux models (v7)
    aux_model_dir = '/kaggle/input/aux-v7/pytorch/default/1/Models_Aux_Only_v7' 
    
    # Path to the main stage 2 models
    main_model_dir = '/kaggle/input/vithugedinov3-with-manual-data-cleaning/pytorch/default/1'
    
    output_file = 'submission.csv'
    model_name = 'vit_huge_plus_patch16_dinov3.lvd1689m'
    img_size = 800
    
    # Fold weights for ensemble
    fold_weights = {0: 0.264901, 1: 0.174089, 2: 0.157770, 3: 0.231843, 4: 0.171397}
    
    seeds = [42]        # Seeds for Main Model
    aux_seeds = [44]    # Seeds for Aux Model
    
    batch_size = 4
    num_workers = 2
    targets = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    
    # TTA Toggle
    use_tta = True

# ============================================================================
# MODELS
# ============================================================================
class AuxModel(nn.Module):
    """
    Architecture matches the 'BiomassModel' from the new aux training code.
    Outputs: [NDVI, Height]
    """
    def __init__(self, model_name):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
        img_features = self.backbone.num_features

        # AUXILIARY HEAD ONLY (NDVI, Height)
        self.aux_head = nn.Sequential(
            nn.Linear(img_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 2) 
        )

    def forward(self, image):
        feat = self.backbone(image)
        aux_out = self.aux_head(feat)
        return aux_out

class BiomassModel(nn.Module):
    """
    Main Stage 2 Model (unchanged, as no new training code was provided for this part).
    Expects concatenated Image features + Encoded Tabular features.
    """
    def __init__(self, model_name):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=False, num_classes=0)
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, CFG.img_size, CFG.img_size)
            img_features = self.backbone(dummy_input).shape[1]
    
        self.tabular_encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        fusion_dim = img_features + 128
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        
        self.head_green = nn.Linear(256, 1)
        self.head_dead = nn.Linear(256, 1)
        self.head_clover = nn.Linear(256, 1)
        self.head_gdm = nn.Linear(256, 1)
        self.head_total = nn.Linear(256, 1)
    
    def forward(self, image, tabular):
        img_features = self.backbone(image)
        tab_features = self.tabular_encoder(tabular)
        combined = torch.cat([img_features, tab_features], dim=1)
        fused = self.fusion(combined)
        
        out_green = self.head_green(fused)
        out_dead = self.head_dead(fused)
        out_clover = self.head_clover(fused)
        out_gdm = self.head_gdm(fused)
        out_total = self.head_total(fused)
        
        outputs = torch.cat([out_green, out_dead, out_clover, out_gdm, out_total], dim=1)
        return outputs

# ============================================================================
# DATA & TTA UTILS
# ============================================================================
class BiomassInferenceDataset(Dataset):
    def __init__(self, df, img_dir, transform, tabular_data=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.tabular_data = tabular_data

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image_path'].split('/')[-1]
        img_path = os.path.join(self.img_dir, img_name)
        image = cv2.imread(img_path)
        if image is None:
            image = np.zeros((CFG.img_size, CFG.img_size, 3), np.uint8)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        image = self.transform(image=image)['image']
        if self.tabular_data is not None:
            return image, torch.tensor(self.tabular_data[idx], dtype=torch.float32)
        return image

def get_tta_transforms():
    base = [A.Resize(CFG.img_size, CFG.img_size), A.Normalize(), ToTensorV2()]
    if not CFG.use_tta:
        return [A.Compose(base)]
    return [
        A.Compose(base),
        A.Compose([A.HorizontalFlip(p=1.0)] + base),
        A.Compose([A.VerticalFlip(p=1.0)] + base)
    ]

# ============================================================================
# WORKER
# ============================================================================
def run_worker(rank, world_size):
    device = torch.device(f'cuda:{rank}')
    test_df = pd.read_csv(CFG.test_csv).drop_duplicates('image_path').reset_index(drop=True)
    
    indices = np.array_split(np.arange(len(test_df)), world_size)[rank]
    if len(indices) == 0: return 
    my_df = test_df.iloc[indices].reset_index(drop=True)

    # --- STAGE 1: AUX PREDICTIONS (FLOAT32) ---
    aux_preds_accum = 0
    aux_count = 0
    
    print(f"[Rank {rank}] Starting Stage 1: Aux Inference...")
    
    for seed in CFG.aux_seeds:
        for fold in range(5):
            path = os.path.join(CFG.aux_model_dir, f'best_aux_only_seed{seed}_fold{fold}.pth')
            if not os.path.exists(path): 
                # v6 naming convention fallback
                path_alt = os.path.join(CFG.aux_model_dir, f'best_aux_seed{seed}_fold{fold}.pth')
                if os.path.exists(path_alt):
                    path = path_alt
                else:
                    continue
            
            ckpt = torch.load(path, map_location='cpu', weights_only=False)
            model = AuxModel(CFG.model_name).to(device)
            model.load_state_dict(ckpt['model_state_dict'])
            model.eval()
            
            loader = DataLoader(BiomassInferenceDataset(my_df, CFG.test_dir, get_tta_transforms()[0]), 
                                batch_size=CFG.batch_size, num_workers=CFG.num_workers)
            
            fold_aux = []
            with torch.no_grad():
                for img in loader:
                    fold_aux.append(model(img.to(device)).cpu().numpy())
            
            if len(fold_aux) > 0:
                res = np.vstack(fold_aux)
                # Apply Inverse Transform (from training scaler)
                if 'tab_scaler' in ckpt: 
                    res = ckpt['tab_scaler'].inverse_transform(res)
                aux_preds_accum += res
                aux_count += 1
            del model; gc.collect(); torch.cuda.empty_cache()

    predicted_tabular = aux_preds_accum / max(1, aux_count)
    print(f"[Rank {rank}] Stage 1 Complete. Models used: {aux_count}")

    # --- STAGE 2: FINAL INFERENCE (FLOAT32) ---
    final_biomass_accum = 0
    total_w = 0
    tta_transforms = get_tta_transforms()
    
    print(f"[Rank {rank}] Starting Stage 2: Main Inference...")

    for seed in CFG.seeds:
        for fold, weight in CFG.fold_weights.items():
            path = os.path.join(CFG.main_model_dir, f'best_model_seed{seed}_fold{fold}.pth')
            if not os.path.exists(path): continue
            
            ckpt = torch.load(path, map_location='cpu', weights_only=False)
            model = BiomassModel(CFG.model_name).to(device)
            model.load_state_dict(ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt)
            model.eval()

            # Prepare Tabular Input (Predicted by Aux)
            tab_input = predicted_tabular.copy()
            if ckpt.get('tabular_scaler'): 
                tab_input = ckpt['tabular_scaler'].transform(tab_input)
            
            fold_tta_preds = 0
            
            for t_idx, transform in enumerate(tta_transforms):
                loader = DataLoader(BiomassInferenceDataset(my_df, CFG.test_dir, transform, tabular_data=tab_input),
                                    batch_size=CFG.batch_size, num_workers=CFG.num_workers)
                
                step_preds = []
                with torch.no_grad():
                    for img, tab in loader:
                        step_preds.append(model(img.to(device), tab.to(device)).cpu().numpy())
                
                if len(step_preds) > 0:
                    fold_tta_preds += np.vstack(step_preds)
            
            res = fold_tta_preds / len(tta_transforms)
            
            if ckpt.get('target_scaler'): 
                res = ckpt['target_scaler'].inverse_transform(res)
            
            final_biomass_accum += (res * weight)
            total_w += weight
            del model; gc.collect(); torch.cuda.empty_cache()

    if total_w > 0:
        final_preds = final_biomass_accum / total_w
        img_ids = my_df['image_path'].apply(lambda x: x.split('/')[-1].replace('.jpg', '')).values
        rows = []
        for i, img_id in enumerate(img_ids):
            for j, target in enumerate(CFG.targets):
                rows.append({'sample_id': f"{img_id}__{target}", 'target': max(0.0, float(final_preds[i, j]))})
        pd.DataFrame(rows).to_csv(f'temp_part_{rank}.csv', index=False)
    else:
        print(f"[Rank {rank}] Warning: No Stage 2 models found!")

# ============================================================================
# MAIN
# ============================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--rank', type=int, default=-1)
    parser.add_argument('--world_size', type=int, default=2)
    args = parser.parse_args()

    if args.rank == -1:
        # Spawning processes for multi-GPU inference
        processes = [subprocess.Popen([sys.executable, __file__, '--rank', str(r), '--world_size', str(args.world_size)]) for r in range(args.world_size)]
        for p in processes: p.wait()
        
        dfs = []
        for r in range(args.world_size):
            fname = f'temp_part_{r}.csv'
            if os.path.exists(fname):
                dfs.append(pd.read_csv(fname))
                os.remove(fname)
        
        if dfs:
            sub = pd.concat(dfs).drop_duplicates('sample_id')
            sub.to_csv(CFG.output_file, index=False)
            print(f"Submission saved to {CFG.output_file}")
        else:
            print("Error: No predictions generated.")
            pd.DataFrame(columns=['sample_id', 'target']).to_csv(CFG.output_file, index=False)
    else:
        run_worker(args.rank, args.world_size)

Writing inference.py


In [2]:
!python inference.py

  data = fetch_version_info()
  data = fetch_version_info()
  data = fetch_version_info()
[Rank 0] Starting Stage 1: Aux Inference...
[Rank 0] Stage 1 Complete. Models used: 5
[Rank 0] Starting Stage 2: Main Inference...
Submission saved to submission.csv
