# Distillation. Teacher

In [2]:
machine = "local"
!wandb login

wandb: Currently logged in as: dmykhailov (dmykhailov-kyiv-school-of-economics) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


## Imports

In [3]:
import os
import gc
import logging
import numpy as np
import pandas as pd
from PIL import Image
from typing import cast
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder

import cv2
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from sam import SAM
from tqdm import tqdm
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import Dinov2Model, Dinov2Config
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import warnings
from notebooks_config import setup_logging, CustomLogger

warnings.simplefilter(action='ignore', category=FutureWarning)
print(f"PyTorch: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

Available variables: ['BASE_DIR', 'DATA_DIR', 'Path', 'directory', 'find_project_root', 'project_root', 'sys']
PyTorch: 2.9.1+cu128
Device: NVIDIA GeForce RTX 5050 Laptop GPU


In [4]:
logger = setup_logging(level=logging.DEBUG, full_color=True, include_function=False)
logger = cast(CustomLogger, logger)  # Type hinting
logger.success("Logging configuration test completed.")

[38;2;105;254;105m
[2025-12-15 09:19:30]
SUCCESS: Logging configured successfully ✅[0m
[38;2;105;254;105m
[2025-12-15 09:19:30]
SUCCESS: Logging configuration test completed.[0m


In [5]:
cpu_count = os.cpu_count()
NUM_WORKERS = 0

LR = 3e-5
EPOCHS = 20
N_FOLDS = 5
GRAD_ACCUM = 4
BATCH_SIZE = 4
DROPOUT_RATE = 0.2
WEIGHT_DECAY = 0.05
HIDDEN_RATIO = 0.5
TRAIN_SPLIT_RATIO = 0.02 # Used if N_FOLDS = 0

MODEL = 'facebook/dinov2-base'  # 'swinv2_tiny_window8_256'
PROJECT_NAME = "csiro-image2biomass-prediction"
CHECKPOINTS_DIR = "./kaggle/checkpoints/teacher/"

# Each patch is 1000x1000, resize to 768x768 for vision transformers
SIZE = 768
USE_LOG_TARGET = True     # Whether to use log1p transformation on target variable
FUSION_METHOD = 'gating'  # ('concat', 'mean', 'max') OR 'gating'

DESCRIPTION = machine + \
    (f"_train{TRAIN_SPLIT_RATIO}" if N_FOLDS == 0 else f"_train[{N_FOLDS}]Folds") + (
        f"_log" if USE_LOG_TARGET else "") + f"_fusion-{FUSION_METHOD}"
DESCRIPTION_FULL = MODEL + "-" + DESCRIPTION + \
    f"_epochs{EPOCHS}_bs{BATCH_SIZE}_gradacc{GRAD_ACCUM}_lr{LR}_wd{WEIGHT_DECAY}_dr{DROPOUT_RATE}_hr{HIDDEN_RATIO}"
SUBMISSION_NAME = f"{DESCRIPTION_FULL}_submission.csv"
SUBMISSION_ENSEMBLE_NAME = f"{DESCRIPTION_FULL}_ensemble_submission.csv"
SUBMISSION_MSG = DESCRIPTION_FULL.replace("_", " ")

SEED = 1488
torch.manual_seed(SEED)
np.random.seed(SEED)
pl.seed_everything(SEED)

print("DESCRIPTION_FULL:", DESCRIPTION_FULL)
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")

Seed set to 1488


DESCRIPTION_FULL: facebook/dinov2-base-local_train[5]Folds_log_fusion-gating_epochs20_bs4_gradacc4_lr3e-05_wd0.05_dr0.2_hr0.5
Effective batch size: 16


In [6]:
# setting device on GPU if available, else CPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
print('NUM_WORKERS:', NUM_WORKERS)
print()

# Additional Info when using cuda
if DEVICE.type == 'cuda':
    # clean GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    torch.set_float32_matmul_precision('high') if machine == "local" else None

    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3, 1), 'GB')

Using device: cuda
NUM_WORKERS: 0

NVIDIA GeForce RTX 5050 Laptop GPU
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


## Data Loading and Preprocessing

In [7]:
PATH_DATA = './kaggle/input/csiro-biomass'
PATH_TRAIN_CSV = os.path.join(PATH_DATA, 'train.csv')
PATH_TEST_CSV = os.path.join(PATH_DATA, 'test.csv')
PATH_TRAIN_IMG = os.path.join(PATH_DATA, 'train')
PATH_TEST_IMG = os.path.join(PATH_DATA, 'test')

df = pd.read_csv(PATH_TRAIN_CSV)
# Remove unneeded targets
df = df[~df['target_name'].isin(['Dry_Total_g', 'GDM_g'])]
print(f"Dataset size: {df.shape}")
display(df.head())

Dataset size: (1071, 9)


Unnamed: 0,sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
0,ID1011485656__Dry_Clover_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Clover_g,0.0
1,ID1011485656__Dry_Dead_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Dead_g,31.9984
2,ID1011485656__Dry_Green_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Green_g,16.275
5,ID1012260530__Dry_Clover_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Clover_g,0.0
6,ID1012260530__Dry_Dead_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Dead_g,0.0


In [8]:
# pivot the dataframe to have one row per image with multiple target columns
tabular_df = df.pivot_table(index=['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI'],
                            columns='target_name', values='target', aggfunc='first').reset_index()
tabular_df.columns.name = None  # remove the aggregation name
print(tabular_df.shape)
print(tabular_df.head())

(357, 9)
               image_path Sampling_Date State            Species  \
0  train/ID1011485656.jpg      2015/9/4   Tas    Ryegrass_Clover   
1  train/ID1012260530.jpg      2015/4/1   NSW            Lucerne   
2  train/ID1025234388.jpg      2015/9/1    WA  SubcloverDalkeith   
3  train/ID1028611175.jpg     2015/5/18   Tas           Ryegrass   
4  train/ID1035947949.jpg     2015/9/11   Tas           Ryegrass   

   Height_Ave_cm  Pre_GSHH_NDVI  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  
0         4.6667           0.62        0.0000     31.9984      16.2750  
1        16.0000           0.55        0.0000      0.0000       7.6000  
2         1.0000           0.38        6.0500      0.0000       0.0000  
3         5.0000           0.66        0.0000     30.9703      24.2376  
4         3.5000           0.54        0.4343     23.2239      10.5261  


In [9]:
print(tabular_df.columns.tolist())

['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']


In [10]:
target_cols  = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']
num_features = ['Height_Ave_cm', 'Pre_GSHH_NDVI']
cat_features = ['Species', 'State']

In [11]:
# BUG: data leakage, will be fixed later in this code
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), num_features), # normalizing numeric features
        ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features) # OHE for categorical features
    ]
)

In [12]:
tabular_data = preprocessor.fit_transform(tabular_df)

In [13]:
print(tabular_data.shape)
display(tabular_data[:5])

(357, 21)


array([[-0.28520388, -0.24631873,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ],
       [ 0.81823967, -0.70706013,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ],
       [-0.64220462, -1.82600352,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         1.        ],
       [-0.25275281,  0.01696207,  0.        ,  0.        

In [1]:
# check the distribution of numeric features (1 and 2 columns)
height_ave_cm = tabular_data[:, 0]
pre_gshh_ndvi = tabular_data[:, 1]

print("Height_Ave_cm - mean:", np.mean(height_ave_cm), " std:", np.std(height_ave_cm))
print("Pre_GSHH_NDVI - mean:", np.mean(pre_gshh_ndvi), " std:", np.std(pre_gshh_ndvi))

NameError: name 'tabular_data' is not defined

## Dataset

In [46]:
class BiomassDataset(Dataset):
    """Dataset for biomass prediction with image + tabular features"""

    def __init__(
        self,
        df: pd.DataFrame,
        tabular_features: np.ndarray,
        target_cols: list[str],
        img_dir: str,
        transform: transforms.Compose | None = None,
        is_test: bool = False,
        use_log_target: bool = True
    ):
        """
        Args:
            df: DataFrame with image_id, image_path, and target columns
            tabular_features: Preprocessed tabular features (shape: [n_samples, n_features])
            target_cols: List of target column names
            img_dir: Root directory for images
            transform: torchvision transform pipeline
            is_test: If True, targets are not expected in df
            use_log_target: If True, apply log1p transform to targets
        """
        self.df = df.reset_index(drop=True)
        self.tabular_features = tabular_features
        self.target_cols = target_cols
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        self.use_log_target = use_log_target

        assert len(self.df) == len(self.tabular_features), \
            f"DataFrame length {len(self.df)} != tabular features length {len(self.tabular_features)}"

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict:
        """
        Returns:
            dict with keys:
                - 'left_image': tensor [C, H, W]
                - 'right_image': tensor [C, H, W]
                - 'tabular': tensor [n_features]
                - 'targets': tensor [n_targets] (if not test)
                - 'image_id': str
        """
        row = self.df.iloc[idx]

        # Load image
        img_path = os.path.join(self.img_dir, row['image_path'].replace('train/', '').replace('test/', ''))
        image = cv2.imread(img_path)

        if image is None:
            raise FileNotFoundError(f"Cannot load image: {img_path}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Split into left and right patches
        # Original image shape: [H, W, C] = [1000, 2000, 3]
        h, w, c = image.shape
        mid_w = w // 2

        left_patch = image[:, :mid_w, :]   # [1000, 1000, 3]
        right_patch = image[:, mid_w:, :]  # [1000, 1000, 3]

        # Convert to PIL Image for torchvision transforms
        left_pil = Image.fromarray(left_patch)
        right_pil = Image.fromarray(right_patch)

        # Apply transforms
        if self.transform:
            left_tensor = self.transform(left_pil)
            right_tensor = self.transform(right_pil)
        else:
            left_tensor = transforms.ToTensor()(left_pil)
            right_tensor = transforms.ToTensor()(right_pil)

        # Get tabular features
        tabular = torch.tensor(self.tabular_features[idx], dtype=torch.float32)

        # Prepare output
        output = {
            'left_image': left_tensor,
            'right_image': right_tensor,
            'tabular': tabular,
            'image_id': row['image_path'].split('/')[-1].replace('.jpg', '')
        }

        # Add targets if not test
        if not self.is_test:
            targets = row[self.target_cols].values.astype(np.float32)

            # Apply log transform if enabled
            if self.use_log_target:
                # log1p handles zeros: log(1+0) = 0
                targets = np.log1p(targets)

            output['targets'] = torch.tensor(targets, dtype=torch.float32)

        return output

In [47]:
def calculate_img_data_stat(df: pd.DataFrame):
    """Calculate mean and std of image data for normalization."""
    means = []
    stds = []

    loader = tqdm(df['image_path'], desc="Calculating image stats")

    for img_path in loader:
        full_path = os.path.join(PATH_TRAIN_IMG, img_path.replace('train/', '').replace('test/', ''))
        image = cv2.imread(full_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0  # Normalize to [0, 1]

        means.append(np.mean(image, axis=(0, 1)))
        stds.append(np.std(image, axis=(0, 1)))

    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)

    return mean, std

In [48]:
# train_mean, train_std = calculate_img_data_stat(tabular_df)
# print(f"Train Image Mean: {train_mean}, Std: {train_std}")

Train Image Mean: [0.44173591 0.50362967 0.30579783]

Std: [0.2364247  0.23557117 0.22199257]

In [49]:
# Image backbone (processes each patch independently)
try:
    temp_backbone = timm.create_model(
        MODEL,
        pretrained=True,
        num_classes=0,  # remove classification head
        global_pool='avg'
    )
    temp_backbone.to(DEVICE)
    temp_backbone.eval()
    config = temp_backbone.default_cfg

except RuntimeError as e:
    print(f"Error loading model {MODEL} with timm: {e}")
    # Try loading from Hugging Face transformer
    print("Trying Hugging Face transformers...")
    config = Dinov2Config.from_pretrained(MODEL)
    temp_backbone = Dinov2Model.from_pretrained(MODEL, config=config)

config


Error loading model facebook/dinov2-base with timm: Unknown model (dinov2-base)
Trying Hugging Face transformers...


Dinov2Config {
  "apply_layernorm": true,
  "architectures": [
    "Dinov2Model"
  ],
  "attention_probs_dropout_prob": 0.0,
  "drop_path_rate": 0.0,
  "dtype": "float32",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 518,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-06,
  "layerscale_value": 1.0,
  "mlp_ratio": 4,
  "model_type": "dinov2",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "out_features": [
    "stage12"
  ],
  "out_indices": [
    12
  ],
  "patch_size": 14,
  "qkv_bias": true,
  "reshape_hidden_states": true,
  "stage_names": [
    "stem",
    "stage1",
    "stage2",
    "stage3",
    "stage4",
    "stage5",
    "stage6",
    "stage7",
    "stage8",
    "stage9",
    "stage10",
    "stage11",
    "stage12"
  ],
  "transformers_version": "4.57.3",
  "use_mask_token": true,
  "use_swiglu_ffn": false
}

In [50]:
try:
    inputs_size = config['input_size']
    inputs_size = int(inputs_size[1]) if inputs_size is not None and inputs_size[1] == inputs_size[2] else 256
    mean = config['mean']
    std = config['std']
except TypeError as e:
    print(f"Error accessing model config: {e}")
    inputs_size = config.image_size
    mean = [0.485, 0.456, 0.406]
    std  = [0.229, 0.224, 0.225]

Error accessing model config: 'Dinov2Config' object is not subscriptable


In [51]:
temp_backbone

Dinov2Model(
  (embeddings): Dinov2Embeddings(
    (patch_embeddings): Dinov2PatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): Dinov2Encoder(
    (layer): ModuleList(
      (0-11): 12 x Dinov2Layer(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attention): Dinov2Attention(
          (attention): Dinov2SelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): Dinov2SelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (layer_scale1): Dinov2LayerScale()
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06,

In [52]:
SIZE = inputs_size
print(f"Backbone expected input size: {inputs_size}, using SIZE={SIZE}")
print(f"Backbone expected mean: {mean}, std: {std}")

# Get backbone output dimension
with torch.no_grad():
    try:
        dummy = torch.randn(1, 3, SIZE, SIZE)
        feat_dim = temp_backbone(dummy).shape[1]
    except AttributeError as ae:
        print(f"AttributeError: {ae}. Trying Hugging Face model forward pass.")
        dummy = torch.randn(1, 3, SIZE, SIZE)
        outputs = temp_backbone(dummy)
        feat_dim = outputs.last_hidden_state.sum(dim=1).shape[1]  # Average pooling
        print(feat_dim)
        
    except Exception as e:
        logger.error(f"Error getting backbone feature dimension: {e}")
        raise e

Backbone expected input size: 518, using SIZE=518
Backbone expected mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225]
AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'shape'. Trying Hugging Face model forward pass.
768


In [53]:
outputs.last_hidden_state.shape  # [1, 1370, 768] where 1370 is number of patches -> (518 / 14) **2

torch.Size([1, 1370, 768])

In [54]:
train_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),  # Increase to 90 for top-down view
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
    ], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1))  # Simulate occlusions
])

val_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [55]:
# Create dataset instance
train_dataset = BiomassDataset(
    df=tabular_df,
    tabular_features=tabular_data,
    target_cols=target_cols,
    img_dir=PATH_TRAIN_IMG,
    transform=train_transform,
    is_test=False
)

# Test it
sample = train_dataset[0]
print(f"Left image shape: {sample['left_image'].shape}")
print(f"Right image shape: {sample['right_image'].shape}")
print(f"Tabular shape: {sample['tabular'].shape}")
print(f"Targets shape: {sample['targets'].shape}")
print(f"Image ID: {sample['image_id']}")
print(f"Target values: {sample['targets']}")
print()

# Test dataset with log transform
print(f"Original targets from df: {tabular_df.iloc[0][target_cols].values}")
print(f"Log-transformed targets: {sample['targets']}")
print(f"Should be close to: {np.log1p(tabular_df.iloc[0][target_cols].values.astype(np.float32))}")

Left image shape: torch.Size([3, 518, 518])
Right image shape: torch.Size([3, 518, 518])
Tabular shape: torch.Size([21])
Targets shape: torch.Size([3])
Image ID: ID1011485656
Target values: tensor([0.0000, 3.4965, 2.8493])

Original targets from df: [np.float64(0.0) np.float64(31.9984) np.float64(16.275)]
Log-transformed targets: tensor([0.0000, 3.4965, 2.8493])
Should be close to: [0.        3.496459  2.8492603]


## Spliting Data (StratifiedKFold)

In [56]:
def get_season(date_str: str) -> str:
    """
    Convert date string to season.

    Args:
        date_str: Date in format 'YYYY/M/D' or 'YYYY/MM/DD'

    Returns:
        Season name: 'Summer', 'Autumn', 'Winter', 'Spring'
    """
    # Parse month from date string
    month = int(date_str.split('/')[1])

    # Australian seasons (Southern Hemisphere)
    if month in [12, 1, 2]:
        return 'Summer'
    elif month in [3, 4, 5]:
        return 'Autumn'
    elif month in [6, 7, 8]:
        return 'Winter'
    else:  # 9, 10, 11
        return 'Spring'

In [57]:
# Add season column
tabular_df['Season'] = tabular_df['Sampling_Date'].apply(get_season)

# Create stratification column combining Season, State, and Species
tabular_df['strat_group'] = (
    tabular_df['Season'].astype(str) + '_' +
    tabular_df['State'].astype(str) + '_' +
    tabular_df['Species'].astype(str)
)

print("Unique stratification groups:")
print(tabular_df['strat_group'].value_counts())
print(f"\nTotal groups: {tabular_df['strat_group'].nunique()}")

Unique stratification groups:
strat_group
Spring_Tas_Ryegrass_Clover                                                41
Winter_Vic_Phalaris_Clover                                                32
Autumn_Tas_Ryegrass                                                       22
Winter_Vic_Ryegrass_Clover                                                21
Spring_Tas_Clover                                                         21
Winter_WA_Clover                                                          20
Winter_Tas_Ryegrass_Clover                                                18
Autumn_NSW_Lucerne                                                        15
Spring_Vic_Ryegrass_Clover                                                11
Spring_Vic_Phalaris_BarleyGrass_SilverGrass_SpearGrass_Clover_Capeweed    11
Spring_NSW_Fescue                                                         11
Summer_NSW_Fescue_CrumbWeed                                               10
Spring_Vic_Phalaris_Clover        

In [58]:
# Initialize StratifiedKFold
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Get stratification labels
strat_labels = tabular_df['strat_group'].values

# Create fold assignments
tabular_df['fold'] = -1

for fold_idx, (train_idx, val_idx) in enumerate(skf.split(tabular_df, strat_labels)):
    tabular_df.loc[val_idx, 'fold'] = fold_idx

    print(f"\nFold {fold_idx + 1}:")
    print(f"  Train samples: {len(train_idx)}")
    print(f"  Val samples: {len(val_idx)}")


Fold 1:
  Train samples: 285
  Val samples: 72

Fold 2:
  Train samples: 285
  Val samples: 72

Fold 3:
  Train samples: 286
  Val samples: 71

Fold 4:
  Train samples: 286
  Val samples: 71

Fold 5:
  Train samples: 286
  Val samples: 71




In [59]:
# Verify stratification worked
print("Stratification verification:")

for fold in range(n_folds):
    fold_df = tabular_df[tabular_df['fold'] == fold]
    print(f"\nFold {fold + 1}:")
    print(f"  Season distribution:")
    print(fold_df['Season'].value_counts(normalize=True).round(3))
    print(f"  State distribution:")
    print(fold_df['State'].value_counts(normalize=True).round(3))
    print(f"  Species distribution:")
    print(fold_df['Species'].value_counts(normalize=True).round(3))

Stratification verification:

Fold 1:
  Season distribution:
Season
Winter    0.361
Spring    0.347
Autumn    0.167
Summer    0.125
Name: proportion, dtype: float64
  State distribution:
State
Tas    0.389
Vic    0.319
NSW    0.222
WA     0.069
Name: proportion, dtype: float64
  Species distribution:
Species
Ryegrass_Clover                                                0.292
Ryegrass                                                       0.167
Phalaris_Clover                                                0.111
Clover                                                         0.111
Fescue                                                         0.083
Lucerne                                                        0.056
Fescue_CrumbWeed                                               0.028
Phalaris                                                       0.028
WhiteClover                                                    0.028
Phalaris_Clover_Ryegrass_Barleygrass_Bromegrass                0.028


In [60]:
def get_fold_data(df: pd.DataFrame, fold: int):
    """
    Get train/val split for specific fold.

    Args:
        df: DataFrame with 'fold' column
        fold: Fold index to use as validation

    Returns:
        train_df, val_df
    """
    train_df = df[df['fold'] != fold].reset_index(drop=True)
    val_df = df[df['fold'] == fold].reset_index(drop=True)

    return train_df, val_df

In [62]:
def get_loaders_patches(fold: int, bs: int) -> tuple[DataLoader, DataLoader, int]:
    """Get dataloaders for specific fold with proper preprocessing."""
    train_df, val_df = get_fold_data(tabular_df, fold)

    print(f"Training fold {fold}:")
    print(f"  Train size: {len(train_df)}")
    print(f"  Val size: {len(val_df)}")

    # Create NEW preprocessor for each fold
    fold_preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), num_features),
            ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features)
        ]
    )

    # FIT only on train, TRANSFORM both
    train_tabular = fold_preprocessor.fit_transform(train_df)
    val_tabular = fold_preprocessor.transform(val_df)  # Only transform!

    # Create datasets
    train_dataset = BiomassDataset(
        df=train_df,
        tabular_features=train_tabular,
        target_cols=target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=train_transform,
        is_test=False,
        use_log_target=USE_LOG_TARGET
    )

    val_dataset = BiomassDataset(
        df=val_df,
        tabular_features=val_tabular,
        target_cols=target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=val_transform,
        is_test=False,
        use_log_target=USE_LOG_TARGET
    )

    # Create dataloaders with num_workers
    train_loader = DataLoader(
        train_dataset,
        batch_size=bs,
        shuffle=True,
        num_workers=min(NUM_WORKERS, 8),  # Limit to avoid overhead
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=bs * 2,  # Larger batch for validation
        shuffle=False,
        num_workers=min(NUM_WORKERS, 8),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    tabular_dim = train_tabular.shape[1]

    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Tabular features dimension: {tabular_dim}")

    return train_loader, val_loader, tabular_dim

In [None]:
train_loader, val_loader, tabular_dim = get_loaders_patches(fold=0, bs=BATCH_SIZE)
print(tabular_dim)

Training fold 0:
  Train size: 285
  Val size: 72
Train batches: 72
Val batches: 9
Tabular features dimension: 21
21
Training fold 0:
  Train size: 285
  Val size: 72
Train batches: 72
Val batches: 9
Tabular features dimension: 21
21


## Ligtning Module

In [64]:
labels = [
    "Dry_Clover_g",
    "Dry_Dead_g",
    "Dry_Green_g",
    "Dry_Total_g",
    "GDM_g"
]

weights = {
    'Dry_Green_g': 0.1,
    'Dry_Dead_g': 0.1,
    'Dry_Clover_g': 0.1,
    'GDM_g': 0.2,
    'Dry_Total_g': 0.5,
}


def competition_metric(y_true, y_pred) -> float:
    """Function to calculate the competition's official evaluation metric (weighted R2 score)."""
    weights_array = np.array([weights[l] for l in labels])

    # Align with this calculation method
    y_weighted_mean = np.average(y_true, weights=weights_array, axis=1).mean()

    # For ss_res and ss_tot, also take the weighted average on axis=1, then the mean of the result
    ss_res = np.average((y_true - y_pred)**2,
                        weights=weights_array, axis=1).mean()
    ss_tot = np.average((y_true - y_weighted_mean)**2,
                        weights=weights_array, axis=1).mean()

    return 1 - ss_res / ss_tot

In [66]:
class BiomassTeacherModelPatches(pl.LightningModule):
    """
    Teacher model with patch-level predictions using DINOv2.
    - Extracts dense features from DINOv2 for each patch
    - Applies shared MLP to each patch independently
    - Averages predictions across all patches
    - Also incorporates tabular features via gating mechanism
    """

    def __init__(
        self,
        backbone_name: str = 'facebook/dinov2-base',
        tabular_dim: int = 10,
        num_targets: int = 3,
        lr: float = 1e-4,
        weight_decay: float = 1e-5,
        hidden_ratio: float = 0.5,
        dropout: float = 0.2,
        fusion_method: str = 'gating',  # 'gating' or 'concat'
        use_log_target: bool = True
    ):
        """
        Args:
            backbone_name: DINOv2 model name
            tabular_dim: dimension of tabular features
            num_targets: number of regression targets (3)
            lr: learning rate
            weight_decay: weight decay for optimizer
            hidden_ratio: ratio for hidden layer size
            dropout: dropout probability
            fusion_method: how to use tabular features ('gating' or 'concat')
            use_log_target: if True, predict log1p transformed targets
        """
        super().__init__()
        self.save_hyperparameters()

        # Load DINOv2 backbone
        self.backbone = Dinov2Model.from_pretrained(backbone_name)
        self.backbone.train()

        self.hidden_dim = self.backbone.config.hidden_size
        self.patch_size = self.backbone.config.patch_size

        self.lr = lr
        self.weight_decay = weight_decay
        self.fusion_method = fusion_method
        self.use_log_target = use_log_target

        # Patch-level MLPs (shared across all patches)
        hidden_size = max(32, int(self.hidden_dim * hidden_ratio))

        def make_patch_head():
            """MLP for patch-level prediction"""
            return nn.Sequential(
                nn.Linear(self.hidden_dim, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, 1)
            )

        # Separate heads for each target
        self.patch_head_green = make_patch_head()
        self.patch_head_clover = make_patch_head()
        self.patch_head_dead = make_patch_head()

        # Tabular features used for modulation
        if self.fusion_method == 'gating':
            # Gating network - features control the strength of predictions
            self.tabular_gate = nn.Sequential(
                nn.Linear(tabular_dim, hidden_size),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_size, 1),
                nn.Sigmoid()  # Output [0, 1]
            )
        elif self.fusion_method == 'concat':
            # Concatenate after patch-level averaging
            self.fusion_layer = nn.Sequential(
                nn.Linear(3 + tabular_dim, hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, 3)
            )

        self.validation_step_outputs = []

        msg = (
            "Patch-level Teacher Model initialized:\n"
            f"backbone={backbone_name}, hidden_dim={self.hidden_dim}, patch_size={self.patch_size},\n"
            f"fusion_method={fusion_method}, use_log_target={use_log_target}"
        )
        logger.info(msg)

    def forward(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass with patch-level processing.

        Args:
            batch: dict with keys:
                - 'left_image': [B, 3, H, W]
                - 'right_image': [B, 3, H, W]
                - 'tabular': [B, tabular_dim]

        Returns:
            (green, clover, dead) predictions [B]
        """
        # Forward through DINOv2 for left patch
        left_outputs = self.backbone(batch['left_image'])
        left_features = left_outputs.last_hidden_state  # [B, num_patches+1, hidden_dim]

        # Remove CLS token, keep only patch tokens
        left_patches = left_features[:, 1:, :]  # [B, num_patches, hidden_dim]

        # Forward through DINOv2 for right patch
        right_outputs = self.backbone(batch['right_image'])
        right_features = right_outputs.last_hidden_state
        right_patches = right_features[:, 1:, :]

        # Concatenate patches from both images
        all_patches = torch.cat([left_patches, right_patches], dim=1)  # [B, 2*num_patches, hidden_dim]

        # Apply MLP to each patch
        patch_preds_green = self.patch_head_green(all_patches)  # [B, 2*num_patches, 1]
        patch_preds_clover = self.patch_head_clover(all_patches)
        patch_preds_dead = self.patch_head_dead(all_patches)

        # Average predictions across patches
        green = patch_preds_green.mean(dim=1).squeeze(1)  # [B]
        clover = patch_preds_clover.mean(dim=1).squeeze(1)
        dead = patch_preds_dead.mean(dim=1).squeeze(1)

        # Use tabular features for modulation
        if self.fusion_method == 'gating':
            # Gating: modulate final prediction
            gate = self.tabular_gate(batch['tabular'])  # [B, 1]
            green = green * gate.squeeze(1)
            clover = clover * gate.squeeze(1)
            dead = dead * gate.squeeze(1)

        elif self.fusion_method == 'concat':
            # Concatenate: final layer to combine
            combined = torch.cat([
                green.unsqueeze(1),
                clover.unsqueeze(1),
                dead.unsqueeze(1),
                batch['tabular']
            ], dim=1)
            output = self.fusion_layer(combined)  # [B, 3]
            green = output[:, 0]
            clover = output[:, 1]
            dead = output[:, 2]

        return green, clover, dead

    def compute_all_targets(self, green: torch.Tensor, clover: torch.Tensor, dead: torch.Tensor) -> torch.Tensor:
        """
        Compute all 5 targets from 3 predicted ones using linear dependencies.

        Args:
            green: Dry_Green_g predictions [B]
            clover: Dry_Clover_g predictions [B]
            dead: Dry_Dead_g predictions [B]

        Returns:
            All 5 targets [B, 5] in order: [Clover, Dead, Green, Total, GDM]
        """
        # Clamp to ensure non-negative after conversion from log space
        green = torch.clamp(green, min=0.0)
        clover = torch.clamp(clover, min=0.0)
        dead = torch.clamp(dead, min=0.0)

        # Calculate derived targets using linear dependencies
        # Dry_Total_g = Dry_Green_g + Dry_Dead_g + Dry_Clover_g
        total = green + dead + clover

        # GDM_g = Dry_Clover_g + Dry_Green_g
        gdm = clover + green

        # Stack in the order expected by competition_metric:
        # ["Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]
        all_targets = torch.stack([clover, dead, green, total, gdm], dim=1)
        
        return all_targets

    def compute_loss(self, pred: tuple, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: (green, clover, dead) predictions
            target: [B, 3] ground truth targets [clover, dead, green]

        Returns:
            MSE loss
        """
        green_pred, clover_pred, dead_pred = pred
        clover_true = target[:, 0]  # Dry_Clover_g
        dead_true = target[:, 1]    # Dry_Dead_g
        green_true = target[:, 2]   # Dry_Green_g

        loss_green = F.mse_loss(green_pred, green_true)
        loss_clover = F.mse_loss(clover_pred, clover_true)
        loss_dead = F.mse_loss(dead_pred, dead_true)

        total_loss = loss_green + loss_clover + loss_dead

        return total_loss

    def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        pred = self(batch)
        loss = self.compute_loss(pred, batch['targets'])
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True,
                batch_size=batch['targets'].size(0))
        return loss

    def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        green_pred, clover_pred, dead_pred = self(batch)
        loss = self.compute_loss((green_pred, clover_pred, dead_pred), batch['targets'])

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True,
                 batch_size=batch['targets'].size(0))

        if self.use_log_target:
            green_pred = torch.expm1(green_pred)
            clover_pred = torch.expm1(clover_pred)
            dead_pred = torch.expm1(dead_pred)
            targets_original = torch.expm1(batch['targets'])
        else:
            targets_original = batch['targets']

        preds_all = self.compute_all_targets(green_pred, clover_pred, dead_pred)

        clover_true = targets_original[:, 0]
        dead_true = targets_original[:, 1]
        green_true = targets_original[:, 2]
        targets_all = self.compute_all_targets(green_true, clover_true, dead_true)

        self.validation_step_outputs.append({
            'preds': preds_all.detach().cpu(),
            'targets': targets_all.detach().cpu()
        })

        return loss

    def on_validation_epoch_end(self):
        if len(self.validation_step_outputs) == 0:
            return

        all_preds = torch.cat(
            [x['preds'] for x in self.validation_step_outputs], dim=0).numpy()
        all_targets = torch.cat(
            [x['targets'] for x in self.validation_step_outputs], dim=0).numpy()

        comp_metric = competition_metric(all_targets, all_preds)
        self.log('val_comp_metric', comp_metric, on_epoch=True, prog_bar=True)
        self.validation_step_outputs.clear()

    def predict_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        green, clover, dead = self(batch)
        preds = torch.stack([clover, dead, green], dim=1)

        if self.use_log_target:
            preds = torch.expm1(preds)

        preds = torch.clamp(preds, min=0.0)
        return preds

    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay
        )
        
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs or 20,
            eta_min=self.lr * 0.01
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'  # FIXME: 'step' if needed
            }
        }

## Folds Training

In [None]:
# Train on all folds
fold_results = []

for fold_id in range(N_FOLDS):
    train_loader, val_loader, tabular_dim = get_loaders_patches(
        fold=fold_id, bs=BATCH_SIZE)

    model = BiomassTeacherModelPatches(
        backbone_name=MODEL,
        tabular_dim=tabular_dim,
        num_targets=len(target_cols),
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        hidden_ratio=HIDDEN_RATIO,
        dropout=DROPOUT_RATE,
        fusion_method=FUSION_METHOD,
        use_log_target=USE_LOG_TARGET
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_comp_metric',  # 'val_loss'
        dirpath=os.path.join(CHECKPOINTS_DIR, f'fold{fold_id}'),
        filename=f'{DESCRIPTION_FULL}-fold{fold_id}' +
        '-{epoch:02d}-{val_loss:.4f}-{val_comp_metric:.4f}',
        save_top_k=3,  # Save top 3 instead of 1
        mode='max'  # or 'min' for val_loss
    )

    early_stopping_callback = EarlyStopping(
        monitor='val_comp_metric',  # 'val_loss'
        patience=7,
        mode='max',
        verbose=True,
        min_delta=1e-3
    )

    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # Logger
    wandb_logger = WandbLogger(
        project=PROJECT_NAME,
        name=f'{DESCRIPTION_FULL}-fold{fold_id}',
        log_model='all',
        tags=['patch-level', f'fold{fold_id}', FUSION_METHOD]
    )

    # Trainer
    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        accelerator=DEVICE.type,
        precision='16-mixed' if torch.cuda.is_available() else 32,
        accumulate_grad_batches=GRAD_ACCUM,
        callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
        logger=wandb_logger,
        log_every_n_steps=1,
        gradient_clip_val=1.0,
        enable_progress_bar=True,
        deterministic=True
    )

    try:
        # Train
        trainer.fit(model, train_loader, val_loader)

        # Load best checkpoint
        best_model_path = checkpoint_callback.best_model_path
        logger.info(f"Loading best model from: {best_model_path}")
        try:
            best_model = BiomassTeacherModel.load_from_checkpoint(
                best_model_path)
        except Exception as e:
            best_model = BiomassTeacherModelPatches.load_from_checkpoint(
                best_model_path)

        # Evaluate on validation set
        val_result = trainer.validate(best_model, val_loader, verbose=False)
        fold_results.append({
            'fold': fold_id,
            'val_loss': val_result[0]['val_loss'],
            'val_comp_metric': val_result[0]['val_comp_metric']
        })

    except SystemExit:
        logger.warning(
            f"Training interrupted during fold {fold_id}. Exiting gracefully.")
        wandb_logger.experiment.finish()

        # Clean up memory after each fold
        del model, trainer, wandb_logger
        del train_loader, val_loader

        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.synchronize()

        break

    finally:
        wandb_logger.experiment.finish()

        # Clean up memory after each fold
        del model, trainer, wandb_logger
        del train_loader, val_loader

        torch.cuda.empty_cache()
        gc.collect()
        torch.cuda.synchronize()

In [68]:
print("Training Summary")
print()
results_df = pd.DataFrame(fold_results)
print(results_df)
print(f"Mean Val Loss: {results_df['val_loss'].mean():.4f} ± {results_df['val_loss'].std():.4f}")

Training Summary

   fold  val_loss  val_comp_metric
0     0  0.894299         0.566824
1     1  0.820352         0.578140
2     2  0.892732         0.575059
3     3  1.365784         0.543001
4     4  1.096217         0.456924
Mean Val Loss: 1.0139 ± 0.2219


## Prepare Data for Student

In [70]:
def load_teacher_models_facebook(checkpoints_dir: str, n_folds: int) -> list:
    """
    Load best checkpoints from all folds.
    
    Args:
        checkpoints_dir: Directory with fold checkpoints
        n_folds: Number of folds
    
    Returns:
        List of loaded models
    """
    models = []
    
    for fold_id in range(n_folds):
        fold_dir = os.path.join(checkpoints_dir, f'fold{fold_id}/facebook')
        
        # Find best checkpoint file
        ckpt_files = [f for f in os.listdir(fold_dir) if f.endswith('.ckpt')]
        if not ckpt_files:
            logger.warning(f"No checkpoint found in {fold_dir}")
            continue

        # sort by comp_val_metric in filename (comp_val_metric=xxx.ckpt)
        ckpt_files.sort(key=lambda x: float(x.split('=')[-1].replace('.ckpt', '')))
        
        # Assuming the last one is the best
        ckpt_path = os.path.join(fold_dir, ckpt_files[-1])
        
        model = BiomassTeacherModelPatches.load_from_checkpoint(ckpt_path)
        logger.info(f"Loading checkpoint: {ckpt_path}")
        model.eval()
        model = model.to(DEVICE)
        models.append(model)
    
    logger.success(f"Loaded {len(models)} teacher models")
    return models

In [None]:
models = load_teacher_models_facebook(CHECKPOINTS_DIR, N_FOLDS)

[38;2;161;247;255m
[2025-12-13 18:41:53]
INFO: Patch-level Teacher Model initialized:
backbone=facebook/dinov2-base, hidden_dim=768, patch_size=14,
fusion_method=gating, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-13 18:41:54]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold0/facebook\dinov2-base-local_train[5]Folds_log_fusion-gating_epochs20_bs4_gradacc4_lr3e-05_wd0.05_dr0.2_hr0.5-fold0-epoch=17-val_loss=0.8943-val_comp_metric=0.5668.ckpt[0m
[38;2;161;247;255m
[2025-12-13 18:41:55]
INFO: Patch-level Teacher Model initialized:
backbone=facebook/dinov2-base, hidden_dim=768, patch_size=14,
fusion_method=gating, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-13 18:41:56]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold1/facebook\dinov2-base-local_train[5]Folds_log_fusion-gating_epochs20_bs4_gradacc4_lr3e-05_wd0.05_dr0.2_hr0.5-fold1-epoch=16-val_loss=0.8204-val_comp_metric=0.5781.ckpt[0m
[38;2;161;247;255m
[2025-12-13 18:41:57]
INFO: Patch-level Te

RuntimeError: Error(s) in loading state_dict for BiomassTeacherModelPatches:
	Missing key(s) in state_dict: "patch_head_green.1.weight", "patch_head_green.1.bias", "patch_head_green.4.weight", "patch_head_green.4.bias", "patch_head_clover.1.weight", "patch_head_clover.1.bias", "patch_head_clover.4.weight", "patch_head_clover.4.bias", "patch_head_dead.1.weight", "patch_head_dead.1.bias", "patch_head_dead.4.weight", "patch_head_dead.4.bias". 
	Unexpected key(s) in state_dict: "patch_head_green.3.weight", "patch_head_green.3.bias", "patch_head_clover.3.weight", "patch_head_clover.3.bias", "patch_head_dead.3.weight", "patch_head_dead.3.bias". 

### Direct Ensemble Soft Labels

In [None]:
def create_soft_targets_dataset():
    """
    Create dataset with soft targets from teacher ensemble.
    Predicts on ALL training data (100%).
    
    Returns:
        DataFrame with soft targets appended
    """
    
    # Load all teacher models
    teacher_models = load_teacher_models_facebook(CHECKPOINTS_DIR, N_FOLDS)

    if not teacher_models:
        logger.error("No teacher models loaded!")
        return None
    
    logger.info(f"Creating soft targets using {len(teacher_models)} teacher models...")
    
    # Prepare full dataset with NO augmentation
    full_dataset = BiomassDataset(
        df=tabular_df,
        tabular_features=tabular_data,
        target_cols=target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=val_transform,  # Use validation transform (no augmentation)
        is_test=False,
        use_log_target=USE_LOG_TARGET
    )
    
    # Create dataloader for inference
    full_loader = DataLoader(
        full_dataset,
        batch_size=BATCH_SIZE,  # Larger batch for inference
        shuffle=False,
    )
    
    logger.info(f"Total samples for soft targets: {len(full_dataset)}")
    logger.info(f"Inference batches: {len(full_loader)}")
    
    # Store predictions from all models
    all_predictions = []
    
    # For each teacher model
    for model_idx, teacher_model in enumerate(teacher_models):
        logger.info(f"Processing teacher model {model_idx + 1}/{len(teacher_models)}...")
        
        model_predictions = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(full_loader):
                # Move batch to device
                batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Get predictions
                preds = teacher_model.predict_step(batch, 0)  # [B, 3]
                model_predictions.append(preds.cpu().numpy())
                
                if (batch_idx + 1) % 10 == 0:
                    logger.debug(f"  Batch {batch_idx + 1}/{len(full_loader)}")
        
        # Concatenate all batches
        model_preds_array = np.concatenate(model_predictions, axis=0)  # [N, 3]
        all_predictions.append(model_preds_array)
        
        logger.success(f"Model {model_idx + 1} predictions shape: {model_preds_array.shape}")
    
    # Average predictions across all models
    ensemble_predictions = np.mean(all_predictions, axis=0)  # [N, 3]
    
    logger.success(f"Ensemble predictions shape: {ensemble_predictions.shape}")
    logger.info(f"Ensemble predictions range: min={ensemble_predictions.min():.4f}, max={ensemble_predictions.max():.4f}")
    
    # Create DataFrame with soft targets
    soft_targets_df = tabular_df.copy()
    
    # Add soft target columns
    soft_targets_df['Dry_Clover_g_soft'] = ensemble_predictions[:, 0]
    soft_targets_df['Dry_Dead_g_soft'] = ensemble_predictions[:, 1]
    soft_targets_df['Dry_Green_g_soft'] = ensemble_predictions[:, 2]
    
    # Keep original targets for reference
    # (optional: remove them later for student training)
    
    logger.info("\nSoft targets statistics:")
    logger.info(f"Dry_Clover_g: mean={soft_targets_df['Dry_Clover_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Clover_g_soft'].std():.4f}")
    logger.info(f"Dry_Dead_g: mean={soft_targets_df['Dry_Dead_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Dead_g_soft'].std():.4f}")
    logger.info(f"Dry_Green_g: mean={soft_targets_df['Dry_Green_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Green_g_soft'].std():.4f}")
    
    return soft_targets_df, ensemble_predictions

In [None]:
def save_soft_targets(soft_targets_df: pd.DataFrame, output_path: str = './kaggle/input/'):
    """
    Save soft targets to CSV.
    
    Args:
        soft_targets_df: DataFrame with soft targets
        output_path: Path to save CSV
    """
    os.makedirs(output_path, exist_ok=True)
    
    output_file = os.path.join(output_path, 'train_with_soft_targets.csv')
    soft_targets_df.to_csv(output_file, index=False)
    
    logger.success(f"Saved soft targets to: {output_file}")
    logger.info(f"Shape: {soft_targets_df.shape}")
    logger.info(f"Columns: {soft_targets_df.columns.tolist()}")
    
    return output_file

In [None]:
def validate_soft_targets(soft_targets_df: pd.DataFrame):
    """
    Validate soft targets quality.
    
    Args:
        soft_targets_df: DataFrame with soft targets
    """
    logger.info("\n=== Soft Targets Validation ===")
    
    for col in ['Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft']:
        original_col = col.replace('_soft', '')
        
        # Check for NaN
        nan_count = soft_targets_df[col].isna().sum()
        if nan_count > 0:
            logger.warning(f"{col}: {nan_count} NaN values")
        
        # Compare soft vs hard targets
        corr = soft_targets_df[original_col].corr(soft_targets_df[col])
        mse = np.mean((soft_targets_df[original_col].values - soft_targets_df[col].values)**2)

        # logger.info(f"{col}:")
        # logger.info(f"  Correlation with hard target: {corr:.4f}")
        # logger.info(f"  MSE vs hard target: {mse:.4f}")
        # logger.info(f"  Range: [{soft_targets_df[col].min():.4f}, {soft_targets_df[col].max():.4f}]")
        msg = [f"{col}:", f"Correlation with hard target: {corr:.4f}", f"MSE vs hard target: {mse:.4f}", f"Range: [{soft_targets_df[col].min():.4f}, {soft_targets_df[col].max():.4f}]"]
        msg = "\n".join(msg)
        logger.info(msg)

In [None]:
# Create soft targets
soft_targets_df, ensemble_preds = create_soft_targets_dataset()

[38;2;161;247;255m
[2025-12-13 13:46:13]
INFO: Patch-level Teacher Model initialized:
backbone=facebook/dinov2-base, hidden_dim=768, patch_size=14,
fusion_method=gating, use_log_target=False[0m
[38;2;161;247;255m
[2025-12-13 13:46:13]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold0/facebook\dinov2-base-local_train[5]Folds_fusion-gating_epochs20_bs4_gradacc4_lr0.0001_wd0.05_dr0.2_hr0.5-fold0-epoch=13-val_loss=607.6509-val_comp_metric=0.5166.ckpt[0m
[38;2;161;247;255m
[2025-12-13 13:46:16]
INFO: Patch-level Teacher Model initialized:
backbone=facebook/dinov2-base, hidden_dim=768, patch_size=14,
fusion_method=gating, use_log_target=False[0m
[38;2;161;247;255m
[2025-12-13 13:46:16]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold1/facebook\dinov2-base-local_train[5]Folds_fusion-gating_epochs20_bs4_gradacc4_lr0.0001_wd0.05_dr0.2_hr0.5-fold1-epoch=19-val_loss=721.7338-val_comp_metric=0.5238.ckpt[0m
[38;2;161;247;255m
[2025-12-13 13:46:18]
INFO: Patch-level Te

KeyboardInterrupt: 

In [None]:
# Validate soft targets
validate_soft_targets(soft_targets_df)

# Save to CSV
output_file = save_soft_targets(soft_targets_df)

print(f"File: {output_file}")
print(f"Samples: {len(soft_targets_df)}")
print(f"Soft target columns: {[c for c in soft_targets_df.columns if '_soft' in c]}")

[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: 
=== Soft Targets Validation ===[0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Dry_Clover_g_soft:
Correlation with hard target: 0.9426
MSE vs hard target: 26.7841
Range: [0.0000, 97.1455][0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Dry_Dead_g_soft:
Correlation with hard target: 0.8593
MSE vs hard target: 40.1800
Range: [0.0000, 56.2889][0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Dry_Green_g_soft:
Correlation with hard target: 0.9462
MSE vs hard target: 69.3651
Range: [0.0000, 142.3275][0m
[38;2;105;254;105m
[2025-12-11 14:33:55]
SUCCESS: Saved soft targets to: ./kaggle/input/train_with_soft_targets.csv[0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Shape: (357, 15)[0m
[38;2;161;247;255m
[2025-12-11 14:33:55]
INFO: Columns: ['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Season', 'strat_group', 'fold', 'Dry_Clover_g_soft', 'D

### Out of Fold Predictions

In [None]:
def create_oof_soft_targets():
    """
    Create Out-Of-Fold soft targets from teacher models.
    Each model predicts ONLY on its validation fold (unseen data).
    
    Returns:
        DataFrame with OOF soft targets
    """
    logger.info("CREATING OUT-OF-FOLD (OOF) SOFT TARGETS")
    
    # Initialize array to store OOF predictions
    # Shape: [num_samples, 3] for 3 targets
    oof_predictions = np.zeros((len(tabular_df), 3))
    oof_indices = np.zeros(len(tabular_df), dtype=bool)  # Track which samples got predictions
    
    # For each fold
    for fold_id in range(N_FOLDS):
        logger.info(f"\nProcessing Fold {fold_id}...")
        
        # Load model for this fold
        fold_dir = os.path.join(CHECKPOINTS_DIR, f'fold{fold_id}')
        ckpt_files = [f for f in os.listdir(fold_dir) if f.endswith('.ckpt')]
        
        if not ckpt_files:
            logger.error(f"No checkpoint found for fold {fold_id}")
            continue
        
        # Sort by val_loss to get best checkpoint
        ckpt_files.sort(key=lambda x: float(x.split('=')[-1].replace('.ckpt', '')))
        ckpt_path = os.path.join(fold_dir, ckpt_files[0])
        
        logger.info(f"Loading checkpoint: {ckpt_path}")
        teacher_model = BiomassTeacherModel.load_from_checkpoint(ckpt_path)
        teacher_model.eval()
        teacher_model = teacher_model.to(DEVICE)
        
        # Get validation data for this fold
        val_df = tabular_df[tabular_df['fold'] == fold_id].reset_index(drop=True)
        val_indices = tabular_df[tabular_df['fold'] == fold_id].index.values
        
        logger.info(f"Validation samples for fold {fold_id}: {len(val_df)}")
        
        # Prepare tabular features for validation fold
        fold_preprocessor = ColumnTransformer(
            transformers=[
                ('num', StandardScaler(), num_features),
                ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features)
            ]
        )
        
        # Fit on train, transform on val
        train_df = tabular_df[tabular_df['fold'] != fold_id]
        fold_preprocessor.fit(train_df)
        val_tabular = fold_preprocessor.transform(val_df)
        
        # Create validation dataset
        val_dataset = BiomassDataset(
            df=val_df,
            tabular_features=val_tabular,
            target_cols=target_cols,
            img_dir=PATH_TRAIN_IMG,
            transform=val_transform,  # No augmentation
            is_test=False,
            use_log_target=USE_LOG_TARGET
        )
        
        # Create validation loader
        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE * 2,
            shuffle=False,  # IMPORTANT: keep order!
            num_workers=min(NUM_WORKERS, 8),
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        # Make predictions
        fold_predictions = []
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                # Move batch to device
                batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                        for k, v in batch.items()}
                
                # Get predictions
                preds = teacher_model.predict_step(batch, 0)  # [B, 3]
                fold_predictions.append(preds.cpu().numpy())
        
        # Concatenate batch predictions
        fold_preds_array = np.concatenate(fold_predictions, axis=0)  # [N_val, 3]
        
        logger.success(f"Fold {fold_id} predictions shape: {fold_preds_array.shape}")
        
        # Verify indices match
        assert len(fold_preds_array) == len(val_indices), \
            f"Predictions length {len(fold_preds_array)} != indices length {len(val_indices)}"
        
        # Store predictions at correct indices
        oof_predictions[val_indices] = fold_preds_array
        oof_indices[val_indices] = True
        
        logger.info(f"Stored predictions for indices: {val_indices[:5]}... (showing first 5)")
        
        # Clean up
        del teacher_model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Verify all samples got predictions
    if not oof_indices.all():
        missing_count = (~oof_indices).sum()
        logger.warning(f"Missing predictions for {missing_count} samples!")
    else:
        logger.success(f"All {len(oof_predictions)} samples have OOF predictions!")
    
    # Create DataFrame with OOF soft targets
    soft_targets_df = tabular_df.copy()
    
    # Add OOF soft target columns
    soft_targets_df['Dry_Clover_g_soft'] = oof_predictions[:, 0]
    soft_targets_df['Dry_Dead_g_soft'] = oof_predictions[:, 1]
    soft_targets_df['Dry_Green_g_soft'] = oof_predictions[:, 2]
    
    logger.info("\n=== OOF Soft Targets Statistics ===")
    logger.info(f"Dry_Clover_g: mean={soft_targets_df['Dry_Clover_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Clover_g_soft'].std():.4f}")
    logger.info(f"Dry_Dead_g: mean={soft_targets_df['Dry_Dead_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Dead_g_soft'].std():.4f}")
    logger.info(f"Dry_Green_g: mean={soft_targets_df['Dry_Green_g_soft'].mean():.4f}, "
               f"std={soft_targets_df['Dry_Green_g_soft'].std():.4f}")
    
    return soft_targets_df, oof_predictions

In [None]:
def save_oof_soft_targets(soft_targets_df: pd.DataFrame, output_path: str = './kaggle/input/'):
    """
    Save OOF soft targets to CSV.
    
    Args:
        soft_targets_df: DataFrame with OOF soft targets
        output_path: Path to save CSV
    """
    os.makedirs(output_path, exist_ok=True)
    
    output_file = os.path.join(output_path, 'train_with_oof_soft_targets.csv')
    soft_targets_df.to_csv(output_file, index=False)
    
    logger.success(f"Saved OOF soft targets to: {output_file}")
    logger.info(f"Shape: {soft_targets_df.shape}")
    logger.info(f"Columns: {soft_targets_df.columns.tolist()}")
    
    return output_file

In [None]:
def compare_oof_vs_ensemble(oof_df: pd.DataFrame, ensemble_df: pd.DataFrame):
    """
    Compare OOF predictions vs direct ensemble predictions.

    Args:
        oof_df: DataFrame with OOF soft targets
        ensemble_df: DataFrame with ensemble soft targets
    """
    logger.info("COMPARING OOF vs ENSEMBLE PREDICTIONS")

    for col in ['Dry_Clover_g_soft', 'Dry_Dead_g_soft', 'Dry_Green_g_soft']:
        oof_vals = oof_df[col].values
        ens_vals = ensemble_df[col].values

        # Correlation
        corr = np.corrcoef(oof_vals, ens_vals)[0, 1]

        # Mean Absolute Difference
        mad = np.mean(np.abs(oof_vals - ens_vals))

        # RMSE
        rmse = np.sqrt(np.mean((oof_vals - ens_vals)**2))

        msgs = [
            f"\n{col}:",
            f"  Correlation: {corr:.4f}",
            f"  Mean Absolute Diff: {mad:.4f}",
            f"  RMSE: {rmse:.4f}",
            f"  OOF  range: [{oof_vals.min():.2f}, {oof_vals.max():.2f}]",
            f"  Ensemble range: [{ens_vals.min():.2f}, {ens_vals.max():.2f}]"
        ]
        logger.info("\n".join(msgs))

In [None]:
# Create OOF soft targets
oof_soft_targets_df, oof_preds = create_oof_soft_targets()

[38;2;161;247;255m
[2025-12-11 14:42:18]
INFO: CREATING OUT-OF-FOLD (OOF) SOFT TARGETS[0m
[38;2;161;247;255m
[2025-12-11 14:42:18]
INFO: 
Processing Fold 0...[0m
[38;2;161;247;255m
[2025-12-11 14:42:18]
INFO: Loading checkpoint: ./kaggle/checkpoints/teacher/fold0\swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5-fold0-epoch=13-val_loss=0.8938.ckpt[0m
[38;2;161;247;255m
[2025-12-11 14:42:19]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m
[38;2;161;247;255m
[2025-12-11 14:42:19]
INFO: Validation samples for fold 0: 72[0m
[38;2;105;254;105m
[2025-12-11 14:42:28]
SUCCESS: Fold 0 predictions shape: (72, 3)[0m
[38;2;161;247;255m
[2025-12-11 14:42:28]
INFO: Stored predictions for indices: [ 1  3  6 11 12]... (showing first 5)[0m
[38;2;161;247;255m
[2025-12-11 14:42:28]
INFO: 
Processing Fold 1...[0m
[38;2;161;247;255m
[2025-12-11 14:42:2

In [None]:
# Validate OOF soft targets
validate_soft_targets(oof_soft_targets_df)

# Save to CSV
oof_output_file = save_oof_soft_targets(oof_soft_targets_df)
print(f"File: {oof_output_file}")
print(f"Samples: {len(oof_soft_targets_df)}")

[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
=== Soft Targets Validation ===[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Dry_Clover_g_soft:
Correlation with hard target: 0.8748
MSE vs hard target: 45.3763
Range: [0.0000, 83.9534][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Dry_Dead_g_soft:
Correlation with hard target: 0.7059
MSE vs hard target: 79.0653
Range: [0.0000, 54.3228][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Dry_Green_g_soft:
Correlation with hard target: 0.8542
MSE vs hard target: 184.7793
Range: [0.0000, 140.0474][0m
[38;2;105;254;105m
[2025-12-11 14:43:11]
SUCCESS: Saved OOF soft targets to: ./kaggle/input/train_with_oof_soft_targets.csv[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Shape: (357, 15)[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: Columns: ['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'Season', 'strat_group', 'fold', 'Dry_Clover_g_

In [None]:
# Compare with ensemble predictions (if available)
if 'soft_targets_df' in locals():
    compare_oof_vs_ensemble(oof_soft_targets_df, soft_targets_df)

[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: COMPARING OOF vs ENSEMBLE PREDICTIONS[0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
Dry_Clover_g_soft:
  Correlation: 0.9606
  Mean Absolute Diff: 1.5858
  RMSE: 4.0862
  OOF  range: [0.00, 83.95]
  Ensemble range: [0.00, 97.15][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
Dry_Dead_g_soft:
  Correlation: 0.9250
  Mean Absolute Diff: 2.5321
  RMSE: 4.0186
  OOF  range: [0.00, 54.32]
  Ensemble range: [0.00, 56.29][0m
[38;2;161;247;255m
[2025-12-11 14:43:11]
INFO: 
Dry_Green_g_soft:
  Correlation: 0.9481
  Mean Absolute Diff: 4.5954
  RMSE: 8.1427
  OOF  range: [0.00, 140.05]
  Ensemble range: [0.00, 142.33][0m
