In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm  # Th∆∞ vi·ªán tuy·ªát v·ªùi cho c√°c backbone
import cv2  # OpenCV ƒë·ªÉ ƒë·ªçc ·∫£nh
from tqdm import tqdm # Thanh ti·∫øn tr√¨nh
import matplotlib.pyplot as plt

# --- L·ªöP C·∫§U H√åNH TRUNG T√ÇM ---
# Qu·∫£n l√Ω t·∫•t c·∫£ c√°c si√™u tham s·ªë (hyperparameters) t·∫°i ƒë√¢y
class CFG:
    # --- ƒê∆∞·ªùng d·∫´n (Paths) ---
    # (H√£y ƒëi·ªÅu ch·ªânh c√°c ƒë∆∞·ªùng d·∫´n n√†y cho ƒë√∫ng v·ªõi m√¥i tr∆∞·ªùng c·ªßa b·∫°n)
    BASE_PATH = '/kaggle/input/csiro-biomass'
    TRAIN_CSV = os.path.join(BASE_PATH, 'train.csv')
    IMAGE_DIR = os.path.join(BASE_PATH, 'train')
    
    # --- C√†i ƒë·∫∑t M√¥ h√¨nh ---
    MODEL_NAME = 'convnext_tiny' # B·∫°n c√≥ th·ªÉ ƒë·ªïi sang 'resnet50'
    PRETRAINED = True
    IMG_SIZE = 768 # K√≠ch th∆∞·ªõc ·∫£nh ƒë·∫ßu v√†o (t·ª´ 1000x1000 n√©n xu·ªëng)
    
    # --- C√†i ƒë·∫∑t Hu·∫•n luy·ªán ---
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    BATCH_SIZE = 2   # B·∫Øt ƒë·∫ßu v·ªõi s·ªë nh·ªè (v√¨ 2 lu·ªìng ·∫£nh 768x768 s·∫Ω t·ªën VRAM)
    EPOCHS = 30
    LEARNING_RATE = 1e-4      # LR cho Giai ƒëo·∫°n 1 (hu·∫•n luy·ªán heads)
    
    # --- C√†i ƒë·∫∑t Giai ƒëo·∫°n 2 (Fine-tuning) ---
    FREEZE_EPOCHS = 15         # S·ªë epochs ch·ªâ hu·∫•n luy·ªán 'heads'
    FINETUNE_LR = 1e-5        # LR th·∫•p h∆°n cho Giai ƒëo·∫°n 2 (to√†n b·ªô m√¥ h√¨nh)
    
    NUM_WORKERS = 2  # S·ªë lu·ªìng t·∫£i d·ªØ li·ªáu
    
    # --- M·ª•c ti√™u & Loss ---
    TARGET_COLS = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g']
    LOSS_WEIGHTS = {
        'total_loss': 0.5,
        'gdm_loss': 0.2,
        'green_loss': 0.1
    }
    # 5 M·ª§C TI√äU ƒê·ªÇ T√çNH ƒêI·ªÇM (theo ƒë√∫ng th·ª© t·ª± c·ªßa cu·ªôc thi)
    ALL_TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    
    # Tr·ªçng s·ªë R2 (theo th·ª© t·ª± c·ªßa ALL_TARGET_COLS)
    R2_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5]

# --- In ra ƒë·ªÉ x√°c nh·∫≠n ---
print(f"S·ª≠ d·ª•ng thi·∫øt b·ªã: {CFG.DEVICE}")
print(f"Backbone m√¥ h√¨nh: {CFG.MODEL_NAME}")
print(f"K√≠ch th∆∞·ªõc ·∫£nh hu·∫•n luy·ªán: {CFG.IMG_SIZE}x{CFG.IMG_SIZE}")

In [None]:
# --- T·∫£i d·ªØ li·ªáu v√† Pivot ---

print(f"ƒêang t·∫£i {CFG.TRAIN_CSV}...")
try:
    # 1. ƒê·ªçc file CSV g·ªëc (d·∫°ng long)
    df_long = pd.read_csv(CFG.TRAIN_CSV)
    print(f"B·∫£ng 'long' g·ªëc c√≥ {len(df_long)} h√†ng.")
    
    # 2. Th·ª±c hi·ªán Pivot
    # Ch√∫ng ta d√πng 'image_path' l√†m ch·ªâ m·ª•c (index) duy nh·∫•t
    # Xoay c·ªôt 'target_name' th√†nh c√°c c·ªôt m·ªõi
    # L·∫•y gi√° tr·ªã t·ª´ c·ªôt 'target'
    df_wide = df_long.pivot(
        index='image_path',
        columns='target_name',
        values='target'
    )
    
    # 3. D·ªçn d·∫πp DataFrame
    # Sau khi pivot, 'image_path' tr·ªü th√†nh index.
    # 'reset_index()' s·∫Ω bi·∫øn n√≥ tr·ªü l·∫°i th√†nh m·ªôt c·ªôt b√¨nh th∆∞·ªùng.
    df_wide = df_wide.reset_index()
    df_wide.columns.name = None # X√≥a t√™n 'target_name' kh·ªèi tr·ª•c c·ªôt
    
    print(f"ƒê√£ pivot! B·∫£ng 'wide' m·ªõi c√≥ {len(df_wide)} h√†ng (m·ªói h√†ng 1 ·∫£nh).")
    
    # 4. Hi·ªÉn th·ªã 5 h√†ng ƒë·∫ßu ti√™n c·ªßa b·∫£ng 'wide' m·ªõi
    print("\n--- 5 h√†ng ƒë·∫ßu c·ªßa df_wide ---")
    print(df_wide.head())
    
    # 5. (T√πy ch·ªçn) Ki·ªÉm tra xem c√°c c·ªôt m·ª•c ti√™u c·ªßa ch√∫ng ta c√≥ ·ªü ƒë√≥ kh√¥ng
    print("\n--- C√°c c·ªôt c√≥ trong df_wide ---")
    print(df_wide.columns.tolist())
    
except FileNotFoundError:
    print(f"L·ªñI: Kh√¥ng t√¨m th·∫•y file {CFG.TRAIN_CSV}")
    print("Vui l√≤ng ki·ªÉm tra l·∫°i CFG.TRAIN_CSV")
    # Gi·∫£ l·∫≠p df_wide n·∫øu kh√¥ng t√¨m th·∫•y file ƒë·ªÉ c√°c b∆∞·ªõc sau kh√¥ng b·ªã l·ªói
    df_wide = pd.DataFrame(columns=['image_path'] + CFG.TARGET_COLS)

In [None]:
from albumentations import (
    Compose, 
    Resize, 
    Normalize,
    HorizontalFlip, 
    VerticalFlip,
    RandomRotate90,  # Ch·ªâ xoay 90, 180, 270
    ColorJitter
)

# --- ƒê·ªãnh nghƒ©a Augmentations ---

# Ch√∫ng ta s·∫Ω d√πng pipeline N√ÄY cho C·∫¢ hai ·∫£nh (tr√°i v√† ph·∫£i) M·ªòT C√ÅCH ƒê·ªòC L·∫¨P
def get_train_transforms():
    """
    TƒÉng c∆∞·ªùng d·ªØ li·ªáu cho hu·∫•n luy·ªán.
    S·∫Ω ƒë∆∞·ª£c √°p d·ª•ng ƒë·ªôc l·∫≠p cho img_left v√† img_right.
    """
    return Compose([
        # 1. TƒÉng c∆∞·ªùng h√¨nh h·ªçc (Geometric)
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        RandomRotate90(p=0.5), # T·ª± ƒë·ªông xoay 90, 180, ho·∫∑c 270

        # 2. TƒÉng c∆∞·ªùng m√†u s·∫Øc (Photometric)
        ColorJitter(
            brightness=0.2, 
            contrast=0.2, 
            saturation=0.2, 
            hue=0.1, 
            p=0.75
        ),

        # 3. Chu·∫©n h√≥a (Normalize)
        # Gi√° tr·ªã mean/std ti√™u chu·∫©n c·ªßa ImageNet
        Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        
        # 4. N√©n ·∫£nh v√† Chuy·ªÉn sang Tensor
        # (N√©n sau khi augment ƒë·ªÉ ƒë·∫£m b·∫£o ch·∫•t l∆∞·ª£ng)
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        ToTensorV2()
    ])

def get_valid_transforms():
    """
    Ch·ªâ chu·∫©n h√≥a v√† n√©n ·∫£nh cho t·∫≠p validation (kh√¥ng augment).
    """
    return Compose([
        Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
        Resize(CFG.IMG_SIZE, CFG.IMG_SIZE),
        ToTensorV2()
    ])

# --- In ra ƒë·ªÉ x√°c nh·∫≠n ---
print("ƒê√£ ƒë·ªãnh nghƒ©a c√°c h√†m Augmentation.")
print(f"·∫¢nh hu·∫•n luy·ªán s·∫Ω ƒë∆∞·ª£c augment v√† n√©n xu·ªëng: {CFG.IMG_SIZE}x{CFG.IMG_SIZE}")
print(f"·∫¢nh validation s·∫Ω ch·ªâ ƒë∆∞·ª£c n√©n xu·ªëng: {CFG.IMG_SIZE}x{CFG.IMG_SIZE}")

In [None]:
# (ƒê√¢y l√† phi√™n b·∫£n M·ªöI c·ªßa l·ªõp BiomassDataset)

class BiomassDataset(Dataset):
    """
    Dataset t√πy ch·ªânh cho chi·∫øn l∆∞·ª£c "Hai lu·ªìng".
    
    S·∫Ω tr·∫£ v·ªÅ:
    (img_left, img_right, train_targets (3), all_targets (5))
    """
    def __init__(self, df, transforms_fn, image_dir, train_target_cols, all_target_cols):
        self.df = df
        self.transforms_fn = transforms_fn
        self.image_dir = image_dir
        
        # L∆∞u tr·ªØ s·∫µn ƒë·ªÉ truy c·∫≠p nhanh
        self.image_paths = df['image_path'].values
        # 3 m·ª•c ti√™u cho vi·ªác t√≠nh loss
        self.train_targets = df[train_target_cols].values
        # 5 m·ª•c ti√™u cho vi·ªác t√≠nh ƒëi·ªÉm R2
        self.all_targets = df[all_target_cols].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # 1. L·∫•y th√¥ng tin
        img_path_suffix = self.image_paths[idx]
        train_target_vals = self.train_targets[idx]
        all_target_vals = self.all_targets[idx]
        
        # 2. ƒê·ªçc ·∫£nh g·ªëc (2000x1000)
        filename = os.path.basename(img_path_suffix)
        full_path = os.path.join(self.image_dir, filename)
        image = cv2.imread(full_path)
        if image is None:
            raise FileNotFoundError(f"Kh√¥ng th·ªÉ ƒë·ªçc ·∫£nh: {full_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # 3. C·∫Øt (Crop) th√†nh 2 ·∫£nh (Tr√°i v√† Ph·∫£i)
        height, width, _ = image.shape
        mid_point = width // 2
        img_left = image[:, :mid_point]
        img_right = image[:, mid_point:]
        
        # 4. √Åp d·ª•ng Augmentations ƒê·ªòC L·∫¨P
        transforms = self.transforms_fn()
        img_left_tensor = transforms(image=img_left)['image']
        
        # L·∫•y pipeline M·ªöI cho ·∫£nh ph·∫£i (ƒë·ªÉ augment ƒë·ªôc l·∫≠p)
        transforms_2 = self.transforms_fn()
        img_right_tensor = transforms_2(image=img_right)['image']
        
        # 5. L·∫•y m·ª•c ti√™u (Targets)
        train_target_tensor = torch.tensor(train_target_vals, dtype=torch.float32)
        all_targets_tensor = torch.tensor(all_target_vals, dtype=torch.float32)
        
        # 6. Tr·∫£ v·ªÅ
        return img_left_tensor, img_right_tensor, train_target_tensor, all_targets_tensor

In [None]:
# --- Th·ª≠ nghi·ªám Dataset v√† DataLoader ---

print("\nƒêang ki·ªÉm tra Dataset & DataLoader...")

try:
    # 1. T·∫°o m·ªôt dataset (d√πng augmentation c·ªßa t·∫≠p train)
    # L·∫•y 10 h√†ng ƒë·∫ßu ti√™n c·ªßa df_wide ƒë·ªÉ th·ª≠ nghi·ªám
    test_df_subset = df_wide.head(10) 
    
    train_dataset = BiomassDataset(
        df=test_df_subset,
        transforms_fn=get_train_transforms, # Truy·ªÅn t√™n H√ÄM
        image_dir=CFG.IMAGE_DIR,
        target_cols=CFG.TARGET_COLS
    )
    
    # 2. T·∫°o DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG.BATCH_SIZE,
        shuffle=True,
        num_workers=CFG.NUM_WORKERS
    )
    
    # 3. L·∫•y m·ªôt batch d·ªØ li·ªáu
    print(f"ƒêang t·∫£i 1 batch (batch_size={CFG.BATCH_SIZE})...")
    img_left_batch, img_right_batch, targets_batch = next(iter(train_loader))
    
    # 4. In ra k√≠ch th∆∞·ªõc (shape)
    print("\n--- K√≠ch th∆∞·ªõc (Shape) c·ªßa Batch ƒë·∫ßu ra ---")
    print(f"  ·∫¢nh tr√°i (Left): {img_left_batch.shape}")
    print(f"  ·∫¢nh ph·∫£i (Right): {img_right_batch.shape}")
    print(f"  M·ª•c ti√™u (Targets): {targets_batch.shape}")
    
    print("\n--- K√≠ch th∆∞·ªõc mong ƒë·ª£i ---")
    print(f"  ·∫¢nh (K·ª≥ v·ªçng):   [Batch, Channels, Height, Width] -> [{CFG.BATCH_SIZE}, 3, {CFG.IMG_SIZE}, {CFG.IMG_SIZE}]")
    print(f"  M·ª•c ti√™u (K·ª≥ v·ªçng): [Batch, Num_Targets] -> [{CFG.BATCH_SIZE}, {len(CFG.TARGET_COLS)}]")
    
    # (T√πy ch·ªçn) Hi·ªÉn th·ªã m·ªôt ·∫£nh
    plt.imshow(img_left_batch[0].permute(1, 2, 0) * 0.229 + 0.485) # ƒê√£ chu·∫©n h√≥a
    plt.title("M·ªôt ·∫£nh m·∫´u t·ª´ Batch (Tr√°i)")
    plt.show()

except Exception as e:
    print(f"\nL·ªñI khi ki·ªÉm tra DataLoader: {e}")
    print("H√£y ki·ªÉm tra l·∫°i ƒë∆∞·ªùng d·∫´n trong CFG v√† c·∫•u tr√∫c file ·∫£nh.")

In [None]:
class BiomassModel(nn.Module):
    """
    M√¥ h√¨nh Hai lu·ªìng, Ba ƒë·∫ßu ra chuy√™n bi·ªát.
    
    1. M·ªôt backbone (v√≠ d·ª•: ConvNeXt) d√πng chung.
    2. Backbone ƒë∆∞·ª£c √°p d·ª•ng cho img_left v√† img_right.
    3. Hai vector ƒë·∫∑c tr∆∞ng ƒë∆∞·ª£c gh√©p (concatenate) l·∫°i.
    4. Vector h·ª£p nh·∫•t ƒë∆∞·ª£c ƒë∆∞a v√†o 3 "ƒë·∫ßu" MLP ri√™ng bi·ªát.
    """
    def __init__(self, model_name, pretrained, n_targets=3):
        super(BiomassModel, self).__init__()
        
        # 1. T·∫£i Backbone (Th√¢n) d√πng chung
        # num_classes=0: Lo·∫°i b·ªè l·ªõp ph√¢n lo·∫°i g·ªëc.
        # global_pool='avg': Th√™m l·ªõp Global Average Pooling
        #                      ƒë·ªÉ l·∫•y ra m·ªôt vector ƒë·∫∑c tr∆∞ng 1D.
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0,
            global_pool='avg'
        )
        
        # 2. L·∫•y s·ªë chi·ªÅu ƒë·∫∑c tr∆∞ng (feature dimension) t·ª´ backbone
        # V√≠ d·ª•: convnext_tiny l√† 768, resnet50 l√† 2048
        self.n_features = self.backbone.num_features
        
        # 3. T√≠nh s·ªë chi·ªÅu sau khi h·ª£p nh·∫•t (fusion)
        # (features_left + features_right)
        self.n_combined_features = self.n_features * 2
        
        # 4. ƒê·ªãnh nghƒ©a Ba (3) "ƒê·∫ßu" MLP ri√™ng bi·ªát
        # M·ªói "ƒë·∫ßu" l√† m·ªôt m·∫°ng n∆°-ron nh·ªè chuy√™n bi·ªát
        
        # --- ƒê·∫ßu cho Dry_Total_g ---
        self.head_total = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1) # ƒê·∫ßu ra 1 gi√° tr·ªã
        )
        
        # --- ƒê·∫ßu cho GDM_g ---
        self.head_gdm = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1) # ƒê·∫ßu ra 1 gi√° tr·ªã
        )
        
        # --- ƒê·∫ßu cho Dry_Green_g ---
        self.head_green = nn.Sequential(
            nn.Linear(self.n_combined_features, self.n_combined_features // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(self.n_combined_features // 2, 1) # ƒê·∫ßu ra 1 gi√° tr·ªã
        )

    def forward(self, img_left, img_right):
        # 1. Ch·∫°y Lu·ªìng 1 (Tr√°i)
        features_left = self.backbone(img_left) # Shape: [batch, n_features]
        
        # 2. Ch·∫°y Lu·ªìng 2 (Ph·∫£i)
        features_right = self.backbone(img_right) # Shape: [batch, n_features]
        
        # 3. H·ª£p nh·∫•t (Fusion)
        # Gh√©p 2 vector ƒë·∫∑c tr∆∞ng l·∫°i
        combined = torch.cat([features_left, features_right], dim=1) # Shape: [batch, n_combined_features]
        
        # 4. Cho qua c√°c "ƒê·∫ßu" ri√™ng bi·ªát
        out_total = self.head_total(combined)
        out_gdm = self.head_gdm(combined)
        out_green = self.head_green(combined)
        
        # 5. Tr·∫£ v·ªÅ 3 gi√° tr·ªã d·ª± ƒëo√°n
        # Ch√∫ng ta tr·∫£ v·ªÅ 3 tensor ri√™ng bi·ªát ƒë·ªÉ h√†m loss d·ªÖ x·ª≠ l√Ω
        return out_total, out_gdm, out_green

In [None]:
# --- Th·ª≠ nghi·ªám Ki·∫øn tr√∫c M√¥ h√¨nh ---
print("\nƒêang ki·ªÉm tra ki·∫øn tr√∫c m√¥ h√¨nh...")

try:
    # 1. T·∫°o m·ªôt m√¥ h√¨nh
    model = BiomassModel(
        model_name=CFG.MODEL_NAME,
        pretrained=False # Kh√¥ng c·∫ßn t·∫£i tr·ªçng s·ªë, ch·ªâ ki·ªÉm tra ki·∫øn tr√∫c
    ).to(CFG.DEVICE)
    
    # 2. T·∫°o d·ªØ li·ªáu gi·∫£ (dummy data)
    # (Batch size, channels, height, width)
    dummy_left = torch.randn(CFG.BATCH_SIZE, 3, CFG.IMG_SIZE, CFG.IMG_SIZE).to(CFG.DEVICE)
    dummy_right = torch.randn(CFG.BATCH_SIZE, 3, CFG.IMG_SIZE, CFG.IMG_SIZE).to(CFG.DEVICE)
    
    print(f"ƒêang ƒë∆∞a 2 batch {dummy_left.shape} v√†o m√¥ h√¨nh...")
    
    # 3. Ch·∫°y forward pass
    out_total, out_gdm, out_green = model(dummy_left, dummy_right)
    
    # 4. In ra k√≠ch th∆∞·ªõc ƒë·∫ßu ra
    print("\n--- K√≠ch th∆∞·ªõc (Shape) c·ªßa ƒê·∫ßu ra M√¥ h√¨nh ---")
    print(f"  ƒê·∫ßu ra Total: {out_total.shape}")
    print(f"  ƒê·∫ßu ra GDM:   {out_gdm.shape}")
    print(f"  ƒê·∫ßu ra Green: {out_green.shape}")
    
    print("\n--- K√≠ch th∆∞·ªõc mong ƒë·ª£i (m·ªói ƒë·∫ßu ra) ---")
    print(f"  K·ª≥ v·ªçng: [Batch, 1] -> [{CFG.BATCH_SIZE}, 1]")

except Exception as e:
    print(f"\nL·ªñI khi ki·ªÉm tra m√¥ h√¨nh: {e}")

In [None]:
class WeightedBiomassLoss(nn.Module):
    """
    H√†m loss t·ªïng h·ª£p c√≥ tr·ªçng s·ªë.
    
    T√≠nh to√°n 3 h√†m MSELoss ri√™ng bi·ªát v√† c·ªông ch√∫ng l·∫°i
    v·ªõi c√°c tr·ªçng s·ªë t·ª´ CFG.
    """
    def __init__(self, loss_weights_dict):
        super(WeightedBiomassLoss, self).__init__()
        # Ch√∫ng ta c√≥ th·ªÉ d√πng m·ªôt (1) h√†m MSELoss
        # v√† g·ªçi n√≥ 3 l·∫ßn
        self.criterion = nn.SmoothL1Loss()
        
        # L∆∞u tr·ªØ c√°c tr·ªçng s·ªë
        self.weights = loss_weights_dict

    def forward(self, predictions, targets):
        """
        predictions: M·ªôt tuple (out_total, out_gdm, out_green) t·ª´ m√¥ h√¨nh
        targets: M·ªôt tensor [batch_size, 3] t·ª´ dataloader
        """
        
        # 1. T√°ch c√°c d·ª± ƒëo√°n
        pred_total, pred_gdm, pred_green = predictions
        
        # 2. T√°ch c√°c m·ª•c ti√™u (ground truth)
        # targets shape l√† [batch, 3]
        # C·ªôt 0: Dry_Total_g
        # C·ªôt 1: GDM_g
        # C·ªôt 2: Dry_Green_g
        true_total = targets[:, 0].unsqueeze(-1) # Shape [batch, 1]
        true_gdm   = targets[:, 1].unsqueeze(-1) # Shape [batch, 1]
        true_green = targets[:, 2].unsqueeze(-1) # Shape [batch, 1]
        
        # 3. T√≠nh 3 h√†m loss ri√™ng bi·ªát
        loss_total = self.criterion(pred_total, true_total)
        loss_gdm   = self.criterion(pred_gdm, true_gdm)
        loss_green = self.criterion(pred_green, true_green)
        
        # 4. √Åp d·ª•ng tr·ªçng s·ªë v√† t√≠nh t·ªïng
        total_loss = (
            self.weights['total_loss'] * loss_total +
            self.weights['gdm_loss'] * loss_gdm +
            self.weights['green_loss'] * loss_green
        )
        
        return total_loss

In [None]:
# --- Kh·ªüi t·∫°o v√† Ki·ªÉm tra H√†m Loss ---

print("\nƒêang kh·ªüi t·∫°o h√†m loss...")

# 1. Kh·ªüi t·∫°o
criterion = WeightedBiomassLoss(loss_weights_dict=CFG.LOSS_WEIGHTS)
criterion.to(CFG.DEVICE)

print(f"H√†m loss ƒë√£ ƒë∆∞·ª£c t·∫°o v·ªõi c√°c tr·ªçng s·ªë: {CFG.LOSS_WEIGHTS}")

# 2. Ki·ªÉm tra
print("\nƒêang ki·ªÉm tra h√†m loss...")
try:
    # T·∫°o d·ªØ li·ªáu gi·∫£
    # (Gi·ªëng nh∆∞ output c·ªßa m√¥ h√¨nh)
    dummy_preds = (
        torch.randn(CFG.BATCH_SIZE, 1).to(CFG.DEVICE),
        torch.randn(CFG.BATCH_SIZE, 1).to(CFG.DEVICE),
        torch.randn(CFG.BATCH_SIZE, 1).to(CFG.DEVICE)
    )
    
    # (Gi·ªëng nh∆∞ output c·ªßa dataloader)
    dummy_targets = torch.randn(CFG.BATCH_SIZE, len(CFG.TARGET_COLS)).to(CFG.DEVICE)
    
    # T√≠nh loss
    loss_value = criterion(dummy_preds, dummy_targets)
    
    print(f"  D·ªØ li·ªáu gi·∫£: {CFG.BATCH_SIZE} m·∫´u")
    print(f"  Gi√° tr·ªã loss t√≠nh ƒë∆∞·ª£c: {loss_value.item():.4f}")
    print("‚úì H√†m loss ho·∫°t ƒë·ªông ch√≠nh x√°c.")
    
except Exception as e:
    print(f"\nL·ªñI khi ki·ªÉm tra h√†m loss: {e}")

In [None]:
from sklearn.model_selection import StratifiedKFold
import numpy as np

# --- Th√™m v√†o l·ªõp CFG ---
CFG.N_FOLDS = 5  # Ch√∫ng ta s·∫Ω d√πng 5-Fold Cross-Validation
CFG.RANDOM_STATE = 42 # ƒê·ªÉ ƒë·∫£m b·∫£o k·∫øt qu·∫£ c√≥ th·ªÉ t√°i l·∫≠p

print(f"\nƒêang chu·∫©n b·ªã {CFG.N_FOLDS}-Fold Cross-Validation...")

# 1. T·∫°o m·ªôt c·ªôt 'fold' m·ªõi trong df_wide, m·∫∑c ƒë·ªãnh l√† -1
df_wide['fold'] = -1

# 2. T·∫°o Bins (Nh√≥m) cho m·ª•c ti√™u quan tr·ªçng nh·∫•t (Dry_Total_g)
# 'pd.cut' s·∫Ω chia c√°c gi√° tr·ªã li√™n t·ª•c th√†nh 10 nh√≥m
# (q=10: quantile-based, ƒë·∫£m b·∫£o m·ªói nh√≥m c√≥ s·ªë l∆∞·ª£ng m·∫´u t∆∞∆°ng ƒë∆∞∆°ng)
num_bins = int(np.floor(1 + np.log2(len(df_wide)))) # C√¥ng th·ª©c Sturges (ho·∫∑c d√πng 10)
if len(df_wide) > 100: # D√πng 10 bins n·∫øu ƒë·ªß d·ªØ li·ªáu
    num_bins = 10
    
print(f"S·ª≠ d·ª•ng {num_bins} bins ƒë·ªÉ ph√¢n t·∫ßng (stratify) tr√™n 'Dry_Total_g'")

df_wide['total_bin'] = pd.cut(
    df_wide['Dry_Total_g'], 
    bins=num_bins, 
    labels=False # Ch·ªâ c·∫ßn nh√£n s·ªë
)

# 3. Kh·ªüi t·∫°o StratifiedKFold
# Ch√∫ng ta chia d·ª±a tr√™n c√°c 'bins' v·ª´a t·∫°o
skf = StratifiedKFold(
    n_splits=CFG.N_FOLDS, 
    shuffle=True, 
    random_state=CFG.RANDOM_STATE
)

# 4. G√°n s·ªë 'fold' cho m·ªói h√†ng
# 'skf.split' tr·∫£ v·ªÅ (train_indices, valid_indices)
# Ch√∫ng ta ch·ªâ c·∫ßn valid_indices ƒë·ªÉ g√°n s·ªë fold
for fold_num, (train_idx, valid_idx) in enumerate(skf.split(df_wide, df_wide['total_bin'])):
    # G√°n s·ªë fold (0, 1, 2, 3, 4) cho c√°c h√†ng trong t·∫≠p validation
    df_wide.loc[valid_idx, 'fold'] = fold_num

# 5. Ki·ªÉm tra k·∫øt qu·∫£
print("\n--- Ph√¢n ph·ªëi s·ªë l∆∞·ª£ng m·∫´u trong m·ªói Fold ---")
print(df_wide['fold'].value_counts().sort_index())

print("\n--- df_wide sau khi th√™m c·ªôt 'fold' (5 h√†ng ƒë·∫ßu) ---")
print(df_wide.head())

In [None]:
from sklearn.metrics import r2_score

def calculate_competition_score(all_preds_3, all_targets_5):
    """
    T√≠nh ƒëi·ªÉm R^2 c√≥ tr·ªçng s·ªë c·ªßa cu·ªôc thi.
    
    Args:
        all_preds_3 (dict): Dict ch·ª©a 3 m·∫£ng numpy d·ª± ƒëo√°n t·ª´ model.
        all_targets_5 (np.array): M·∫£ng [N, 5] ch·ª©a 5 gi√° tr·ªã ground truth.
    
    Returns:
        float: ƒêi·ªÉm R^2 cu·ªëi c√πng.
    """
    
    # 1. T√°i c·∫•u tr√∫c 5 d·ª± ƒëo√°n t·ª´ 3 ƒë·∫ßu ra
    pred_total = all_preds_3['total']
    pred_gdm = all_preds_3['gdm']
    pred_green = all_preds_3['green']
    
    # ƒê·∫£m b·∫£o kh√¥ng c√≥ gi√° tr·ªã √¢m
    pred_clover = np.maximum(0, pred_gdm - pred_green)
    pred_dead = np.maximum(0, pred_total - pred_gdm)
    
    # Gh√©p 5 d·ª± ƒëo√°n (ph·∫£i ƒë√∫ng th·ª© t·ª± c·ªßa CFG.ALL_TARGET_COLS)
    # ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
    y_preds = np.stack([
        pred_green,
        pred_dead,
        pred_clover,
        pred_gdm,
        pred_total
    ], axis=1) # Shape: [N, 5]
    
    y_true = all_targets_5 # Shape: [N, 5]

    # 2. T√≠nh R^2 cho t·ª´ng m·ª•c ti√™u
    # (r2_score c·ªßa sklearn c√≥ th·ªÉ x·ª≠ l√Ω multi-output)
    r2_scores = r2_score(y_true, y_preds, multioutput='raw_values')
    
    # 3. √Åp d·ª•ng tr·ªçng s·ªë
    weighted_r2_total = 0.0
    for i, weight in enumerate(CFG.R2_WEIGHTS):
        weighted_r2_total += r2_scores[i] * weight
        
    return weighted_r2_total

In [None]:
import torch.optim as optim
from tqdm import tqdm
import time
import gc

def train_one_epoch(model, loader, criterion, optimizer, device):
    """
    Ch·∫°y 1 epoch hu·∫•n luy·ªán.
    (ƒê√£ c·∫≠p nh·∫≠t ƒë·ªÉ x·ª≠ l√Ω 4 ƒë·∫ßu ra t·ª´ Dataset)
    """
    model.train()  # Chuy·ªÉn m√¥ h√¨nh sang ch·∫ø ƒë·ªô .train()
    epoch_loss = 0.0
    
    pbar = tqdm(loader, desc="Training", leave=False)
    
    # *** THAY ƒê·ªîI CH√çNH ·ªû ƒê√ÇY ***
    # Dataloader gi·ªù tr·∫£ v·ªÅ 4 gi√° tr·ªã.
    # Ch√∫ng ta d√πng d·∫•u g·∫°ch d∆∞·ªõi (_) ƒë·ªÉ b·ªè qua gi√° tr·ªã th·ª© 4 (all_targets).
    for (img_left, img_right, train_targets, _all_targets_ignored) in pbar:
        
        # 1. Chuy·ªÉn 3 d·ªØ li·ªáu c·∫ßn thi·∫øt sang device (GPU/CPU)
        img_left = img_left.to(device)
        img_right = img_right.to(device)
        targets = train_targets.to(device) # Ch·ªâ 3 m·ª•c ti√™u cho loss
        
        # 2. X√≥a gradients c≈©
        optimizer.zero_grad()
        
        # 3. Forward pass (Lan truy·ªÅn ti·∫øn)
        predictions = model(img_left, img_right)
        
        # 4. T√≠nh to√°n Loss (v·ªõi 3 m·ª•c ti√™u)
        loss = criterion(predictions, targets)
        
        # 5. Backward pass (Lan truy·ªÅn ng∆∞·ª£c)
        loss.backward()
        
        # 6. C·∫≠p nh·∫≠t tr·ªçng s·ªë
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix(loss=f'{loss.item():.4f}') # Hi·ªÉn th·ªã loss hi·ªán t·∫°i
        
    # Tr·∫£ v·ªÅ loss trung b√¨nh c·ªßa epoch
    return epoch_loss / len(loader)

def validate_one_epoch(model, loader, criterion, device):
    """
    Ch·∫°y 1 epoch ƒë√°nh gi√° V√Ä t√≠nh ƒëi·ªÉm R^2.
    """
    model.eval()
    epoch_loss = 0.0
    
    # List ƒë·ªÉ thu th·∫≠p t·∫•t c·∫£ d·ª± ƒëo√°n v√† nh√£n
    all_preds_3 = {'total': [], 'gdm': [], 'green': []}
    all_targets_list = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False)
        # Ch√∫ √Ω: Dataloader gi·ªù tr·∫£ v·ªÅ 4 gi√° tr·ªã
        for (img_left, img_right, train_targets, all_targets) in pbar:
            
            img_left = img_left.to(device)
            img_right = img_right.to(device)
            train_targets = train_targets.to(device) # 3 m·ª•c ti√™u cho loss
            # all_targets v·∫´n ·ªü CPU, ch√∫ng ta s·∫Ω x·ª≠ l√Ω sau
            
            # 1. Forward pass
            pred_total, pred_gdm, pred_green = model(img_left, img_right)
            
            # 2. T√≠nh to√°n Loss (ch·ªâ d·ª±a tr√™n 3 m·ª•c ti√™u)
            predictions_tuple = (pred_total, pred_gdm, pred_green)
            loss = criterion(predictions_tuple, train_targets)
            epoch_loss += loss.item()
            
            # 3. Thu th·∫≠p k·∫øt qu·∫£ ƒë·ªÉ t√≠nh R^2
            # Chuy·ªÉn v·ªÅ CPU v√† l∆∞u d∆∞·ªõi d·∫°ng numpy
            all_preds_3['total'].append(pred_total.cpu().numpy())
            all_preds_3['gdm'].append(pred_gdm.cpu().numpy())
            all_preds_3['green'].append(pred_green.cpu().numpy())
            all_targets_list.append(all_targets.cpu().numpy())

    # --- K·∫øt th√∫c Epoch ---
    
    # 4. Gh√©p t·∫•t c·∫£ c√°c batch l·∫°i
    # Gh√©p c√°c list m·∫£ng numpy th√†nh m·ªôt m·∫£ng numpy l·ªõn
    preds_dict_np = {
        'total': np.concatenate(all_preds_3['total']).flatten(),
        'gdm':   np.concatenate(all_preds_3['gdm']).flatten(),
        'green': np.concatenate(all_preds_3['green']).flatten()
    }
    targets_np_5 = np.concatenate(all_targets_list) # Shape [N, 5]
    
    # 5. T√≠nh ƒëi·ªÉm R^2
    competition_score = calculate_competition_score(preds_dict_np, targets_np_5)
    
    avg_epoch_loss = epoch_loss / len(loader)
    
    return avg_epoch_loss, competition_score

In [None]:
import time
import torch.optim as optim
import gc

def run_training(fold_to_run):
    """
    H√†m ch√≠nh ƒë√£ ƒë∆∞·ª£c c·∫≠p nh·∫≠t v·ªõi chi·∫øn l∆∞·ª£c 2 giai ƒëo·∫°n (Freeze/Unfreeze).
    """
    print(f"\n{'='*50}")
    print(f"üöÄ B·∫ÆT ƒê·∫¶U HU·∫§N LUY·ªÜN FOLD {fold_to_run} (Chi·∫øn l∆∞·ª£c 2 giai ƒëo·∫°n) üöÄ")
    print(f"{'='*50}")
    
    start_time = time.time()
    
    # 1. Chia d·ªØ li·ªáu (Gi·ªØ nguy√™n)
    print(f"ƒêang chia d·ªØ li·ªáu cho Fold {fold_to_run}...")
    train_df = df_wide[df_wide['fold'] != fold_to_run].reset_index(drop=True)
    valid_df = df_wide[df_wide['fold'] == fold_to_run].reset_index(drop=True)
    
    # 2. T·∫°o Datasets (Gi·ªØ nguy√™n)
    # (Gi·∫£ s·ª≠ b·∫°n ƒëang d√πng BiomassDataset tr·∫£ v·ªÅ 4 gi√° tr·ªã)
    train_dataset = BiomassDataset(
        df=train_df, transforms_fn=get_train_transforms, image_dir=CFG.IMAGE_DIR,
        train_target_cols=CFG.TARGET_COLS, all_target_cols=CFG.ALL_TARGET_COLS
    )
    valid_dataset = BiomassDataset(
        df=valid_df, transforms_fn=get_valid_transforms, image_dir=CFG.IMAGE_DIR,
        train_target_cols=CFG.TARGET_COLS, all_target_cols=CFG.ALL_TARGET_COLS
    )
    
    # 3. T·∫°o DataLoaders (Gi·ªØ nguy√™n)
    train_loader = DataLoader(
        train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True,
        num_workers=CFG.NUM_WORKERS, pin_memory=True
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=CFG.BATCH_SIZE * 2, shuffle=False,
        num_workers=CFG.NUM_WORKERS, pin_memory=True
    )
    
    # 4. Kh·ªüi t·∫°o M√¥ h√¨nh & Loss (H·ªó tr·ª£ Multi-GPU)
    print(f"ƒêang t·∫£i backbone '{CFG.MODEL_NAME}'...")
    model_base = BiomassModel(CFG.MODEL_NAME, CFG.PRETRAINED)
    
    if torch.cuda.device_count() > 1:
        print(f"S·ª≠ d·ª•ng {torch.cuda.device_count()} GPU v·ªõi nn.DataParallel.")
        model = nn.DataParallel(model_base)
    else:
        model = model_base
    model.to(CFG.DEVICE)
    
    criterion = WeightedBiomassLoss(CFG.LOSS_WEIGHTS).to(CFG.DEVICE)
    
    # =================================================================
    # ‚ú® GIAI ƒêO·∫†N 1: HU·∫§N LUY·ªÜN "HEADS" (ƒê√ìNG BƒÇNG BACKBONE)
    # =================================================================
    print(f"\n--- GIAI ƒêO·∫†N 1: ƒê√≥ng bƒÉng Backbone (Training Heads) ---")
    print(f"Epochs: 1 ƒë·∫øn {CFG.FREEZE_EPOCHS} | LR: {CFG.LEARNING_RATE}")
    
    # ƒê√≥ng bƒÉng t·∫•t c·∫£ c√°c tham s·ªë c·ªßa backbone
    # (D√πng .module. ƒë·ªÉ truy c·∫≠p m√¥ h√¨nh g·ªëc b√™n trong DataParallel)
    for param in model.module.backbone.parameters():
        param.requires_grad = False
        
    # T·∫°o optimizer CH·ªà cho c√°c "heads" (tham s·ªë c√≥ requires_grad=True)
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=CFG.LEARNING_RATE
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=2 # Gi·∫£m patience
    )
    
    best_score = -float('inf') # Theo d√µi ƒëi·ªÉm R^2 t·ªët nh·∫•t qua C·∫¢ 2 GIAI ƒêO·∫†N

    for epoch in range(1, CFG.FREEZE_EPOCHS + 1):
        print(f"\n--- Epoch {epoch}/{CFG.EPOCHS} (Giai ƒëo·∫°n 1) ---")
        
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, CFG.DEVICE)
        valid_loss, competition_score = validate_one_epoch(model, valid_loader, criterion, CFG.DEVICE)
        
        scheduler.step(valid_loss)
        
        print(f"Epoch {epoch} - Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | Score (R^2): {competition_score:.4f}")
        
        if competition_score > best_score:
            best_score = competition_score
            print(f"‚ú® Score R^2 c·∫£i thi·ªán! ƒêang l∆∞u m√¥ h√¨nh 'best_model_fold{fold_to_run}.pth'...")
            torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), 
                       f'best_model_fold{fold_to_run}.pth')

    # =================================================================
    # ‚ú® GIAI ƒêO·∫†N 2: HU·∫§N LUY·ªÜN TO√ÄN B·ªò (R√É ƒê√îNG BACKBONE)
    # =================================================================
    print(f"\n--- GIAI ƒêO·∫†N 2: R√£ ƒë√¥ng Backbone (Fine-tuning) ---")
    print(f"Epochs: {CFG.FREEZE_EPOCHS + 1} ƒë·∫øn {CFG.EPOCHS} | LR: {CFG.FINETUNE_LR}")

    # R√£ ƒë√¥ng to√†n b·ªô m√¥ h√¨nh
    for param in model.module.backbone.parameters():
        param.requires_grad = True
        
    # T·∫°o m·ªôt optimizer M·ªöI cho TO√ÄN B·ªò m√¥ h√¨nh v·ªõi LR th·∫•p
    optimizer = optim.Adam(
        model.parameters(), 
        lr=CFG.FINETUNE_LR # S·ª≠ d·ª•ng LR m·ªõi
    )
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.2, patience=3
    )
    
    # Ti·∫øp t·ª•c v√≤ng l·∫∑p epoch
    for epoch in range(CFG.FREEZE_EPOCHS + 1, CFG.EPOCHS + 1):
        print(f"\n--- Epoch {epoch}/{CFG.EPOCHS} (Giai ƒëo·∫°n 2) ---")
        
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, CFG.DEVICE)
        valid_loss, competition_score = validate_one_epoch(model, valid_loader, criterion, CFG.DEVICE)
        
        scheduler.step(valid_loss)
        
        print(f"Epoch {epoch} - Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f} | Score (R^2): {competition_score:.4f}")
        
        if competition_score > best_score:
            best_score = competition_score
            print(f"‚ú® Score R^2 c·∫£i thi·ªán! ƒêang l∆∞u m√¥ h√¨nh 'best_model_fold{fold_to_run}.pth'...")
            torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), 
                       f'best_model_fold{fold_to_run}.pth')
            
    # --- K·∫æT TH√öC HU·∫§N LUY·ªÜN ---
    end_time = time.time()
    print(f"\nüéâ Ho√†n th√†nh Fold {fold_to_run} sau {(end_time - start_time)/60:.2f} ph√∫t.")
    print(f"ƒêi·ªÉm R^2 t·ªët nh·∫•t: {best_score:.4f}")
    
    # D·ªçn d·∫πp
    del model, train_loader, valid_loader, train_dataset, valid_dataset
    gc.collect()
    torch.cuda.empty_cache()

# --- B·∫ÆT ƒê·∫¶U CH·∫†Y HU·∫§N LUY·ªÜN ---
# (Gi·ªØ nguy√™n)
try:
    for i in range(CFG.N_FOLDS):
        run_training(fold_to_run=i)
except Exception as e:
    gc.collect()
    torch.cuda.empty_cache()
    raise e