---

# Methodology: CSIRO Pasture Biomass Prediction

## 1. Core Strategy: Predicting Key Components

The primary goal is to predict five biomass targets. Based on exploratory data analysis (EDA), we identified linear dependencies:
* `Dry_Total_g` $\approx$ `Dry_Green_g` + `Dry_Dead_g` + `Dry_Clover_g`
* `GDM_g` $\approx$ `Dry_Green_g` + `Dry_Clover_g`

To avoid redundancy, the model is trained to predict only the **three most visually distinct and/or highest-weighted targets**:
* `Dry_Total_g` (50% of the score)
* `GDM_g` (20% of the score)
* `Dry_Green_g` (10% of the score)

The remaining two targets (`Dry_Dead_g` and `Dry_Clover_g`) are then **calculated during validation and inference** using subtraction (e.g., `pred_Clover = max(0, pred_GDM - pred_Green)`).

---

## 2. Data Handling & K-Fold Strategy

* **Image Input:** All source images are high-resolution (`2000x1000` pixels).
* **Two-Stream Processing:** To preserve fine-grained details (like clover leaves) that would be lost by resizing the entire image, the `Dataset` class crops each image into two `1000x1000` patches (a "left" and "right" half).
* **High-Resolution Input:** Each `1000x1000` patch is then resized to **`768x768`**, maintaining a high level of detail.
* **K-Fold Strategy:** We use a **5-Fold Cross-Validation** strategy due to the small dataset (357 images).
* **Robust Splitting (GroupKFold):** To prevent data leakage (where similar images from the same day are in both train and validation), we use `GroupKFold` grouped by `Sampling_Date`. This ensures the model is validated on dates it has never seen.

---

## 3. Model Architecture: Two-Stream, Multi-Head

The model uses a "Two-Stream, Multi-Head" architecture.
* **Shared Backbone:** A single `timm` backbone (e.g., `convnext_tiny`) with pre-trained ImageNet weights is used.
* **Two-Stream Input:**
    * `img_left` $\rightarrow$ `backbone` $\rightarrow$ `features_left`
    * `img_right` $\rightarrow$ (same) `backbone` $\rightarrow$ `features_right`
* **Fusion:** The two feature vectors are concatenated: `combined_features = torch.cat([features_left, features_right])`.
* **Multi-Head Output:** This combined vector is fed into **three separate, specialized MLP heads** (one for each target: `head_total`, `head_gdm`, `head_green`) to allow for task specialization.

---

## 4. Data Augmentation

To compensate for the small dataset, augmentations are applied **independently** to the `img_left` and `img_right` patches.
* `HorizontalFlip (p=0.5)`
* `VerticalFlip (p=0.5)`
* `RandomRotate90 (p=0.5)` (Only 90-degree rotations)
* `ColorJitter`

This independent application creates a much larger variety of training combinations.

---

## 5. Loss Function: Weighted SmoothL1Loss

The model is optimized using a custom weighted loss function that aligns with the competition's scoring metric.
* **Base Loss:** `nn.SmoothL1Loss` (Huber Loss) is used instead of `MSELoss` to make training more stable and less sensitive to outliers.
* **Weighted Sum:** The final loss is a weighted sum of the individual losses, using the competition's scoring weights:
    $$Loss = (0.5 \cdot Loss_{Total}) + (0.2 \cdot Loss_{GDM}) + (0.1 \cdot Loss_{Green})$$

---

## 6. Training Strategy: Two-Stage Fine-Tuning

A two-stage "Freeze/Unfreeze" strategy is used to stabilize training on the small dataset.
* **Stage 1 (Freeze):**
    * **Epochs:** 1-5
    * **Action:** The entire `backbone` is frozen. Only the three MLP heads are trained.
    * **LR:** `1e-4`
* **Stage 2 (Unfreeze/Fine-Tuning):**
    * **Epochs:** 6-20
    * **Action:** The `backbone` is "unfrozen," and the entire model is trained.
    * **LR:** A very low learning rate (`1e-5`) is used to slowly adapt the backbone features.
* **Model Saving:** A `ModelCheckpoint` saves the model based on the **highest `Score (R^2)`** on the validation set, *not* the lowest loss. This is critical for capturing the model's peak performance (like the `R^2=0.64` spike at Epoch 11) and ignoring the unstable, overfitted epochs.

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import cv2
from tqdm import tqdm
import gc

# ===============================================================
# 1. ‚öôÔ∏è CONFIGURATION (PH·∫¢I GI·ªêNG H·ªÜT FILE TRAINING)
# ===============================================================
class CFG:
    # --- ƒê∆∞·ªùng d·∫´n (Paths) ---
    # (H√£y ƒëi·ªÅu ch·ªânh c√°c ƒë∆∞·ªùng d·∫´n n√†y cho ƒë√∫ng v·ªõi m√¥i tr∆∞·ªùng c·ªßa b·∫°n)
    BASE_PATH = '/kaggle/input/csiro-biomass'
    TEST_CSV = os.path.join(BASE_PATH, 'test.csv')
    TEST_IMAGE_DIR = os.path.join(BASE_PATH, 'test')
    
    # Th∆∞ m·ª•c ch·ª©a 5 file .pth
    MODEL_DIR = '/kaggle/input/csiro/' # Gi·∫£ s·ª≠ 5 file .pth n·∫±m c√πng th∆∞ m·ª•c
    SUBMISSION_FILE = 'submission.csv'
    
    # --- C√†i ƒë·∫∑t M√¥ h√¨nh (PH·∫¢I TR√ôNG KH·ªöP) ---
    MODEL_NAME = 'convnext_tiny' # PH·∫¢I GI·ªêNG H·ªÜT L√öC TRAIN
    IMG_SIZE = 768               # PH·∫¢I GI·ªêNG H·ªÜT L√öC TRAIN
    
    # --- C√†i ƒë·∫∑t Inference ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    BATCH_SIZE = 1 # C√≥ th·ªÉ tƒÉng batch size khi inference
    NUM_WORKERS = 1
    N_FOLDS = 5
    
    # --- M·ª•c ti√™u & Loss (PH·∫¢I TR√ôNG KH·ªöP) ---
    # 3 m·ª•c ti√™u model ƒë√£ d·ª± ƒëo√°n
    TARGET_COLS = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g']
    
    # 5 m·ª•c ti√™u ƒë·ªÉ n·ªôp b√†i
    ALL_TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']

print(f"S·ª≠ d·ª•ng thi·∫øt b·ªã: {CFG.DEVICE}")
print(f"Backbone m√¥ h√¨nh: {CFG.MODEL_NAME}")
print(f"K√≠ch th∆∞·ªõc ·∫£nh inference: {CFG.IMG_SIZE}x{CFG.IMG_SIZE}")


# ===============================================================
# 2. üèûÔ∏è AUGMENTATIONS (CH·ªà D√ôNG VALIDATION)
# ===============================================================
from albumentations import (
    Compose, 
    Resize, 
    Normalize,
    HorizontalFlip, 
    VerticalFlip
)

def get_tta_transforms():
    """
    Tr·∫£ v·ªÅ m·ªôt LIST c√°c pipeline transform cho TTA.
    M·ªói pipeline l√† m·ªôt "view" kh√°c nhau c·ªßa ·∫£nh.
    """
    
    # ƒê√¢y l√† c√°c b∆∞·ªõc chu·∫©n h√≥a c∆° b·∫£n
    base_transforms = [
        Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        ToTensorV2()
    ]
    
    # -----------------
    # View 1: ·∫¢nh g·ªëc (Ch·ªâ Resize + Normalize)
    # -----------------
    original_view = Compose([
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        *base_transforms
    ])
    
    # -----------------
    # View 2: L·∫≠t ngang (HFlip)
    # -----------------
    hflip_view = Compose([
        HorizontalFlip(p=1.0), # Lu√¥n lu√¥n l·∫≠t
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        *base_transforms
    ])
    
    # -----------------
    # View 3: L·∫≠t d·ªçc (VFlip)
    # -----------------
    vflip_view = Compose([
        VerticalFlip(p=1.0), # Lu√¥n lu√¥n l·∫≠t
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        *base_transforms
    ])
    
    return [original_view, hflip_view, vflip_view]

print("ƒê√£ ƒë·ªãnh nghƒ©a h√†m get_tta_transforms().")


class TestBiomassDataset(Dataset):
    """
    Dataset t√πy ch·ªânh cho ·∫£nh test (Chi·∫øn l∆∞·ª£c "Hai lu·ªìng").
    S·ª≠a ƒë·ªïi ƒë·ªÉ ch·∫•p nh·∫≠n m·ªôt pipeline transform c·ª• th·ªÉ cho TTA.
    """
    def __init__(self, df, transform_pipeline, image_dir):
        self.df = df
        # (S·ª¨A ƒê·ªîI) Ch·∫•p nh·∫≠n m·ªôt pipeline ƒë√£ ƒë∆∞·ª£c kh·ªüi t·∫°o
        self.transforms = transform_pipeline 
        self.image_dir = image_dir
        self.image_paths = df['image_path'].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 1. L·∫•y th√¥ng tin
        img_path_suffix = self.image_paths[idx]
        
        # 2. ƒê·ªçc ·∫£nh g·ªëc (2000x1000)
        filename = os.path.basename(img_path_suffix)
        full_path = os.path.join(self.image_dir, filename)
        
        image = cv2.imread(full_path)
        if image is None:
            print(f"Warning: Kh√¥ng th·ªÉ ƒë·ªçc ·∫£nh: {full_path}. Tr·∫£ v·ªÅ ·∫£nh ƒëen.")
            image = np.zeros((1000, 2000, 3), dtype=np.uint8)
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 3. C·∫Øt (Crop) th√†nh 2 ·∫£nh (Tr√°i v√† Ph·∫£i)
        height, width, _ = image.shape
        mid_point = width // 2
        img_left = image[:, :mid_point]
        img_right = image[:, mid_point:]
        
        # 4. √Åp d·ª•ng TTA Transform (C√ôNG M·ªòT TRANSFORM cho c·∫£ 2)
        # (V√≠ d·ª•: C·∫£ 2 ·∫£nh ƒë·ªÅu b·ªã l·∫≠t ngang)
        img_left_tensor = self.transforms(image=img_left)['image']
        img_right_tensor = self.transforms(image=img_right)['image']
        
        # 5. Tr·∫£ v·ªÅ
        return img_left_tensor, img_right_tensor

# ===============================================================
# 4. üß† MODEL ARCHITECTURE (SAO CH√âP T·ª™ FILE TRAIN)
# ===============================================================
class BiomassModel(nn.Module):
    """
    Ki·∫øn tr√∫c m√¥ h√¨nh (Hai lu·ªìng, Ba ƒë·∫ßu ra)
    PH·∫¢I GI·ªêNG H·ªÜT file training.
    """
    def __init__(self, model_name, pretrained, n_targets=3):
        super(BiomassModel, self).__init__()
        
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained, # S·∫Ω l√† False khi inference
            num_classes=0,
            global_pool='avg'
        )
        
        self.n_features = self.backbone.num_features
        self.n_combined_features = self.n_features * 2
        
        # --- ƒê·∫ßu cho Dry_Total_g ---
        self.head_total = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1)
        )
        
        # --- ƒê·∫ßu cho GDM_g ---
        self.head_gdm = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1)
        )
        
        # --- ƒê·∫ßu cho Dry_Green_g ---
        self.head_green = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1)
        )

    def forward(self, img_left, img_right):
        features_left = self.backbone(img_left)
        features_right = self.backbone(img_right)
        combined = torch.cat([features_left, features_right], dim=1)
        
        out_total = self.head_total(combined)
        out_gdm = self.head_gdm(combined)
        out_green = self.head_green(combined)
        
        return out_total, out_gdm, out_green


def predict_one_view(models_list, test_loader, device):
    """
    H√†m con: Ch·∫°y d·ª± ƒëo√°n ensemble 5-fold cho M·ªòT view TTA.
    """
    view_preds_3 = {'total': [], 'gdm': [], 'green': []}
    
    with torch.no_grad():
        for (img_left, img_right) in tqdm(test_loader, desc="  Predicting View", leave=False):
            img_left = img_left.to(device)
            img_right = img_right.to(device)
            
            batch_preds_3_folds = {'total': [], 'gdm': [], 'green': []}
            
            # 1. V√≤ng l·∫∑p Ensemble 5-Fold
            for model in models_list:
                pred_total, pred_gdm, pred_green = model(img_left, img_right)
                batch_preds_3_folds['total'].append(pred_total.cpu())
                batch_preds_3_folds['gdm'].append(pred_gdm.cpu())
                batch_preds_3_folds['green'].append(pred_green.cpu())
            
            # 2. L·∫•y trung b√¨nh 5 Fold
            avg_pred_total = torch.mean(torch.stack(batch_preds_3_folds['total']), dim=0)
            avg_pred_gdm = torch.mean(torch.stack(batch_preds_3_folds['gdm']), dim=0)
            avg_pred_green = torch.mean(torch.stack(batch_preds_3_folds['green']), dim=0)
            
            view_preds_3['total'].append(avg_pred_total.numpy())
            view_preds_3['gdm'].append(avg_pred_gdm.numpy())
            view_preds_3['green'].append(avg_pred_green.numpy())

    # 3. Gh√©p k·∫øt qu·∫£ c√°c batch c·ªßa view n√†y
    preds_np = {
        'total': np.concatenate(view_preds_3['total']).flatten(),
        'gdm':   np.concatenate(view_preds_3['gdm']).flatten(),
        'green': np.concatenate(view_preds_3['green']).flatten()
    }
    return preds_np


def run_inference_with_tta():
    """
    H√†m inference ch√≠nh, th·ª±c hi·ªán TTA x Ensemble.
    """
    print(f"\n{'='*50}")
    print(f"üöÄ B·∫ÆT ƒê·∫¶U INFERENCE (v·ªõi TTA) üöÄ")
    print(f"{'='*50}")

    # --- 1. T·∫£i D·ªØ li·ªáu Test ---
    print(f"ƒêang t·∫£i {CFG.TEST_CSV}...")
    try:
        test_df_long = pd.read_csv(CFG.TEST_CSV)
        test_df_unique = test_df_long.drop_duplicates(subset=['image_path']).reset_index(drop=True)
        print(f"T√¨m th·∫•y {len(test_df_unique)} ·∫£nh test duy nh·∫•t.")
    except FileNotFoundError:
        print(f"L·ªñI: Kh√¥ng t√¨m th·∫•y {CFG.TEST_CSV}")
        return None, None, None

    # --- 2. T·∫£i 5 M√¥ h√¨nh (Ensemble) ---
    print("\nƒêang t·∫£i 5 m√¥ h√¨nh ƒë√£ hu·∫•n luy·ªán...")
    models_list = []
    # (Code t·∫£i 5 m√¥ h√¨nh... gi·ªëng h·ªát b∆∞·ªõc 16 c·ªßa file tr∆∞·ªõc)
    for fold in range(CFG.N_FOLDS):
        model_path = os.path.join(CFG.MODEL_DIR, f'best_model_fold{fold}.pth')
        if not os.path.exists(model_path):
            print(f"L·ªñI: Kh√¥ng t√¨m th·∫•y file m√¥ h√¨nh: {model_path}")
            return None, None, None
        model = BiomassModel(CFG.MODEL_NAME, pretrained=False)
        try:
            model.load_state_dict(torch.load(model_path, map_location=CFG.DEVICE))
        except RuntimeError:
            state_dict = torch.load(model_path, map_location=CFG.DEVICE)
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k.replace('module.', '')
                new_state_dict[name] = v
            model.load_state_dict(new_state_dict)
        model.eval()
        model.to(CFG.DEVICE)
        models_list.append(model)
    print(f"‚úì ƒê√£ t·∫£i th√†nh c√¥ng {len(models_list)} m√¥ h√¨nh.")

    # --- 3. V√≤ng l·∫∑p TTA (V√≤ng l·∫∑p ngo√†i) ---
    tta_transforms = get_tta_transforms()
    print(f"\nB·∫Øt ƒë·∫ßu d·ª± ƒëo√°n v·ªõi {len(tta_transforms)} TTA views...")
    
    all_tta_view_preds = [] # List ƒë·ªÉ l∆∞u k·∫øt qu·∫£ c·ªßa m·ªói view TTA

    for i, tta_transform in enumerate(tta_transforms):
        print(f"--- ƒêang ch·∫°y TTA View {i+1}/{len(tta_transforms)} ---")
        
        # T·∫°o Dataset/Loader M·ªöI cho view TTA n√†y
        test_dataset = TestBiomassDataset(
            df=test_df_unique,
            transform_pipeline=tta_transform, # Truy·ªÅn pipeline TTA
            image_dir=CFG.TEST_IMAGE_DIR
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=False,
            num_workers=CFG.NUM_WORKERS,
            pin_memory=True
        )
        
        # Ch·∫°y ensemble 5-fold cho view n√†y
        view_preds_np = predict_one_view(models_list, test_loader, CFG.DEVICE)
        all_tta_view_preds.append(view_preds_np)
        print(f"‚úì Ho√†n th√†nh TTA View {i+1}")

    # --- 4. Ensemble (L·∫•y trung b√¨nh) k·∫øt qu·∫£ TTA ---
    print("\nƒêang ensemble k·∫øt qu·∫£ c·ªßa c√°c TTA views...")
    final_ensembled_preds = {
        'total': np.mean([d['total'] for d in all_tta_view_preds], axis=0),
        'gdm':   np.mean([d['gdm'] for d in all_tta_view_preds], axis=0),
        'green': np.mean([d['green'] for d in all_tta_view_preds], axis=0)
    }
    
    print("‚úì D·ª± ƒëo√°n ho√†n t·∫•t.")
    
    del models_list, test_loader, test_dataset
    gc.collect()
    torch.cuda.empty_cache()
    
    return final_ensembled_preds, test_df_long, test_df_unique
# ===============================================================
# 6. ‚úçÔ∏è H√ÄM T·∫†O FILE SUBMISSION
# ===============================================================
def create_submission(preds_np, test_df_long, test_df_unique):
    """
    H√†m n√†y nh·∫≠n 3 d·ª± ƒëo√°n ƒë√£ ensemble,
    t√≠nh to√°n 2 d·ª± ƒëo√°n c√≤n l·∫°i,
    v√† ƒë·ªãnh d·∫°ng file n·ªôp b√†i.
    """
    if preds_np is None:
        print("B·ªè qua t·∫°o submission do l·ªói ·ªü tr√™n.")
        return

    print("\nƒêang h·∫≠u x·ª≠ l√Ω v√† t·∫°o file submission...")

    # 1. L·∫•y 3 d·ª± ƒëo√°n ƒë√£ ensemble
    pred_total_final = preds_np['total']
    pred_gdm_final = preds_np['gdm']
    pred_green_final = preds_np['green']

    # 2. T√≠nh 2 m·ª•c ti√™u c√≤n l·∫°i (H·∫≠u x·ª≠ l√Ω)
    # D√πng np.maximum(0, ...) ƒë·ªÉ ƒë·∫£m b·∫£o kh√¥ng c√≥ gi√° tr·ªã √¢m
    pred_clover_final = np.maximum(0, pred_gdm_final - pred_green_final)
    pred_dead_final = np.maximum(0, pred_total_final - pred_gdm_final)

    # 3. T·∫°o m·ªôt DataFrame "wide" ch·ª©a 5 d·ª± ƒëo√°n
    # (ƒê·∫£m b·∫£o th·ª© t·ª± 5 c·ªôt gi·ªëng CFG.ALL_TARGET_COLS)
    preds_wide_df = pd.DataFrame({
        'image_path': test_df_unique['image_path'],
        'Dry_Green_g': pred_green_final,
        'Dry_Dead_g': pred_dead_final,
        'Dry_Clover_g': pred_clover_final,
        'GDM_g': pred_gdm_final,
        'Dry_Total_g': pred_total_final
    })

    # 4. "Un-pivot" DataFrame (Chuy·ªÉn sang d·∫°ng "long")
    # Bi·∫øn n√≥ t·ª´ 5 c·ªôt v·ªÅ d·∫°ng "long" (gi·ªëng sample_submission)
    preds_long_df = preds_wide_df.melt(
        id_vars=['image_path'],
        value_vars=CFG.ALL_TARGET_COLS, # 5 c·ªôt m·ª•c ti√™u
        var_name='target_name',        # C·ªôt t√™n m·ª•c ti√™u
        value_name='target'            # C·ªôt gi√° tr·ªã d·ª± ƒëo√°n
    )

    # 5. Merge v·ªõi file test.csv g·ªëc (test_df_long)
    # ƒê√¢y l√† b∆∞·ªõc quan tr·ªçng ƒë·ªÉ l·∫•y ƒë√∫ng 'sample_id'
    # (v√≠ d·ª•: 'ID1001187975__Dry_Clover_g')
    submission_df = pd.merge(
        test_df_long[['sample_id', 'image_path', 'target_name']],
        preds_long_df,
        on=['image_path', 'target_name'],
        how='left'
    )

    # 6. D·ªçn d·∫πp v√† L∆∞u
    # Ch·ªâ gi·ªØ l·∫°i 2 c·ªôt ƒë∆∞·ª£c y√™u c·∫ßu
    submission_df = submission_df[['sample_id', 'target']]
    
    # L∆∞u file
    submission_df.to_csv(CFG.SUBMISSION_FILE, index=False)

    print(f"\nüéâ HO√ÄN T·∫§T! ƒê√£ l∆∞u file submission t·∫°i: {CFG.SUBMISSION_FILE}")
    print("--- 5 h√†ng ƒë·∫ßu c·ªßa file submission ---")
    print(submission_df.head())
    print("\n--- 5 h√†ng cu·ªëi c·ªßa file submission ---")
    print(submission_df.tail())
    
# ===============================================================
# 8. üèÅ CH·∫†Y CH∆Ø∆†NG TR√åNH (ƒê√£ s·ª≠a)
# ===============================================================
if __name__ == "__main__":
    # 1. Ch·∫°y d·ª± ƒëo√°n (ƒë√£ bao g·ªìm TTA)
    all_preds_np, df_long, df_unique = run_inference_with_tta()
    
    # 2. T·∫°o file submission (H√†m create_submission gi·ªØ nguy√™n)
    create_submission(all_preds_np, df_long, df_unique)