# Distillation. Student

In [1]:
machine = "local"
!wandb login

wandb: Currently logged in as: dmykhailov (dmykhailov-kyiv-school-of-economics) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


## Imports

In [16]:
import os
import gc
import logging
import numpy as np
import pandas as pd
from PIL import Image
from typing import cast
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder

import cv2
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm import tqdm
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import warnings
from notebooks_config import setup_logging, CustomLogger

warnings.simplefilter(action='ignore', category=FutureWarning)
print(f"PyTorch: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

PyTorch: 2.9.1+cu128
Device: NVIDIA GeForce RTX 5050 Laptop GPU


In [17]:
logger = setup_logging(level=logging.DEBUG, full_color=True, include_function=False)
logger = cast(CustomLogger, logger)  # Type hinting
logger.success("Logging configuration test completed.")

[38;2;105;254;105m
[2025-12-11 14:59:29]
SUCCESS: Logging configured successfully ✅[0m
[38;2;105;254;105m
[2025-12-11 14:59:29]
SUCCESS: Logging configuration test completed.[0m


In [None]:
cpu_count = os.cpu_count()
NUM_WORKERS = 0 if machine == "local" else cpu_count // 2 if cpu_count else 0

LR = 1e-4
EPOCHS = 25
N_FOLDS = 5
GRAD_ACCUM = 1
BATCH_SIZE = 16
DROPOUT_RATE = 0.3
DISTILL_ALPHA = 0.5  # Weight for distillation loss
WEIGHT_DECAY = 0.05
HIDDEN_RATIO = 0.5
TRAIN_SPLIT_RATIO = 0.02 # Used if N_FOLDS = 0

MODEL = "swinv2_tiny_window8_256"
MODEL_STAGE = "student"  # 'teacher' or 'student'
PROJECT_NAME = "csiro-image2biomass-prediction"
CHECKPOINTS_DIR = f"./kaggle/checkpoints/{MODEL_STAGE}/"
USE_OOF_SOFT_TARGETS = False  # Whether to use OOF soft targets or 100% ensemble soft targets

# Each patch is 1000x1000, resize to 768x768 for vision transformers
SIZE = 768
USE_LOG_TARGET = True   # Whether to use log1p transformation on target variable
FUSION_METHOD = 'mean'  # ('concat', 'mean', 'max')

DESCRIPTION = machine + \
    (f"_train{TRAIN_SPLIT_RATIO}" if N_FOLDS == 0 else f"_train[{N_FOLDS}]Folds") + (
        f"_log" if USE_LOG_TARGET else "") + f"_fusion-{FUSION_METHOD}"
DESCRIPTION_FULL = MODEL + "-" + DESCRIPTION + \
    f"_epochs{EPOCHS}_bs{BATCH_SIZE}_gradacc{GRAD_ACCUM}_lr{LR}_wd{WEIGHT_DECAY}_dr{DROPOUT_RATE}_hr{HIDDEN_RATIO}"
SUBMISSION_NAME = f"{DESCRIPTION_FULL}_submission.csv"
SUBMISSION_ENSEMBLE_NAME = f"{DESCRIPTION_FULL}_ensemble_submission.csv"
SUBMISSION_MSG = DESCRIPTION_FULL.replace("_", " ")

SEED = 1488
torch.manual_seed(SEED)
np.random.seed(SEED)
pl.seed_everything(SEED)

print("DESCRIPTION_FULL:", DESCRIPTION_FULL)
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")

Seed set to 1488


DESCRIPTION_FULL: swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5
Effective batch size: 16


In [19]:
# setting device on GPU if available, else CPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
print('NUM_WORKERS:', NUM_WORKERS)
print()

# Additional Info when using cuda
if DEVICE.type == 'cuda':
    # clean GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    torch.set_float32_matmul_precision('high') if machine == "local" else None

    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3, 1), 'GB')

Using device: cuda
NUM_WORKERS: 0

NVIDIA GeForce RTX 5050 Laptop GPU
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


## Data Loading and Preprocessing

In [6]:
PATH_DATA = './kaggle/input/csiro-biomass'
PATH_TRAIN_CSV = os.path.join(PATH_DATA, 'train.csv')
PATH_TEST_CSV = os.path.join(PATH_DATA, 'test.csv')
PATH_TRAIN_IMG = os.path.join(PATH_DATA, 'train')
PATH_TEST_IMG = os.path.join(PATH_DATA, 'test')

df = pd.read_csv(PATH_TRAIN_CSV)
# Remove unneeded targets
df = df[~df['target_name'].isin(['Dry_Total_g', 'GDM_g'])]
print(f"Dataset size: {df.shape}")
display(df.head())

Dataset size: (1071, 9)


Unnamed: 0,sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
0,ID1011485656__Dry_Clover_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Clover_g,0.0
1,ID1011485656__Dry_Dead_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Dead_g,31.9984
2,ID1011485656__Dry_Green_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Green_g,16.275
5,ID1012260530__Dry_Clover_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Clover_g,0.0
6,ID1012260530__Dry_Dead_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Dead_g,0.0


In [7]:
# pivot the dataframe to have one row per image with multiple target columns
tabular_df = df.pivot_table(index=['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI'],
                            columns='target_name', values='target', aggfunc='first').reset_index()
tabular_df.columns.name = None  # remove the aggregation name
print(tabular_df.shape)
print(tabular_df.head())

(357, 9)
               image_path Sampling_Date State            Species  \
0  train/ID1011485656.jpg      2015/9/4   Tas    Ryegrass_Clover   
1  train/ID1012260530.jpg      2015/4/1   NSW            Lucerne   
2  train/ID1025234388.jpg      2015/9/1    WA  SubcloverDalkeith   
3  train/ID1028611175.jpg     2015/5/18   Tas           Ryegrass   
4  train/ID1035947949.jpg     2015/9/11   Tas           Ryegrass   

   Height_Ave_cm  Pre_GSHH_NDVI  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  
0         4.6667           0.62        0.0000     31.9984      16.2750  
1        16.0000           0.55        0.0000      0.0000       7.6000  
2         1.0000           0.38        6.0500      0.0000       0.0000  
3         5.0000           0.66        0.0000     30.9703      24.2376  
4         3.5000           0.54        0.4343     23.2239      10.5261  


In [8]:
print(tabular_df.columns.tolist())

['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']


In [9]:
target_cols  = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']
num_features = ['Height_Ave_cm', 'Pre_GSHH_NDVI']
cat_features = ['Species', 'State']

In [10]:
# BUG: data leakage, will be fixed later in this code
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), num_features), # normalizing numeric features
        ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features) # OHE for categorical features
    ]
)

In [11]:
tabular_data = preprocessor.fit_transform(tabular_df)

In [12]:
print(tabular_data.shape)
display(tabular_data[:5])

(357, 21)


array([[-0.28520388, -0.24631873,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ],
       [ 0.81823967, -0.70706013,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ],
       [-0.64220462, -1.82600352,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         1.        ],
       [-0.25275281,  0.01696207,  0.        ,  0.        

In [22]:
# Load OOF soft targets from Teacher
path_soft_targets = './kaggle/input/{type}.csv'
path_soft_targets = path_soft_targets.format(type='train_with_oof_soft_targets' if USE_OOF_SOFT_TARGETS else 'train_with_soft_targets')

soft_targets_df = pd.read_csv(path_soft_targets)

logger.success(f"Loaded soft targets: {soft_targets_df.shape}")
logger.info(f"Columns: {soft_targets_df.columns.tolist()}")

[38;2;105;254;105m
[2025-12-11 15:00:09]
SUCCESS: Loaded soft targets: (357, 15)[0m
[38;2;161;247;255m
[2025-12-11 15:00:09]
INFO: Columns: ['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Season', 'strat_group', 'fold', 'Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft'][0m


In [23]:
# Verify we have soft target columns
soft_cols = ['Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft']
assert all(col in soft_targets_df.columns for col in soft_cols), "Missing soft target columns!"

logger.info(f"Soft targets preview:")
display(soft_targets_df[soft_cols].head())

[38;2;161;247;255m
[2025-12-11 15:00:15]
INFO: Soft targets preview:[0m


Unnamed: 0,Dry_Clover_g_soft,Dry_Dead_g_soft,Dry_Green_g_soft
0,0.369569,43.185364,20.15381
1,0.051066,0.545608,7.336673
2,6.852712,0.061617,0.0
3,0.561359,32.825825,19.57612
4,1.100665,16.536076,8.233045


## Dataset

In [None]:
class BiomassStudentDataset(Dataset):
    """
    Student dataset for distillation.
    Returns: image + hard targets + soft targets (from Teacher OOF predictions)
    NO tabular features!
    """

    def __init__(
        self,
        df: pd.DataFrame,
        target_cols: list[str],
        soft_target_cols: list[str],
        img_dir: str,
        transform: transforms.Compose | None = None,
        is_test: bool = False,
        use_log_target: bool = True
    ):
        """
        Args:
            df: DataFrame with image_path, hard targets, and soft targets
            target_cols: List of hard target column names
            soft_target_cols: List of soft target column names (from Teacher)
            img_dir: Root directory for images
            transform: torchvision transform pipeline
            is_test: If True, targets are not expected
            use_log_target: If True, apply log1p transform to hard targets
        """
        self.df = df.reset_index(drop=True)
        self.target_cols = target_cols
        self.soft_target_cols = soft_target_cols
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        self.use_log_target = use_log_target

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict:
        """
        Returns:
            dict with keys:
                - 'left_image': tensor [C, H, W]
                - 'right_image': tensor [C, H, W]
                - 'hard_targets': tensor [3] - real ground truth
                - 'soft_targets': tensor [3] - Teacher predictions
                - 'image_id': str
        """
        row = self.df.iloc[idx]

        # Load image
        img_path = os.path.join(self.img_dir, row['image_path'].replace('train/', '').replace('test/', ''))
        image = cv2.imread(img_path)

        if image is None:
            raise FileNotFoundError(f"Cannot load image: {img_path}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Split into left and right patches
        h, w, c = image.shape
        mid_w = w // 2

        left_patch = image[:, :mid_w, :]
        right_patch = image[:, mid_w:, :]

        # Convert to PIL
        left_pil = Image.fromarray(left_patch)
        right_pil = Image.fromarray(right_patch)

        # Apply transforms
        if self.transform:
            left_tensor = self.transform(left_pil)
            right_tensor = self.transform(right_pil)
        else:
            left_tensor = transforms.ToTensor()(left_pil)
            right_tensor = transforms.ToTensor()(right_pil)

        output = {
            'left_image': left_tensor,
            'right_image': right_tensor,
            'image_id': row['image_path'].split('/')[-1].replace('.jpg', '')
        }

        # Add targets if not test
        if not self.is_test:
            # Hard targets (ground truth)
            hard_targets = row[self.target_cols].values.astype(np.float32)
            if self.use_log_target:
                hard_targets = np.log1p(hard_targets)
            
            # Soft targets (Teacher predictions - already in original scale!)
            soft_targets = row[self.soft_target_cols].values.astype(np.float32)
            # Apply log if needed to match hard targets space
            if self.use_log_target:
                soft_targets = np.log1p(soft_targets)

            output['hard_targets'] = torch.tensor(hard_targets, dtype=torch.float32)
            output['soft_targets'] = torch.tensor(soft_targets, dtype=torch.float32)

        return output

In [None]:
def calculate_img_data_stat(df: pd.DataFrame):
    """Calculate mean and std of image data for normalization."""
    means = []
    stds = []

    loader = tqdm(df['image_path'], desc="Calculating image stats")

    for img_path in loader:
        full_path = os.path.join(PATH_TRAIN_IMG, img_path.replace('train/', '').replace('test/', ''))
        image = cv2.imread(full_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0  # Normalize to [0, 1]

        means.append(np.mean(image, axis=(0, 1)))
        stds.append(np.std(image, axis=(0, 1)))

    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)

    return mean, std

In [15]:
# train_mean, train_std = calculate_img_data_stat(tabular_df)
# print(f"Train Image Mean: {train_mean}, Std: {train_std}")

In [16]:
# Image backbone (processes each patch independently)
temp_backbone = timm.create_model(
    MODEL,
    pretrained=True,
    num_classes=0,  # remove classification head
    global_pool='avg'
)

temp_backbone.default_cfg

{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
 'hf_hub_id': 'timm/swinv2_tiny_window8_256.ms_in1k',
 'architecture': 'swinv2_tiny_window8_256',
 'tag': 'ms_in1k',
 'custom_load': False,
 'input_size': (3, 256, 256),
 'fixed_input_size': True,
 'interpolation': 'bicubic',
 'crop_pct': 0.9,
 'crop_mode': 'center',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'num_classes': 1000,
 'pool_size': (8, 8),
 'first_conv': 'patch_embed.proj',
 'classifier': 'head.fc',
 'license': 'mit'}

In [17]:
inputs_size = temp_backbone.default_cfg['input_size']
mean = temp_backbone.default_cfg['mean']
std = temp_backbone.default_cfg['std']

SIZE = int(inputs_size[1]) if inputs_size is not None and inputs_size[1] == inputs_size[2] else 256
print(f"Backbone expected input size: {inputs_size}, using SIZE={SIZE}")
print(f"Backbone expected mean: {mean}, std: {std}")

# Get backbone output dimension
with torch.no_grad():
    try:
        dummy = torch.randn(1, 3, SIZE, SIZE)
        feat_dim = temp_backbone(dummy).shape[1]
    except Exception as e:
        logger.error(f"Error getting backbone feature dimension: {e}")
        raise e

Backbone expected input size: (3, 256, 256), using SIZE=256
Backbone expected mean: (0.485, 0.456, 0.406), std: (0.229, 0.224, 0.225)


In [None]:
# Student needs STRONGER augmentations (he has no metadata!)
student_train_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    
    # Geometric augmentations (stronger)
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    
    # Color augmentations (stronger)
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.15)
    ], p=0.6),
    
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
    
    # Stronger occlusion simulation
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3))
])

# Validation transform (same as before)
student_val_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

logger.info("Student augmentations are MORE AGGRESSIVE than Teacher!")

In [None]:
# Create dataset instance
train_dataset = BiomassStudentDataset(
    df=tabular_df,
    tabular_features=tabular_data,
    target_cols=target_cols,
    img_dir=PATH_TRAIN_IMG,
    transform=train_transform,
    is_test=False
)

# Test it
sample = train_dataset[0]
print(f"Left image shape: {sample['left_image'].shape}")
print(f"Right image shape: {sample['right_image'].shape}")
print(f"Tabular shape: {sample['tabular'].shape}")
print(f"Targets shape: {sample['targets'].shape}")
print(f"Image ID: {sample['image_id']}")
print(f"Target values: {sample['targets']}")
print()

# Test dataset with log transform
print(f"Original targets from df: {tabular_df.iloc[0][target_cols].values}")
print(f"Log-transformed targets: {sample['targets']}")
print(f"Should be close to: {np.log1p(tabular_df.iloc[0][target_cols].values.astype(np.float32))}")

Left image shape: torch.Size([3, 256, 256])
Right image shape: torch.Size([3, 256, 256])
Tabular shape: torch.Size([21])
Targets shape: torch.Size([3])
Image ID: ID1011485656
Target values: tensor([0.0000, 3.4965, 2.8493])

Original targets from df: [np.float64(0.0) np.float64(31.9984) np.float64(16.275)]
Log-transformed targets: tensor([0.0000, 3.4965, 2.8493])
Should be close to: [0.        3.496459  2.8492603]


## Spliting Data (StratifiedKFold)

In [29]:
def get_season(date_str: str) -> str:
    """
    Convert date string to season.

    Args:
        date_str: Date in format 'YYYY/M/D' or 'YYYY/MM/DD'

    Returns:
        Season name: 'Summer', 'Autumn', 'Winter', 'Spring'
    """
    # Parse month from date string
    month = int(date_str.split('/')[1])

    # Australian seasons (Southern Hemisphere)
    if month in [12, 1, 2]:
        return 'Summer'
    elif month in [3, 4, 5]:
        return 'Autumn'
    elif month in [6, 7, 8]:
        return 'Winter'
    else:  # 9, 10, 11
        return 'Spring'

In [30]:
# Add season column
tabular_df['Season'] = tabular_df['Sampling_Date'].apply(get_season)

# Create stratification column combining Season, State, and Species
tabular_df['strat_group'] = (
    tabular_df['Season'].astype(str) + '_' +
    tabular_df['State'].astype(str) + '_' +
    tabular_df['Species'].astype(str)
)

print("Unique stratification groups:")
print(tabular_df['strat_group'].value_counts())
print(f"\nTotal groups: {tabular_df['strat_group'].nunique()}")

Unique stratification groups:
strat_group
Spring_Tas_Ryegrass_Clover                                                41
Winter_Vic_Phalaris_Clover                                                32
Autumn_Tas_Ryegrass                                                       22
Winter_Vic_Ryegrass_Clover                                                21
Spring_Tas_Clover                                                         21
Winter_WA_Clover                                                          20
Winter_Tas_Ryegrass_Clover                                                18
Autumn_NSW_Lucerne                                                        15
Spring_Vic_Ryegrass_Clover                                                11
Spring_Vic_Phalaris_BarleyGrass_SilverGrass_SpearGrass_Clover_Capeweed    11
Spring_NSW_Fescue                                                         11
Summer_NSW_Fescue_CrumbWeed                                               10
Spring_Vic_Phalaris_Clover        

In [31]:
# Initialize StratifiedKFold
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Get stratification labels
strat_labels = tabular_df['strat_group'].values

# Create fold assignments
tabular_df['fold'] = -1

for fold_idx, (train_idx, val_idx) in enumerate(skf.split(tabular_df, strat_labels)):
    tabular_df.loc[val_idx, 'fold'] = fold_idx

    print(f"\nFold {fold_idx + 1}:")
    print(f"  Train samples: {len(train_idx)}")
    print(f"  Val samples: {len(val_idx)}")


Fold 1:
  Train samples: 285
  Val samples: 72

Fold 2:
  Train samples: 285
  Val samples: 72

Fold 3:
  Train samples: 286
  Val samples: 71

Fold 4:
  Train samples: 286
  Val samples: 71

Fold 5:
  Train samples: 286
  Val samples: 71




In [32]:
# Verify stratification worked
print("Stratification verification:")

for fold in range(n_folds):
    fold_df = tabular_df[tabular_df['fold'] == fold]
    print(f"\nFold {fold + 1}:")
    print(f"  Season distribution:")
    print(fold_df['Season'].value_counts(normalize=True).round(3))
    print(f"  State distribution:")
    print(fold_df['State'].value_counts(normalize=True).round(3))
    print(f"  Species distribution:")
    print(fold_df['Species'].value_counts(normalize=True).round(3))

Stratification verification:

Fold 1:
  Season distribution:
Season
Winter    0.361
Spring    0.347
Autumn    0.167
Summer    0.125
Name: proportion, dtype: float64
  State distribution:
State
Tas    0.389
Vic    0.319
NSW    0.222
WA     0.069
Name: proportion, dtype: float64
  Species distribution:
Species
Ryegrass_Clover                                                0.292
Ryegrass                                                       0.167
Phalaris_Clover                                                0.111
Clover                                                         0.111
Fescue                                                         0.083
Lucerne                                                        0.056
Fescue_CrumbWeed                                               0.028
Phalaris                                                       0.028
WhiteClover                                                    0.028
Phalaris_Clover_Ryegrass_Barleygrass_Bromegrass                0.028


In [33]:
def get_fold_data(df: pd.DataFrame, fold: int):
    """
    Get train/val split for specific fold.

    Args:
        df: DataFrame with 'fold' column
        fold: Fold index to use as validation

    Returns:
        train_df, val_df
    """
    train_df = df[df['fold'] != fold].reset_index(drop=True)
    val_df = df[df['fold'] == fold].reset_index(drop=True)

    return train_df, val_df

In [None]:
def get_loaders(fold: int, bs: int, soft_df: pd.DataFrame) -> tuple[DataLoader, DataLoader]:
    """
    Get dataloaders for Student model (no tabular features needed!)
    
    Args:
        fold: Fold index to use as validation
        bs: Batch size
        soft_df: DataFrame with soft targets from Teacher
    """
    train_df = soft_df[soft_df['fold'] != fold].reset_index(drop=True)
    val_df = soft_df[soft_df['fold'] == fold].reset_index(drop=True)

    logger.info(f"Student training fold {fold}:")
    logger.info(f"  Train size: {len(train_df)}")
    logger.info(f"  Val size: {len(val_df)}")

    # Soft target columns
    soft_target_cols = ['Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft']

    # Create datasets
    train_dataset = BiomassStudentDataset(
        df=train_df,
        target_cols=target_cols,
        soft_target_cols=soft_target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=student_train_transform,
        is_test=False,
        use_log_target=USE_LOG_TARGET
    )

    val_dataset = BiomassStudentDataset(
        df=val_df,
        target_cols=target_cols,
        soft_target_cols=soft_target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=student_val_transform,
        is_test=False,
        use_log_target=USE_LOG_TARGET
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=bs,
        shuffle=True,
        num_workers=min(NUM_WORKERS, 8),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=bs * 2,
        shuffle=False,
        num_workers=min(NUM_WORKERS, 8),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    logger.info(f"Train batches: {len(train_loader)}")
    logger.info(f"Val batches: {len(val_loader)}")

    return train_loader, val_loader

In [None]:
# Test student dataset
logger.info("Testing Student Dataset...")

test_train_loader, test_val_loader = get_student_loaders(fold=0, bs=BATCH_SIZE, soft_df=soft_targets_df)

sample_batch = next(iter(test_train_loader))

print(f"Left image shape: {sample_batch['left_image'].shape}")
print(f"Right image shape: {sample_batch['right_image'].shape}")
print(f"Hard targets shape: {sample_batch['hard_targets'].shape}")
print(f"Soft targets shape: {sample_batch['soft_targets'].shape}")

logger.success("Student dataset works perfectly!")

## Ligtning Module

### Custom Loss Function

In [35]:
labels = [
    "Dry_Clover_g",
    "Dry_Dead_g",
    "Dry_Green_g",
    "Dry_Total_g",
    "GDM_g"
]

weights = {
    'Dry_Green_g': 0.1,
    'Dry_Dead_g': 0.1,
    'Dry_Clover_g': 0.1,
    'GDM_g': 0.2,
    'Dry_Total_g': 0.5,
}


def competition_metric(y_true, y_pred) -> float:
    """Function to calculate the competition's official evaluation metric (weighted R2 score)."""
    weights_array = np.array([weights[l] for l in labels])

    # Align with this calculation method
    y_weighted_mean = np.average(y_true, weights=weights_array, axis=1).mean()

    # For ss_res and ss_tot, also take the weighted average on axis=1, then the mean of the result
    ss_res = np.average((y_true - y_pred)**2,
                        weights=weights_array, axis=1).mean()
    ss_tot = np.average((y_true - y_weighted_mean)**2,
                        weights=weights_array, axis=1).mean()

    return 1 - ss_res / ss_tot

In [24]:
class StudentDistillationLoss(nn.Module):
    """
    Custom loss for Student model combining:
    1. Distillation loss (learn from Teacher)
    2. Hard loss (learn from real targets with competition weights)
    """
    
    def __init__(self, alpha: float = 0.5, use_log_space: bool = True):
        """
        Args:
            alpha: Weight for distillation loss (0.5 = equal weight to Teacher and ground truth)
            use_log_space: If True, compute loss in log space
        """
        super().__init__()
        self.alpha = alpha
        self.use_log_space = use_log_space
        
        # Competition weights
        self.w_green = 0.1
        self.w_clover = 0.1
        self.w_dead = 0.1
        self.w_gdm = 0.2
        self.w_total = 0.5
        
    def forward(
        self, 
        student_preds: torch.Tensor,  # [B, 3] predictions in log space
        hard_targets: torch.Tensor,   # [B, 3] ground truth in log space
        soft_targets: torch.Tensor    # [B, 3] Teacher predictions in log space
    ) -> tuple[torch.Tensor, dict]:
        """
        Returns:
            total_loss: Combined loss
            loss_dict: Dictionary with individual loss components
        """
        
        # 1. Distillation Loss (MSE with Teacher's soft targets)
        loss_distill = F.mse_loss(student_preds, soft_targets)
        
        # 2. Hard Loss with competition weights
        # Individual components
        loss_green = F.mse_loss(student_preds[:, 2], hard_targets[:, 2])   # Dry_Green_g
        loss_clover = F.mse_loss(student_preds[:, 0], hard_targets[:, 0])  # Dry_Clover_g
        loss_dead = F.mse_loss(student_preds[:, 1], hard_targets[:, 1])    # Dry_Dead_g
        
        # Derived targets (computed from components)
        # Dry_Total_g = sum of all 3 components
        student_total = student_preds.sum(dim=1)
        hard_total = hard_targets.sum(dim=1)
        loss_total = F.mse_loss(student_total, hard_total)
        
        # GDM_g = Clover + Green
        student_gdm = student_preds[:, 0] + student_preds[:, 2]  # Clover + Green
        hard_gdm = hard_targets[:, 0] + hard_targets[:, 2]
        loss_gdm = F.mse_loss(student_gdm, hard_gdm)
        
        # Weighted hard loss (following competition metric weights)
        loss_hard = (
            self.w_green * loss_green +
            self.w_clover * loss_clover +
            self.w_dead * loss_dead +
            self.w_gdm * loss_gdm +
            self.w_total * loss_total
        )
        
        # 3. Total loss (weighted combination)
        total_loss = self.alpha * loss_distill + (1 - self.alpha) * loss_hard
        
        # Return loss dict for logging
        loss_dict = {
            'loss_distill': loss_distill.item(),
            'loss_hard': loss_hard.item(),
            'loss_green': loss_green.item(),
            'loss_clover': loss_clover.item(),
            'loss_dead': loss_dead.item(),
            'loss_total': loss_total.item(),
            'loss_gdm': loss_gdm.item(),
        }
        
        return total_loss, loss_dict

In [None]:
class BiomassStudentModel(pl.LightningModule):
    """
    Student model for biomass prediction.
    Uses ONLY images (dual-patch), NO tabular features.
    Learns from Teacher's soft targets + ground truth.
    """

    def __init__(
        self,
        backbone_name: str = 'swinv2_tiny_window8_256',
        num_targets: int = 3,
        lr: float = 1e-4,
        weight_decay: float = 1e-5,
        hidden_ratio: float = 0.5,
        dropout: float = 0.3,  # Higher dropout for student
        fusion_method: str = 'mean',
        distill_alpha: float = 0.5,  # Weight for distillation loss
        use_log_target: bool = True
    ):
        super().__init__()
        self.save_hyperparameters()

        # Image backbone (same as Teacher)
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )

        self.lr = lr
        self.weight_decay = weight_decay
        self.fusion_method = fusion_method
        self.use_log_target = use_log_target

        # Get backbone output dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, SIZE, SIZE)
            feat_dim = self.backbone(dummy).shape[1]

        self.feat_dim = feat_dim

        # NO tabular features - only image features!
        if self.fusion_method == 'concat':
            self.combined_dim = feat_dim * 2
        else:  # mean or max
            self.combined_dim = feat_dim

        # Regression heads (simpler than Teacher)
        hidden_size = max(32, int(self.combined_dim * hidden_ratio))

        def make_head():
            return nn.Sequential(
                nn.Linear(self.combined_dim, hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, 1)
            )

        self.head_green = make_head()
        self.head_clover = make_head()
        self.head_dead = make_head()

        # Custom distillation loss
        self.criterion = StudentDistillationLoss(
            alpha=distill_alpha,
            use_log_space=use_log_target
        )

        # Storage for validation
        self.validation_step_outputs = []

        logger.info(f"Student model initialized: backbone={backbone_name}, feat_dim={feat_dim}, "
                    f"combined_dim={self.combined_dim}, fusion={fusion_method}, "
                    f"distill_alpha={distill_alpha}")

    def forward(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            batch: dict with 'left_image', 'right_image'

        Returns:
            (green, clover, dead) predictions
        """
        # Extract features from each patch
        left_feat = self.backbone(batch['left_image'])
        right_feat = self.backbone(batch['right_image'])

        # Fuse image features
        if self.fusion_method == 'concat':
            img_feat = torch.cat([left_feat, right_feat], dim=1)
        elif self.fusion_method == 'mean':
            img_feat = (left_feat + right_feat) / 2
        elif self.fusion_method == 'max':
            img_feat = torch.maximum(left_feat, right_feat)
        else:
            raise ValueError(f"Unknown fusion method: {self.fusion_method}")

        # Predict each target
        green = self.head_green(img_feat).squeeze(1)
        clover = self.head_clover(img_feat).squeeze(1)
        dead = self.head_dead(img_feat).squeeze(1)

        return green, clover, dead

    def compute_all_targets(self, green: torch.Tensor, clover: torch.Tensor, dead: torch.Tensor) -> torch.Tensor:
        """Compute all 5 targets from 3 predicted ones"""
        green = torch.clamp(green, min=0.0)
        clover = torch.clamp(clover, min=0.0)
        dead = torch.clamp(dead, min=0.0)

        total = green + dead + clover
        gdm = clover + green

        all_targets = torch.stack([clover, dead, green, total, gdm], dim=1)
        return all_targets

    def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        green, clover, dead = self(batch)
        
        # Stack predictions [B, 3] in order: [clover, dead, green]
        preds = torch.stack([clover, dead, green], dim=1)
        
        # Compute distillation loss
        loss, loss_dict = self.criterion(
            preds,
            batch['hard_targets'],
            batch['soft_targets']
        )

        # Log all loss components
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True,
                 batch_size=batch['hard_targets'].size(0))
        self.log('train_loss_distill', loss_dict['loss_distill'], on_step=False, on_epoch=True)
        self.log('train_loss_hard', loss_dict['loss_hard'], on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        green_pred, clover_pred, dead_pred = self(batch)
        
        preds = torch.stack([clover_pred, dead_pred, green_pred], dim=1)
        
        # Compute loss
        loss, loss_dict = self.criterion(
            preds,
            batch['hard_targets'],
            batch['soft_targets']
        )

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True,
                 batch_size=batch['hard_targets'].size(0))
        self.log('val_loss_distill', loss_dict['loss_distill'], on_step=False, on_epoch=True)
        self.log('val_loss_hard', loss_dict['loss_hard'], on_step=False, on_epoch=True)

        # Convert to original scale for metric
        if self.use_log_target:
            green_pred = torch.expm1(green_pred)
            clover_pred = torch.expm1(clover_pred)
            dead_pred = torch.expm1(dead_pred)
            
            hard_targets_original = torch.expm1(batch['hard_targets'])
        else:
            hard_targets_original = batch['hard_targets']

        # Compute all 5 targets
        preds_all = self.compute_all_targets(green_pred, clover_pred, dead_pred)
        
        clover_true = hard_targets_original[:, 0]
        dead_true = hard_targets_original[:, 1]
        green_true = hard_targets_original[:, 2]
        targets_all = self.compute_all_targets(green_true, clover_true, dead_true)

        self.validation_step_outputs.append({
            'preds': preds_all.detach().cpu(),
            'targets': targets_all.detach().cpu()
        })

        return loss

    def on_validation_epoch_end(self):
        if len(self.validation_step_outputs) == 0:
            return

        all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0).numpy()
        all_targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0).numpy()

        comp_metric = competition_metric(all_targets, all_preds)
        self.log('val_comp_metric', comp_metric, on_epoch=True, prog_bar=True)

        self.validation_step_outputs.clear()

    def predict_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        green, clover, dead = self(batch)
        preds = torch.stack([clover, dead, green], dim=1)

        if self.use_log_target:
            preds = torch.expm1(preds)

        preds = torch.clamp(preds, min=0.0)
        return preds

    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )

        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs or 20,
            eta_min=self.lr * 0.01
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'
            }
        }

## Folds Training

In [None]:
student_fold_results = []

for fold_id in range(N_FOLDS):
    train_loader, val_loader = get_student_loaders(fold=fold_id, bs=BATCH_SIZE, soft_df=soft_targets_df)

    model = BiomassStudentModel(
        backbone_name=MODEL,
        num_targets=len(target_cols),
        lr=STUDENT_LR,
        weight_decay=WEIGHT_DECAY,
        hidden_ratio=HIDDEN_RATIO,
        dropout=STUDENT_DROPOUT,
        fusion_method=FUSION_METHOD,
        distill_alpha=DISTILL_ALPHA,
        use_log_target=USE_LOG_TARGET
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=os.path.join(CHECKPOINTS_DIR, f'fold{fold_id}'),
        filename=f'{DESCRIPTION_FULL}-student-fold{fold_id}' + '-{epoch:02d}-{val_loss:.4f}',
        save_top_k=3,
        mode='min'
    )

    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=7,  # More patience for student
        mode='min',
        verbose=True
    )

    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # Logger
    wandb_logger = WandbLogger(
        project=PROJECT_NAME,
        name=f'{DESCRIPTION_FULL}-student-fold{fold_id}',
        log_model='all'
    )

    # Trainer
    trainer = pl.Trainer(
        max_epochs=STUDENT_EPOCHS,
        accelerator=DEVICE.type,
        precision='16-mixed' if torch.cuda.is_available() else 32,
        accumulate_grad_batches=GRAD_ACCUM,
        callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
        logger=wandb_logger,
        log_every_n_steps=1,
        gradient_clip_val=1.0,
        enable_progress_bar=True
    )

    try:
        # Train
        logger.info(f"\nTraining Student on Fold {fold_id}...")
        trainer.fit(model, train_loader, val_loader)

        # Load best checkpoint
        best_model_path = checkpoint_callback.best_model_path
        logger.info(f"Loading best Student model from: {best_model_path}")
        best_model = BiomassStudentModel.load_from_checkpoint(best_model_path)

        # Evaluate
        val_result = trainer.validate(best_model, val_loader, verbose=False)
        student_fold_results.append({
            'fold': fold_id,
            'val_loss': val_result[0]['val_loss'],
            'val_comp_metric': val_result[0].get('val_comp_metric', 0.0)
        })
        
        logger.success(f"Fold {fold_id} completed: val_loss={val_result[0]['val_loss']:.4f}")

    except SystemExit:
        logger.warning(f"Training interrupted during fold {fold_id}. Exiting gracefully.")
        wandb_logger.experiment.finish()
        break

    finally:
        wandb_logger.experiment.finish()

logger.success("STUDENT TRAINING COMPLETED")

In [None]:
print("Training Summary")
print()
results_df = pd.DataFrame(student_fold_results)
print(results_df)
print(f"Mean Val Loss: {results_df['val_loss'].mean():.4f} ± {results_df['val_loss'].std():.4f}")
print(f"Mean Comp Metric: {results_df['val_comp_metric'].mean():.4f} ± {results_df['val_comp_metric'].std():.4f}")

Training Summary

   fold  val_loss
0     0  0.893824
1     1  0.811320
2     2  0.936835
3     3  1.147606
4     4  1.055121
Mean Val Loss: 0.9689 ± 0.1331


## Prepare Data for Student

In [45]:
def load_teacher_models(checkpoints_dir: str, n_folds: int) -> list:
    """
    Load best checkpoints from all folds.
    
    Args:
        checkpoints_dir: Directory with fold checkpoints
        n_folds: Number of folds
    
    Returns:
        List of loaded models
    """
    models = []
    
    for fold_id in range(n_folds):
        fold_dir = os.path.join(checkpoints_dir, f'fold{fold_id}')
        
        # Find best checkpoint file
        ckpt_files = [f for f in os.listdir(fold_dir) if f.endswith('.ckpt')]
        if not ckpt_files:
            logger.warning(f"No checkpoint found in {fold_dir}")
            continue

        # sort by val_loss in filename (val_loss=xxx.ckpt)
        ckpt_files.sort(key=lambda x: float(x.split('=')[-1].replace('.ckpt', '')))
        
        # Assuming the first one is the best
        ckpt_path = os.path.join(fold_dir, ckpt_files[0])
        logger.info(f"Loading checkpoint: {ckpt_path}")
        
        model = BiomassTeacherModel.load_from_checkpoint(ckpt_path)
        model.eval()
        model = model.to(DEVICE)
        models.append(model)
    
    logger.success(f"Loaded {len(models)} teacher models")
    return models

In [46]:
models = load_teacher_models(CHECKPOINTS_DIR, N_FOLDS)

[38;2;161;247;255m
[2025-12-11 14:17:08]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold0\swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5-fold0-epoch=13-val_loss=0.8938.ckpt[0m
[38;2;161;247;255m
[2025-12-11 14:17:09]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-11 14:17:09]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold1\swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5-fold1-epoch=11-val_loss=0.8113.ckpt[0m
[38;2;161;247;255m
[2025-12-11 14:17:10]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-11 14:17:10]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold2\swinv2_tiny_window8_256-local_train[5]Folds_log_f

### Direct Ensemble Soft Labels

In [None]:
def create_soft_targets_dataset():
    """
    Create dataset with soft targets from teacher ensemble.
    Predicts on ALL training data (100%).
    
    Returns:
        DataFrame with soft targets appended
    """
    
    # Load all teacher models
    teacher_models = load_teacher_models(CHECKPOINTS_DIR, N_FOLDS)
    
    if not teacher_models:
        logger.error("No teacher models loaded!")
        return None
    
    logger.info(f"Creating soft targets using {len(teacher_models)} teacher models...")
    
    # Prepare full dataset with NO augmentation
    full_dataset = BiomassDataset(
        df=tabular_df,
        tabular_features=tabular_data,
        target_cols=target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=val_transform,  # Use validation transform (no augmentation)
        is_test=False,
        use_log_target=USE_LOG_TARGET
    )
    
    # Create dataloader for inference
    full_loader = DataLoader(
        full_dataset,
        batch_size=BATCH_SIZE * 2,  # Larger batch for inference
        shuffle=False,
        num_workers=min(NUM_WORKERS, 8),
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    logger.info(f"Total samples for soft targets: {len(full_dataset)}")
    logger.info(f"Inference batches: {len(full_loader)}")
    
    # Store predictions from all models
    all_predictions = []
    
    # For each teacher model
    for model_idx, teacher_model in enumerate(teacher_models):
        logger.info(f"Processing teacher model {model_idx + 1}/{len(teacher_models)}...")
        
        model_predictions = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(full_loader):
                # Move batch to device
                batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Get predictions
                preds = teacher_model.predict_step(batch, 0)  # [B, 3]
                model_predictions.append(preds.cpu().numpy())
                
                if (batch_idx + 1) % 10 == 0:
                    logger.debug(f"  Batch {batch_idx + 1}/{len(full_loader)}")
        
        # Concatenate all batches
        model_preds_array = np.concatenate(model_predictions, axis=0)  # [N, 3]
        all_predictions.append(model_preds_array)
        
        logger.success(f"Model {model_idx + 1} predictions shape: {model_preds_array.shape}")
    
    # Average predictions across all models
    ensemble_predictions = np.mean(all_predictions, axis=0)  # [N, 3]
    
    logger.success(f"Ensemble predictions shape: {ensemble_predictions.shape}")
    logger.info(f"Ensemble predictions range: min={ensemble_predictions.min():.4f}, max={ensemble_predictions.max():.4f}")
    
    # Create DataFrame with soft targets
    soft_targets_df = tabular_df.copy()
    
    # Add soft target columns
    soft_targets_df['Dry_Clover_g_soft'] = ensemble_predictions[:, 0]
    soft_targets_df['Dry_Dead_g_soft'] = ensemble_predictions[:, 1]
    soft_targets_df['Dry_Green_g_soft'] = ensemble_predictions[:, 2]
    
    # Keep original targets for reference
    # (optional: remove them later for student training)
    
    logger.info("\nSoft targets statistics:")
    logger.info(f"Dry_Clover_g: mean={soft_targets_df['Dry_Clover_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Clover_g_soft'].std():.4f}")
    logger.info(f"Dry_Dead_g: mean={soft_targets_df['Dry_Dead_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Dead_g_soft'].std():.4f}")
    logger.info(f"Dry_Green_g: mean={soft_targets_df['Dry_Green_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Green_g_soft'].std():.4f}")
    
    return soft_targets_df, ensemble_predictions

In [48]:
def save_soft_targets(soft_targets_df: pd.DataFrame, output_path: str = './kaggle/input/'):
    """
    Save soft targets to CSV.
    
    Args:
        soft_targets_df: DataFrame with soft targets
        output_path: Path to save CSV
    """
    os.makedirs(output_path, exist_ok=True)
    
    output_file = os.path.join(output_path, 'train_with_soft_targets.csv')
    soft_targets_df.to_csv(output_file, index=False)
    
    logger.success(f"Saved soft targets to: {output_file}")
    logger.info(f"Shape: {soft_targets_df.shape}")
    logger.info(f"Columns: {soft_targets_df.columns.tolist()}")
    
    return output_file

In [51]:
def validate_soft_targets(soft_targets_df: pd.DataFrame):
    """
    Validate soft targets quality.
    
    Args:
        soft_targets_df: DataFrame with soft targets
    """
    logger.info("\n=== Soft Targets Validation ===")
    
    for col in ['Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft']:
        original_col = col.replace('_soft', '')
        
        # Check for NaN
        nan_count = soft_targets_df[col].isna().sum()
        if nan_count > 0:
            logger.warning(f"{col}: {nan_count} NaN values")
        
        # Compare soft vs hard targets
        corr = soft_targets_df[original_col].corr(soft_targets_df[col])
        mse = np.mean((soft_targets_df[original_col].values - soft_targets_df[col].values)**2)

        # logger.info(f"{col}:")
        # logger.info(f"  Correlation with hard target: {corr:.4f}")
        # logger.info(f"  MSE vs hard target: {mse:.4f}")
        # logger.info(f"  Range: [{soft_targets_df[col].min():.4f}, {soft_targets_df[col].max():.4f}]")
        msg = [f"{col}:", f"Correlation with hard target: {corr:.4f}", f"MSE vs hard target: {mse:.4f}", f"Range: [{soft_targets_df[col].min():.4f}, {soft_targets_df[col].max():.4f}]"]
        msg = "\n".join(msg)
        logger.info(msg)

In [None]:
# Create soft targets
soft_targets_df, ensemble_preds = create_soft_targets_dataset()

[38;2;161;247;255m
[2025-12-11 14:20:17]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold0\swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5-fold0-epoch=13-val_loss=0.8938.ckpt[0m
[38;2;161;247;255m
[2025-12-11 14:20:18]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-11 14:20:18]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold1\swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5-fold1-epoch=11-val_loss=0.8113.ckpt[0m
[38;2;161;247;255m
[2025-12-11 14:20:19]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-11 14:20:19]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold2\swinv2_tiny_window8_256-local_train[5]Folds_log_f

In [52]:
# Validate soft targets
validate_soft_targets(soft_targets_df)

# Save to CSV
output_file = save_soft_targets(soft_targets_df)

print(f"File: {output_file}")
print(f"Samples: {len(soft_targets_df)}")
print(f"Soft target columns: {[c for c in soft_targets_df.columns if '_soft' in c]}")

[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: 
=== Soft Targets Validation ===[0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Dry_Clover_g_soft:
Correlation with hard target: 0.9426
MSE vs hard target: 26.7841
Range: [0.0000, 97.1455][0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Dry_Dead_g_soft:
Correlation with hard target: 0.8593
MSE vs hard target: 40.1800
Range: [0.0000, 56.2889][0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Dry_Green_g_soft:
Correlation with hard target: 0.9462
MSE vs hard target: 69.3651
Range: [0.0000, 142.3275][0m
[38;2;105;254;105m
[2025-12-11 14:33:55]
SUCCESS: Saved soft targets to: ./kaggle/input/train_with_soft_targets.csv[0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Shape: (357, 15)[0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Columns: ['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Season', 'strat_group', 'fold', 'Dry_Clover_g_soft', 'D

### Out of Fold Predictions

In [63]:
def create_oof_soft_targets():
    """
    Create Out-Of-Fold soft targets from teacher models.
    Each model predicts ONLY on its validation fold (unseen data).
    
    Returns:
        DataFrame with OOF soft targets
    """
    logger.info("CREATING OUT-OF-FOLD (OOF) SOFT TARGETS")
    
    # Initialize array to store OOF predictions
    # Shape: [num_samples, 3] for 3 targets
    oof_predictions = np.zeros((len(tabular_df), 3))
    oof_indices = np.zeros(len(tabular_df), dtype=bool)  # Track which samples got predictions
    
    # For each fold
    for fold_id in range(N_FOLDS):
        logger.info(f"\nProcessing Fold {fold_id}...")
        
        # Load model for this fold
        fold_dir = os.path.join(CHECKPOINTS_DIR, f'fold{fold_id}')
        ckpt_files = [f for f in os.listdir(fold_dir) if f.endswith('.ckpt')]
        
        if not ckpt_files:
            logger.error(f"No checkpoint found for fold {fold_id}")
            continue
        
        # Sort by val_loss to get best checkpoint
        ckpt_files.sort(key=lambda x: float(x.split('=')[-1].replace('.ckpt', '')))
        ckpt_path = os.path.join(fold_dir, ckpt_files[0])
        
        logger.info(f"Loading checkpoint: {ckpt_path}")
        teacher_model = BiomassTeacherModel.load_from_checkpoint(ckpt_path)
        teacher_model.eval()
        teacher_model = teacher_model.to(DEVICE)
        
        # Get validation data for this fold
        val_df = tabular_df[tabular_df['fold'] == fold_id].reset_index(drop=True)
        val_indices = tabular_df[tabular_df['fold'] == fold_id].index.values
        
        logger.info(f"Validation samples for fold {fold_id}: {len(val_df)}")
        
        # Prepare tabular features for validation fold
        fold_preprocessor = ColumnTransformer(
            transformers=[
                ('num', StandardScaler(), num_features),
                ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features)
            ]
        )
        
        # Fit on train, transform on val
        train_df = tabular_df[tabular_df['fold'] != fold_id]
        fold_preprocessor.fit(train_df)
        val_tabular = fold_preprocessor.transform(val_df)
        
        # Create validation dataset
        val_dataset = BiomassDataset(
            df=val_df,
            tabular_features=val_tabular,
            target_cols=target_cols,
            img_dir=PATH_TRAIN_IMG,
            transform=val_transform,  # No augmentation
            is_test=False,
            use_log_target=USE_LOG_TARGET
        )
        
        # Create validation loader
        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE * 2,
            shuffle=False,  # IMPORTANT: keep order!
            num_workers=min(NUM_WORKERS, 8),
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        # Make predictions
        fold_predictions = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                # Move batch to device
                batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Get predictions
                preds = teacher_model.predict_step(batch, 0)  # [B, 3]
                fold_predictions.append(preds.cpu().numpy())
        
        # Concatenate batch predictions
        fold_preds_array = np.concatenate(fold_predictions, axis=0)  # [N_val, 3]
        
        logger.success(f"Fold {fold_id} predictions shape: {fold_preds_array.shape}")
        
        # Verify indices match
        assert len(fold_preds_array) == len(val_indices), \
            f"Predictions length {len(fold_preds_array)} != indices length {len(val_indices)}"
        
        # Store predictions at correct indices
        oof_predictions[val_indices] = fold_preds_array
        oof_indices[val_indices] = True
        
        logger.info(f"Stored predictions for indices: {val_indices[:5]}... (showing first 5)")
        
        # Clean up
        del teacher_model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Verify all samples got predictions
    if not oof_indices.all():
        missing_count = (~oof_indices).sum()
        logger.warning(f"Missing predictions for {missing_count} samples!")
    else:
        logger.success(f"All {len(oof_predictions)} samples have OOF predictions!")
    
    # Create DataFrame with OOF soft targets
    soft_targets_df = tabular_df.copy()
    
    # Add OOF soft target columns
    soft_targets_df['Dry_Clover_g_soft'] = oof_predictions[:, 0]
    soft_targets_df['Dry_Dead_g_soft'] = oof_predictions[:, 1]
    soft_targets_df['Dry_Green_g_soft'] = oof_predictions[:, 2]
    
    logger.info("\n=== OOF Soft Targets Statistics ===")
    logger.info(f"Dry_Clover_g: mean={soft_targets_df['Dry_Clover_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Clover_g_soft'].std():.4f}")
    logger.info(f"Dry_Dead_g: mean={soft_targets_df['Dry_Dead_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Dead_g_soft'].std():.4f}")
    logger.info(f"Dry_Green_g: mean={soft_targets_df['Dry_Green_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Green_g_soft'].std():.4f}")
    
    return soft_targets_df, oof_predictions

In [64]:
def save_oof_soft_targets(soft_targets_df: pd.DataFrame, output_path: str = './kaggle/input/'):
    """
    Save OOF soft targets to CSV.
    
    Args:
        soft_targets_df: DataFrame with OOF soft targets
        output_path: Path to save CSV
    """
    os.makedirs(output_path, exist_ok=True)
    
    output_file = os.path.join(output_path, 'train_with_oof_soft_targets.csv')
    soft_targets_df.to_csv(output_file, index=False)
    
    logger.success(f"Saved OOF soft targets to: {output_file}")
    logger.info(f"Shape: {soft_targets_df.shape}")
    logger.info(f"Columns: {soft_targets_df.columns.tolist()}")
    
    return output_file

In [65]:
def compare_oof_vs_ensemble(oof_df: pd.DataFrame, ensemble_df: pd.DataFrame):
    """
    Compare OOF predictions vs direct ensemble predictions.

    Args:
        oof_df: DataFrame with OOF soft targets
        ensemble_df: DataFrame with ensemble soft targets
    """
    logger.info("COMPARING OOF vs ENSEMBLE PREDICTIONS")

    for col in ['Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft']:
        oof_vals = oof_df[col].values
        ens_vals = ensemble_df[col].values

        # Correlation
        corr = np.corrcoef(oof_vals, ens_vals)[0, 1]

        # Mean Absolute Difference
        mad = np.mean(np.abs(oof_vals - ens_vals))

        # RMSE
        rmse = np.sqrt(np.mean((oof_vals - ens_vals)**2))

        msgs = [
            f"\n{col}:",
            f"  Correlation: {corr:.4f}",
            f"  Mean Absolute Diff: {mad:.4f}",
            f"  RMSE: {rmse:.4f}",
            f"  OOF  range: [{oof_vals.min():.2f}, {oof_vals.max():.2f}]",
            f"  Ensemble range: [{ens_vals.min():.2f}, {ens_vals.max():.2f}]"
        ]
        logger.info("\n".join(msgs))

In [66]:
# Create OOF soft targets
oof_soft_targets_df, oof_preds = create_oof_soft_targets()

[38;2;161;247;255m
[2025-12-11 14:42:18]
INFO: CREATING OUT-OF-FOLD (OOF) SOFT TARGETS[0m
[38;2;161;247;255m
[2025-12-11 14:42:18]
INFO: 
Processing Fold 0...[0m
[38;2;161;247;255m
[2025-12-11 14:42:18]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold0\swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5-fold0-epoch=13-val_loss=0.8938.ckpt[0m
[38;2;161;247;255m
[2025-12-11 14:42:19]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-11 14:42:19]
INFO: Validation samples for fold 0: 72[0m
[38;2;105;254;105m
[2025-12-11 14:42:28]
SUCCESS: Fold 0 predictions shape: (72, 3)[0m
[38;2;161;247;255m
[2025-12-11 14:42:28]
INFO: Stored predictions for indices: [ 1  3  6 11 12]... (showing first 5)[0m
[38;2;161;247;255m
[2025-12-11 14:42:28]
INFO: 
Processing Fold 1...[0m
[38;2;161;247;255m
[2025-12-11 14:42:2

In [67]:
# Validate OOF soft targets
validate_soft_targets(oof_soft_targets_df)

# Save to CSV
oof_output_file = save_oof_soft_targets(oof_soft_targets_df)
print(f"File: {oof_output_file}")
print(f"Samples: {len(oof_soft_targets_df)}")

[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
=== Soft Targets Validation ===[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Dry_Clover_g_soft:
Correlation with hard target: 0.8748
MSE vs hard target: 45.3763
Range: [0.0000, 83.9534][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Dry_Dead_g_soft:
Correlation with hard target: 0.7059
MSE vs hard target: 79.0653
Range: [0.0000, 54.3228][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Dry_Green_g_soft:
Correlation with hard target: 0.8542
MSE vs hard target: 184.7793
Range: [0.0000, 140.0474][0m
[38;2;105;254;105m
[2025-12-11 14:43:11]
SUCCESS: Saved OOF soft targets to: ./kaggle/input/train_with_oof_soft_targets.csv[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Shape: (357, 15)[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Columns: ['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Season', 'strat_group', 'fold', 'Dry_Clover_g_

In [68]:
# Compare with ensemble predictions (if available)
if 'soft_targets_df' in locals():
    compare_oof_vs_ensemble(oof_soft_targets_df, soft_targets_df)

[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: COMPARING OOF vs ENSEMBLE PREDICTIONS[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
Dry_Clover_g_soft:
  Correlation: 0.9606
  Mean Absolute Diff: 1.5858
  RMSE: 4.0862
  OOF  range: [0.00, 83.95]
  Ensemble range: [0.00, 97.15][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
Dry_Dead_g_soft:
  Correlation: 0.9250
  Mean Absolute Diff: 2.5321
  RMSE: 4.0186
  OOF  range: [0.00, 54.32]
  Ensemble range: [0.00, 56.29][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
Dry_Green_g_soft:
  Correlation: 0.9481
  Mean Absolute Diff: 4.5954
  RMSE: 8.1427
  OOF  range: [0.00, 140.05]
  Ensemble range: [0.00, 142.33][0m
