In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import cv2
from tqdm import tqdm
import gc

# ===============================================================
# 1. ‚öôÔ∏è CONFIGURATION (MUST MATCH THE TRAINING FILE)
# ===============================================================
class CFG:
    # --- Paths ---
    BASE_PATH = '/kaggle/input/csiro-biomass'
    TEST_CSV = os.path.join(BASE_PATH, 'test.csv')
    TEST_IMAGE_DIR = os.path.join(BASE_PATH, 'test')
    
    # Directory containing 5 model (.pth) files
    MODEL_DIR = '/kaggle/input/csiro/'  # Assuming all 5 .pth files are here
    SUBMISSION_FILE = 'submission.csv'
    
    # --- Model Settings (MUST MATCH TRAINING) ---
    MODEL_NAME = 'convnext_tiny'  # Must be identical to the one used during training
    IMG_SIZE = 768                # Must match the training image size
    
    # --- Inference Settings ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    BATCH_SIZE = 1
    NUM_WORKERS = 1
    N_FOLDS = 5
    
    # --- Target Columns (MUST MATCH TRAINING) ---
    TARGET_COLS = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g']
    
    # --- Columns Required for Submission ---
    ALL_TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']

print(f"Using device: {CFG.DEVICE}")
print(f"Model backbone: {CFG.MODEL_NAME}")
print(f"Inference image size: {CFG.IMG_SIZE}x{CFG.IMG_SIZE}")


# ===============================================================
# 2. üèûÔ∏è AUGMENTATIONS (FOR VALIDATION / TTA ONLY)
# ===============================================================
from albumentations import (
    Compose, 
    Resize, 
    Normalize,
    HorizontalFlip, 
    VerticalFlip
)

def get_tta_transforms():
    """
    Returns a list of transform pipelines for TTA (Test-Time Augmentation).
    Each pipeline represents a different "view" of the same image.
    """
    
    # Base normalization steps
    base_transforms = [
        Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ]
    
    # View 1: Original image
    original_view = Compose([
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        *base_transforms
    ])
    
    # View 2: Horizontal flip
    hflip_view = Compose([
        HorizontalFlip(p=1.0),
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        *base_transforms
    ])
    
    # View 3: Vertical flip
    vflip_view = Compose([
        VerticalFlip(p=1.0),
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        *base_transforms
    ])
    
    return [original_view, hflip_view, vflip_view]

print("Defined function get_tta_transforms().")


class TestBiomassDataset(Dataset):
    """
    Custom dataset for test images (Dual-stream strategy).
    Modified to accept a specific TTA transform pipeline.
    """
    def __init__(self, df, transform_pipeline, image_dir):
        self.df = df
        self.transforms = transform_pipeline 
        self.image_dir = image_dir
        self.image_paths = df['image_path'].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 1. Get image path
        img_path_suffix = self.image_paths[idx]
        
        # 2. Read original image (2000x1000)
        filename = os.path.basename(img_path_suffix)
        full_path = os.path.join(self.image_dir, filename)
        
        image = cv2.imread(full_path)
        if image is None:
            print(f"Warning: Unable to read image: {full_path}. Returning black image.")
            image = np.zeros((1000, 2000, 3), dtype=np.uint8)
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 3. Crop into left and right halves
        height, width, _ = image.shape
        mid_point = width // 2
        img_left = image[:, :mid_point]
        img_right = image[:, mid_point:]
        
        # 4. Apply the same TTA transform to both halves
        img_left_tensor = self.transforms(image=img_left)['image']
        img_right_tensor = self.transforms(image=img_right)['image']
        
        # 5. Return both halves
        return img_left_tensor, img_right_tensor


# ===============================================================
# 4. üß† MODEL ARCHITECTURE (COPY EXACTLY FROM TRAIN FILE)
# ===============================================================
class BiomassModel(nn.Module):
    """
    Dual-stream architecture with three outputs.
    Must match the training file exactly.
    """
    def __init__(self, model_name, pretrained, n_targets=3):
        super(BiomassModel, self).__init__()
        
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained, 
            num_classes=0,
            global_pool='avg'
        )
        
        self.n_features = self.backbone.num_features
        self.n_combined_features = self.n_features * 2
        
        # Head for Dry_Total_g
        self.head_total = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1)
        )
        
        # Head for GDM_g
        self.head_gdm = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1)
        )
        
        # Head for Dry_Green_g
        self.head_green = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1)
        )

    def forward(self, img_left, img_right):
        features_left = self.backbone(img_left)
        features_right = self.backbone(img_right)
        combined = torch.cat([features_left, features_right], dim=1)
        
        out_total = self.head_total(combined)
        out_gdm = self.head_gdm(combined)
        out_green = self.head_green(combined)
        
        return out_total, out_gdm, out_green


def predict_one_view(models_list, test_loader, device):
    """
    Predicts ensemble 5-fold results for ONE TTA view.
    """
    view_preds_3 = {'total': [], 'gdm': [], 'green': []}
    
    with torch.no_grad():
        for (img_left, img_right) in tqdm(test_loader, desc="  Predicting View", leave=False):
            img_left = img_left.to(device)
            img_right = img_right.to(device)
            
            batch_preds_3_folds = {'total': [], 'gdm': [], 'green': []}
            
            # Ensemble across 5 folds
            for model in models_list:
                pred_total, pred_gdm, pred_green = model(img_left, img_right)
                batch_preds_3_folds['total'].append(pred_total.cpu())
                batch_preds_3_folds['gdm'].append(pred_gdm.cpu())
                batch_preds_3_folds['green'].append(pred_green.cpu())
            
            # Average over 5 folds
            avg_pred_total = torch.mean(torch.stack(batch_preds_3_folds['total']), dim=0)
            avg_pred_gdm = torch.mean(torch.stack(batch_preds_3_folds['gdm']), dim=0)
            avg_pred_green = torch.mean(torch.stack(batch_preds_3_folds['green']), dim=0)
            
            view_preds_3['total'].append(avg_pred_total.numpy())
            view_preds_3['gdm'].append(avg_pred_gdm.numpy())
            view_preds_3['green'].append(avg_pred_green.numpy())

    # Concatenate all batch predictions for this view
    preds_np = {
        'total': np.concatenate(view_preds_3['total']).flatten(),
        'gdm':   np.concatenate(view_preds_3['gdm']).flatten(),
        'green': np.concatenate(view_preds_3['green']).flatten()
    }
    return preds_np


def run_inference_with_tta():
    """
    Main inference function performing TTA + 5-Fold Ensemble.
    """
    print(f"\n{'='*50}")
    print(f"üöÄ STARTING INFERENCE (with TTA) üöÄ")
    print(f"{'='*50}")

    # --- 1. Load Test Data ---
    print(f"Loading {CFG.TEST_CSV}...")
    try:
        test_df_long = pd.read_csv(CFG.TEST_CSV)
        test_df_unique = test_df_long.drop_duplicates(subset=['image_path']).reset_index(drop=True)
        print(f"Found {len(test_df_unique)} unique test images.")
    except FileNotFoundError:
        print(f"ERROR: Test CSV not found: {CFG.TEST_CSV}")
        return None, None, None

    # --- 2. Load 5 Trained Models ---
    print("\nLoading 5 trained models...")
    models_list = []
    for fold in range(CFG.N_FOLDS):
        model_path = os.path.join(CFG.MODEL_DIR, f'best_model_fold{fold}.pth')
        if not os.path.exists(model_path):
            print(f"ERROR: Model file not found: {model_path}")
            return None, None, None
        model = BiomassModel(CFG.MODEL_NAME, pretrained=False)
        try:
            model.load_state_dict(torch.load(model_path, map_location=CFG.DEVICE))
        except RuntimeError:
            state_dict = torch.load(model_path, map_location=CFG.DEVICE)
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k.replace('module.', '')
                new_state_dict[name] = v
            model.load_state_dict(new_state_dict)
        model.eval()
        model.to(CFG.DEVICE)
        models_list.append(model)
    print(f"‚úì Successfully loaded {len(models_list)} models.")

    # --- 3. Loop over TTA Views ---
    tta_transforms = get_tta_transforms()
    print(f"\nStarting predictions with {len(tta_transforms)} TTA views...")
    
    all_tta_view_preds = []

    for i, tta_transform in enumerate(tta_transforms):
        print(f"--- Running TTA View {i+1}/{len(tta_transforms)} ---")
        
        test_dataset = TestBiomassDataset(
            df=test_df_unique,
            transform_pipeline=tta_transform,
            image_dir=CFG.TEST_IMAGE_DIR
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=False,
            num_workers=CFG.NUM_WORKERS,
            pin_memory=True
        )
        
        view_preds_np = predict_one_view(models_list, test_loader, CFG.DEVICE)
        all_tta_view_preds.append(view_preds_np)
        print(f"‚úì Completed TTA View {i+1}")

    # --- 4. Average TTA Results ---
    print("\nAveraging predictions across TTA views...")
    final_ensembled_preds = {
        'total': np.mean([d['total'] for d in all_tta_view_preds], axis=0),
        'gdm':   np.mean([d['gdm'] for d in all_tta_view_preds], axis=0),
        'green': np.mean([d['green'] for d in all_tta_view_preds], axis=0)
    }
    
    print("‚úì Inference complete.")
    
    del models_list, test_loader, test_dataset
    gc.collect()
    torch.cuda.empty_cache()
    
    return final_ensembled_preds, test_df_long, test_df_unique


# ===============================================================
# 6. ‚úçÔ∏è CREATE SUBMISSION FILE
# ===============================================================
def create_submission(preds_np, test_df_long, test_df_unique):
    """
    Takes 3 predicted outputs, calculates 2 additional targets,
    and formats them into the submission CSV file.
    """
    if preds_np is None:
        print("Skipping submission creation due to previous error.")
        return

    print("\nPost-processing and creating submission file...")

    pred_total_final = preds_np['total']
    pred_gdm_final = preds_np['gdm']
    pred_green_final = preds_np['green']

    # Calculate two missing targets
    pred_clover_final = np.maximum(0, pred_gdm_final - pred_green_final)
    pred_dead_final = np.maximum(0, pred_total_final - pred_gdm_final)

    preds_wide_df = pd.DataFrame({
        'image_path': test_df_unique['image_path'],
        'Dry_Green_g': pred_green_final,
        'Dry_Dead_g': pred_dead_final,
        'Dry_Clover_g': pred_clover_final,
        'GDM_g': pred_gdm_final,
        'Dry_Total_g': pred_total_final
    })

    # Convert from wide to long format
    preds_long_df = preds_wide_df.melt(
        id_vars=['image_path'],
        value_vars=CFG.ALL_TARGET_COLS,
        var_name='target_name',
        value_name='target'
    )

    # Merge with original test.csv to get 'sample_id'
    submission_df = pd.merge(
        test_df_long[['sample_id', 'image_path', 'target_name']],
        preds_long_df,
        on=['image_path', 'target_name'],
        how='left'
    )

    submission_df = submission_df[['sample_id', 'target']]
    submission_df.to_csv(CFG.SUBMISSION_FILE, index=False)

    print(f"\nüéâ DONE! Submission saved at: {CFG.SUBMISSION_FILE}")
    print("--- First 5 rows of submission ---")
    print(submission_df.head())
    print("\n--- Last 5 rows of submission ---")
    print(submission_df.tail())


# ===============================================================
# 8. üèÅ RUN THE PROGRAM
# ===============================================================
if __name__ == "__main__":
    all_preds_np, df_long, df_unique = run_inference_with_tta()
    create_submission(all_preds_np, df_long, df_unique)
