# Distillation. Teacher

In [8]:
machine = "local"
!wandb login

wandb: Currently logged in as: dmykhailov (dmykhailov-kyiv-school-of-economics) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


## Imports

In [9]:
import os
import gc
import logging
import numpy as np
import pandas as pd
from PIL import Image
from typing import cast
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, OneHotEncoder

import cv2
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm import tqdm
from torch.optim import AdamW
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import warnings
from notebooks_config import setup_logging, CustomLogger

warnings.simplefilter(action='ignore', category=FutureWarning)
print(f"PyTorch: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

PyTorch: 2.9.1+cu128
Device: NVIDIA GeForce RTX 5050 Laptop GPU


In [10]:
logger = setup_logging(level=logging.DEBUG, full_color=True, include_function=False)
logger = cast(CustomLogger, logger)  # Type hinting
logger.success("Logging configuration test completed.")

[38;2;105;254;105m
[2025-12-11 01:28:28]
SUCCESS: Logging configured successfully ✅[0m
[38;2;105;254;105m
[2025-12-11 01:28:28]
SUCCESS: Logging configuration test completed.[0m


In [41]:
cpu_count = os.cpu_count()
NUM_WORKERS = 0 if machine == "local" else cpu_count // 2 if cpu_count else 0

LR = 1e-4
EPOCHS = 20
N_FOLDS = 5
GRAD_ACCUM = 1
BATCH_SIZE = 16
DROPOUT_RATE = 0.2
WEIGHT_DECAY = 0.05
HIDDEN_RATIO = 0.5
TRAIN_SPLIT_RATIO = 0.02 # Used if N_FOLDS = 0

MODEL = "swinv2_tiny_window8_256"
PROJECT_NAME = "csiro-image2biomass-prediction"
CHECKPOINTS_DIR = "./kaggle/checkpoints/teacher/"

# Each patch is 1000x1000, resize to 768x768 for vision transformers
SIZE = 768
USE_LOG_TARGET = True   # Whether to use log1p transformation on target variable
FUSION_METHOD = 'mean'  # ('concat', 'mean', 'max')

DESCRIPTION = machine + \
    (f"_train{TRAIN_SPLIT_RATIO}" if N_FOLDS == 0 else f"_train[{N_FOLDS}]Folds") + (
        f"_log" if USE_LOG_TARGET else "") + f"_fusion-{FUSION_METHOD}"
DESCRIPTION_FULL = MODEL + "-" + DESCRIPTION + \
    f"_epochs{EPOCHS}_bs{BATCH_SIZE}_gradacc{GRAD_ACCUM}_lr{LR}_wd{WEIGHT_DECAY}_dr{DROPOUT_RATE}_hr{HIDDEN_RATIO}"
SUBMISSION_NAME = f"{DESCRIPTION_FULL}_submission.csv"
SUBMISSION_ENSEMBLE_NAME = f"{DESCRIPTION_FULL}_ensemble_submission.csv"
SUBMISSION_MSG = DESCRIPTION_FULL.replace("_", " ")

SEED = 1488
torch.manual_seed(SEED)
np.random.seed(SEED)
pl.seed_everything(SEED)

print("DESCRIPTION_FULL:", DESCRIPTION_FULL)
print(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")

Seed set to 1488


DESCRIPTION_FULL: swinv2_tiny_window8_256-local_train[5]Folds_log_fusion-mean_epochs20_bs16_gradacc1_lr0.0001_wd0.05_dr0.2_hr0.5
Effective batch size: 16


In [12]:
# setting device on GPU if available, else CPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)
print('NUM_WORKERS:', NUM_WORKERS)
print()

# Additional Info when using cuda
if DEVICE.type == 'cuda':
    # clean GPU memory
    torch.cuda.empty_cache()
    gc.collect()

    torch.set_float32_matmul_precision('high') if machine == "local" else None

    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3, 1), 'GB')

Using device: cuda
NUM_WORKERS: 20

NVIDIA GeForce RTX 5050 Laptop GPU
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


## Data Loading and Preprocessing

In [13]:
PATH_DATA = './kaggle/input/csiro-biomass'
PATH_TRAIN_CSV = os.path.join(PATH_DATA, 'train.csv')
PATH_TEST_CSV = os.path.join(PATH_DATA, 'test.csv')
PATH_TRAIN_IMG = os.path.join(PATH_DATA, 'train')
PATH_TEST_IMG = os.path.join(PATH_DATA, 'test')

df = pd.read_csv(PATH_TRAIN_CSV)
df = df[~df['target_name'].isin(['Dry_Total_g', 'GDM_g'])]  # Remove unneeded targets
print(f"Dataset size: {df.shape}")
display(df.head())

Dataset size: (1071, 9)


Unnamed: 0,sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
0,ID1011485656__Dry_Clover_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Clover_g,0.0
1,ID1011485656__Dry_Dead_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Dead_g,31.9984
2,ID1011485656__Dry_Green_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Green_g,16.275
5,ID1012260530__Dry_Clover_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Clover_g,0.0
6,ID1012260530__Dry_Dead_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Dead_g,0.0


In [14]:
# pivot the dataframe to have one row per image with multiple target columns
tabular_df = df.pivot_table(index=['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI'],
                              columns='target_name', values='target', aggfunc='first').reset_index()
tabular_df.columns.name = None  # remove the aggregation name
print(tabular_df.shape)
print(tabular_df.head())

(357, 9)
               image_path Sampling_Date State            Species  \
0  train/ID1011485656.jpg      2015/9/4   Tas    Ryegrass_Clover   
1  train/ID1012260530.jpg      2015/4/1   NSW            Lucerne   
2  train/ID1025234388.jpg      2015/9/1    WA  SubcloverDalkeith   
3  train/ID1028611175.jpg     2015/5/18   Tas           Ryegrass   
4  train/ID1035947949.jpg     2015/9/11   Tas           Ryegrass   

   Height_Ave_cm  Pre_GSHH_NDVI  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  
0         4.6667           0.62        0.0000     31.9984      16.2750  
1        16.0000           0.55        0.0000      0.0000       7.6000  
2         1.0000           0.38        6.0500      0.0000       0.0000  
3         5.0000           0.66        0.0000     30.9703      24.2376  
4         3.5000           0.54        0.4343     23.2239      10.5261  


In [15]:
print(tabular_df.columns.tolist())

['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']


In [16]:
target_cols  = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']
num_features = ['Height_Ave_cm', 'Pre_GSHH_NDVI']
cat_features = ['Species', 'State']

In [17]:
# BUG: data leakage, will be fixed later in this code
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), num_features), # normalizing numeric features
        ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features) # OHE for categorical features
    ]
)

In [18]:
tabular_data = preprocessor.fit_transform(tabular_df)

In [19]:
print(tabular_data.shape)
display(tabular_data[:5])

(357, 21)


array([[-0.28520388, -0.24631873,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ],
       [ 0.81823967, -0.70706013,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ],
       [-0.64220462, -1.82600352,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         1.        ],
       [-0.25275281,  0.01696207,  0.        ,  0.        

## Dataset

In [None]:
class BiomassDataset(Dataset):
    """Dataset for biomass prediction with image + tabular features"""

    def __init__(
        self,
        df: pd.DataFrame,
        tabular_features: np.ndarray,
        target_cols: list[str],
        img_dir: str,
        transform: transforms.Compose | None = None,
        is_test: bool = False,
        use_log_target: bool = True
    ):
        """
        Args:
            df: DataFrame with image_id, image_path, and target columns
            tabular_features: Preprocessed tabular features (shape: [n_samples, n_features])
            target_cols: List of target column names
            img_dir: Root directory for images
            transform: torchvision transform pipeline
            is_test: If True, targets are not expected in df
            use_log_target: If True, apply log1p transform to targets
        """
        self.df = df.reset_index(drop=True)
        self.tabular_features = tabular_features
        self.target_cols = target_cols
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        self.use_log_target = use_log_target

        assert len(self.df) == len(self.tabular_features), \
            f"DataFrame length {len(self.df)} != tabular features length {len(self.tabular_features)}"

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, idx: int) -> dict:
        """
        Returns:
            dict with keys:
                - 'left_image': tensor [C, H, W]
                - 'right_image': tensor [C, H, W]
                - 'tabular': tensor [n_features]
                - 'targets': tensor [n_targets] (if not test)
                - 'image_id': str
        """
        row = self.df.iloc[idx]

        # Load image
        img_path = os.path.join(self.img_dir, row['image_path'].replace('train/', '').replace('test/', ''))
        image = cv2.imread(img_path)

        if image is None:
            raise FileNotFoundError(f"Cannot load image: {img_path}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Split into left and right patches
        # Original image shape: [H, W, C] = [1000, 2000, 3]
        h, w, c = image.shape
        mid_w = w // 2

        left_patch = image[:, :mid_w, :]   # [1000, 1000, 3]
        right_patch = image[:, mid_w:, :]  # [1000, 1000, 3]

        # Convert to PIL Image for torchvision transforms
        left_pil = Image.fromarray(left_patch)
        right_pil = Image.fromarray(right_patch)

        # Apply transforms
        if self.transform:
            left_tensor = self.transform(left_pil)
            right_tensor = self.transform(right_pil)
        else:
            left_tensor = transforms.ToTensor()(left_pil)
            right_tensor = transforms.ToTensor()(right_pil)

        # Get tabular features
        tabular = torch.tensor(self.tabular_features[idx], dtype=torch.float32)

        # Prepare output
        output = {
            'left_image': left_tensor,
            'right_image': right_tensor,
            'tabular': tabular,
            'image_id': row['image_path'].split('/')[-1].replace('.jpg', '')
        }
        
        # Add targets if not test
        if not self.is_test:
            targets = row[self.target_cols].values.astype(np.float32)
            
            # Apply log transform if enabled
            if self.use_log_target:
                targets = np.log1p(targets)  # log1p handles zeros: log(1+0) = 0
            
            output['targets'] = torch.tensor(targets, dtype=torch.float32)
        
        return output

In [21]:
def calculate_img_data_stat(df: pd.DataFrame):
    """Calculate mean and std of image data for normalization."""
    means = []
    stds = []

    loader = tqdm(df['image_path'], desc="Calculating image stats")
    
    for img_path in loader:
        full_path = os.path.join(PATH_TRAIN_IMG, img_path.replace('train/', '').replace('test/', ''))
        image = cv2.imread(full_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0  # Normalize to [0, 1]
        
        means.append(np.mean(image, axis=(0, 1)))
        stds.append(np.std(image, axis=(0, 1)))
    
    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)
    
    return mean, std

In [None]:
# train_mean, train_std = calculate_img_data_stat(tabular_df)
# print(f"Train Image Mean: {train_mean}, Std: {train_std}")

Calculating image stats: 100%|██████████| 357/357 [00:48<00:00,  7.29it/s]

Train Image Mean: [0.44173591 0.50362967 0.30579783], Std: [0.2364247  0.23557117 0.22199257]





In [22]:
# Image backbone (processes each patch independently)
temp_backbone = timm.create_model(
    MODEL,
    pretrained=True,
    num_classes=0,  # remove classification head
    global_pool='avg'
)

temp_backbone.default_cfg


{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
 'hf_hub_id': 'timm/swinv2_tiny_window8_256.ms_in1k',
 'architecture': 'swinv2_tiny_window8_256',
 'tag': 'ms_in1k',
 'custom_load': False,
 'input_size': (3, 256, 256),
 'fixed_input_size': True,
 'interpolation': 'bicubic',
 'crop_pct': 0.9,
 'crop_mode': 'center',
 'mean': (0.485, 0.456, 0.406),
 'std': (0.229, 0.224, 0.225),
 'num_classes': 1000,
 'pool_size': (8, 8),
 'first_conv': 'patch_embed.proj',
 'classifier': 'head.fc',
 'license': 'mit'}

In [42]:
inputs_size = temp_backbone.default_cfg['input_size']
mean = temp_backbone.default_cfg['mean']
std = temp_backbone.default_cfg['std']

SIZE = int(inputs_size[1]) if inputs_size is not None and inputs_size[1] == inputs_size[2] else 256
print(f"Backbone expected input size: {inputs_size}, using SIZE={SIZE}")
print(f"Backbone expected mean: {mean}, std: {std}")

# Get backbone output dimension
with torch.no_grad():
    try:
        dummy = torch.randn(1, 3, SIZE, SIZE)
        feat_dim = temp_backbone(dummy).shape[1]
    except Exception as e:
        logger.error(f"Error getting backbone feature dimension: {e}")
        raise e

Backbone expected input size: (3, 256, 256), using SIZE=256
Backbone expected mean: (0.485, 0.456, 0.406), std: (0.229, 0.224, 0.225)


In [39]:
train_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),  # Increase to 90 for top-down view
    transforms.RandomApply([
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1)
    ], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1))  # Simulate occlusions
])

val_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

In [25]:
# Create dataset instance
train_dataset = BiomassDataset(
    df=tabular_df,
    tabular_features=tabular_data,
    target_cols=target_cols,
    img_dir=PATH_TRAIN_IMG,
    transform=train_transform,
    is_test=False
)

# Test it
sample = train_dataset[0]
print(f"Left image shape: {sample['left_image'].shape}")
print(f"Right image shape: {sample['right_image'].shape}")
print(f"Tabular shape: {sample['tabular'].shape}")
print(f"Targets shape: {sample['targets'].shape}")
print(f"Image ID: {sample['image_id']}")
print(f"Target values: {sample['targets']}")

Left image shape: torch.Size([3, 256, 256])
Right image shape: torch.Size([3, 256, 256])
Tabular shape: torch.Size([21])
Targets shape: torch.Size([3])
Image ID: ID1011485656
Target values: tensor([0.0000, 3.4965, 2.8493])


## Spliting Data (StratifiedKFold)

In [26]:
def get_season(date_str: str) -> str:
    """
    Convert date string to season.
    
    Args:
        date_str: Date in format 'YYYY/M/D' or 'YYYY/MM/DD'
    
    Returns:
        Season name: 'Summer', 'Autumn', 'Winter', 'Spring'
    """
    # Parse month from date string
    month = int(date_str.split('/')[1])
    
    # Australian seasons (Southern Hemisphere)
    if month in [12, 1, 2]:
        return 'Summer'
    elif month in [3, 4, 5]:
        return 'Autumn'
    elif month in [6, 7, 8]:
        return 'Winter'
    else:  # 9, 10, 11
        return 'Spring'

In [27]:
# Add season column
tabular_df['Season'] = tabular_df['Sampling_Date'].apply(get_season)

# Create stratification column combining Season, State, and Species
tabular_df['strat_group'] = (
    tabular_df['Season'].astype(str) + '_' +
    tabular_df['State'].astype(str) + '_' +
    tabular_df['Species'].astype(str)
)

print("Unique stratification groups:")
print(tabular_df['strat_group'].value_counts())
print(f"\nTotal groups: {tabular_df['strat_group'].nunique()}")

Unique stratification groups:
strat_group
Spring_Tas_Ryegrass_Clover                                                41
Winter_Vic_Phalaris_Clover                                                32
Autumn_Tas_Ryegrass                                                       22
Winter_Vic_Ryegrass_Clover                                                21
Spring_Tas_Clover                                                         21
Winter_WA_Clover                                                          20
Winter_Tas_Ryegrass_Clover                                                18
Autumn_NSW_Lucerne                                                        15
Spring_Vic_Ryegrass_Clover                                                11
Spring_Vic_Phalaris_BarleyGrass_SilverGrass_SpearGrass_Clover_Capeweed    11
Spring_NSW_Fescue                                                         11
Summer_NSW_Fescue_CrumbWeed                                               10
Spring_Vic_Phalaris_Clover        

In [28]:
# Initialize StratifiedKFold
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Get stratification labels
strat_labels = tabular_df['strat_group'].values

# Create fold assignments
tabular_df['fold'] = -1

for fold_idx, (train_idx, val_idx) in enumerate(skf.split(tabular_df, strat_labels)):
    tabular_df.loc[val_idx, 'fold'] = fold_idx
    
    print(f"\nFold {fold_idx + 1}:")
    print(f"  Train samples: {len(train_idx)}")
    print(f"  Val samples: {len(val_idx)}")



Fold 1:
  Train samples: 285
  Val samples: 72

Fold 2:
  Train samples: 285
  Val samples: 72

Fold 3:
  Train samples: 286
  Val samples: 71

Fold 4:
  Train samples: 286
  Val samples: 71

Fold 5:
  Train samples: 286
  Val samples: 71




In [29]:
# Verify stratification worked
print("Stratification verification:")

for fold in range(n_folds):
    fold_df = tabular_df[tabular_df['fold'] == fold]
    print(f"\nFold {fold + 1}:")
    print(f"  Season distribution:")
    print(fold_df['Season'].value_counts(normalize=True).round(3))
    print(f"  State distribution:")
    print(fold_df['State'].value_counts(normalize=True).round(3))
    print(f"  Species distribution:")
    print(fold_df['Species'].value_counts(normalize=True).round(3))

Stratification verification:

Fold 1:
  Season distribution:
Season
Winter    0.361
Spring    0.347
Autumn    0.167
Summer    0.125
Name: proportion, dtype: float64
  State distribution:
State
Tas    0.389
Vic    0.319
NSW    0.222
WA     0.069
Name: proportion, dtype: float64
  Species distribution:
Species
Ryegrass_Clover                                                0.292
Ryegrass                                                       0.167
Phalaris_Clover                                                0.111
Clover                                                         0.111
Fescue                                                         0.083
Lucerne                                                        0.056
Fescue_CrumbWeed                                               0.028
Phalaris                                                       0.028
WhiteClover                                                    0.028
Phalaris_Clover_Ryegrass_Barleygrass_Bromegrass                0.028


In [30]:
def get_fold_data(df: pd.DataFrame, fold: int):
    """
    Get train/val split for specific fold.
    
    Args:
        df: DataFrame with 'fold' column
        fold: Fold index to use as validation
    
    Returns:
        train_df, val_df
    """
    train_df = df[df['fold'] != fold].reset_index(drop=True)
    val_df = df[df['fold'] == fold].reset_index(drop=True)
    
    return train_df, val_df

In [35]:
def get_loaders(fold: int, bs: int) -> tuple[DataLoader, DataLoader]:
    """Get dataloaders for specific fold with proper preprocessing."""
    train_df, val_df = get_fold_data(tabular_df, fold)

    print(f"Training fold {fold}:")
    print(f"  Train size: {len(train_df)}")
    print(f"  Val size: {len(val_df)}")

    # Create NEW preprocessor for each fold
    fold_preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), num_features),
            ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features)
        ]
    )

    # FIT only on train, TRANSFORM both
    train_tabular = fold_preprocessor.fit_transform(train_df)
    val_tabular = fold_preprocessor.transform(val_df)  # Only transform!

    # Create datasets
    train_dataset = BiomassDataset(
        df=train_df,
        tabular_features=train_tabular,
        target_cols=target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=train_transform,
        is_test=False
    )

    val_dataset = BiomassDataset(
        df=val_df,
        tabular_features=val_tabular,
        target_cols=target_cols,
        img_dir=PATH_TRAIN_IMG,
        transform=val_transform,
        is_test=False
    )

    # Create dataloaders with num_workers
    train_loader = DataLoader(
        train_dataset,
        batch_size=bs,
        shuffle=True,
        num_workers=min(NUM_WORKERS, 8),  # Limit to avoid overhead
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=bs * 2,  # Larger batch for validation
        shuffle=False,
        num_workers=min(NUM_WORKERS, 8),
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Tabular features dimension: {train_tabular.shape[1]}")

    return train_loader, val_loader

## Ligtning Module

In [None]:
class BiomassTeacherModel(pl.LightningModule):
    """Teacher model for biomass prediction with dual-patch image + tabular features"""
    
    def __init__(
        self,
        backbone_name: str = 'swinv2_tiny_window8_256',
        tabular_dim: int = 10,
        num_targets: int = 3,
        lr: float = 1e-4,
        weight_decay: float = 1e-5,
        hidden_ratio: float = 0.5,
        dropout: float = 0.2,
        fusion_method: str = 'concat',
        use_log_target: bool = True
    ):
        """
        Args:
            backbone_name: timm model name
            tabular_dim: dimension of tabular features
            num_targets: number of regression targets
            lr: learning rate
            weight_decay: optimizer weight decay
            hidden_ratio: ratio of hidden layer size relative to feature dim
            dropout: dropout probability
            fusion_method: how to combine left/right features ('concat', 'mean', 'max')
            use_log_target: if True, model predicts log-transformed targets
        """
        super().__init__()
        self.save_hyperparameters()
        
        # Image backbone (processes each patch independently)
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=True,
            num_classes=0,
            global_pool='avg'
        )
        
        # Get backbone output dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, SIZE, SIZE)
            feat_dim = self.backbone(dummy).shape[1]
        
        self.feat_dim = feat_dim
        self.fusion_method = fusion_method
        self.use_log_target = use_log_target
        
        # Combined dimension depends on fusion method
        if self.fusion_method == 'concat':
            self.combined_dim = feat_dim * 2 + tabular_dim
        else:  # mean or max
            self.combined_dim = feat_dim + tabular_dim
        
        # Regression heads (NO activation at the end)
        hidden_size = max(32, int(self.combined_dim * hidden_ratio))
        
        def make_head():
            return nn.Sequential(
                nn.Linear(self.combined_dim, hidden_size),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(hidden_size, 1)
                # NO Softplus here!
            )
        
        self.head_green = make_head()
        self.head_clover = make_head()
        self.head_dead = make_head()
        
        logger.info(f"Model initialized: backbone={backbone_name}, feat_dim={feat_dim}, "
                   f"combined_dim={self.combined_dim}, fusion={fusion_method}, "
                   f"use_log_target={use_log_target}")
    
    def forward(self, batch: dict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Args:
            batch: dict with 'left_image', 'right_image', 'tabular'
        
        Returns:
            (green, clover, dead) predictions as separate tensors
            Note: predictions are in log-space if use_log_target=True
        """
        # Extract features from each patch
        left_feat = self.backbone(batch['left_image'])   # [B, D]
        right_feat = self.backbone(batch['right_image'])  # [B, D]
        
        # Fuse image features based on method
        if self.fusion_method == 'concat':
            img_feat = torch.cat([left_feat, right_feat], dim=1)  # [B, 2*D]
        elif self.fusion_method == 'mean':
            img_feat = (left_feat + right_feat) / 2  # [B, D]
        elif self.fusion_method == 'max':
            img_feat = torch.maximum(left_feat, right_feat)  # [B, D]
        else:
            raise ValueError(f"Unknown fusion method: {self.fusion_method}")
        
        # Combine with tabular features
        combined = torch.cat([img_feat, batch['tabular']], dim=1)  # [B, combined_dim]
        
        # Predict each target (raw predictions, no activation)
        green = self.head_green(combined).squeeze(1)
        clover = self.head_clover(combined).squeeze(1)
        dead = self.head_dead(combined).squeeze(1)
        
        return green, clover, dead
    
    def compute_loss(self, pred: tuple, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: (green, clover, dead) predictions
            target: [B, 3] ground truth targets [clover, dead, green]
        
        Returns:
            MSE loss
        """
        green_pred, clover_pred, dead_pred = pred
        clover_true = target[:, 0]  # Dry_Clover_g
        dead_true = target[:, 1]    # Dry_Dead_g
        green_true = target[:, 2]   # Dry_Green_g
        
        loss_green = F.mse_loss(green_pred, green_true)
        loss_clover = F.mse_loss(clover_pred, clover_true)
        loss_dead = F.mse_loss(dead_pred, dead_true)
        
        total_loss = loss_green + loss_clover + loss_dead
        
        return total_loss
    
    def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        pred = self(batch)
        loss = self.compute_loss(pred, batch['targets'])
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True,
                 batch_size=batch['targets'].size(0))
        
        return loss
    
    def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        pred = self(batch)
        loss = self.compute_loss(pred, batch['targets'])
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True,
                 batch_size=batch['targets'].size(0))
        
        return loss
    
    def predict_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        """
        Prediction with proper post-processing
        
        Returns:
            Predictions in original scale (not log-transformed), clamped to [0, inf)
        """
        green, clover, dead = self(batch)
        
        # Stack predictions
        preds = torch.stack([clover, dead, green], dim=1)  # [B, 3]
        
        # If using log targets, convert back to original scale
        if self.use_log_target:
            preds = torch.expm1(preds)  # expm1(x) = exp(x) - 1, inverse of log1p
        
        # Clamp to ensure non-negative values
        preds = torch.clamp(preds, min=0.0)
        
        return preds
    
    def configure_optimizers(self):
        optimizer = AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=self.hparams.lr * 0.01
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'  # FIXME: 'step' if needed
            }
        }

## Folds Training

In [None]:
# Train on all folds
fold_results = []

for fold_id in range(N_FOLDS):
    train_loader, val_loader = get_loaders(fold=fold_id, bs=BATCH_SIZE)
    
    model = BiomassTeacherModel(
        backbone_name=MODEL,
        tabular_dim=train_loader.dataset.tabular_features.shape[1],
        num_targets=len(target_cols),
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        hidden_ratio=HIDDEN_RATIO,
        dropout=DROPOUT_RATE,
        fusion_method=FUSION_METHOD,
        use_log_target=USE_LOG_TARGET
    )
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=os.path.join(CHECKPOINTS_DIR, f'fold{fold_id}'),
        filename=f'{DESCRIPTION_FULL}-fold{fold_id}' + '-{epoch:02d}-{val_loss:.4f}',
        save_top_k=3,  # Save top 3 instead of 1
        mode='min'
    )
    
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=5,
        mode='min',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Logger
    wandb_logger = WandbLogger(
        project=PROJECT_NAME,
        name=f'{DESCRIPTION_FULL}-fold{fold_id}',
        log_model='all'
    )
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        accelerator=DEVICE.type,
        precision='16-mixed' if torch.cuda.is_available() else 32,
        accumulate_grad_batches=GRAD_ACCUM,
        callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor],
        logger=wandb_logger,
        log_every_n_steps=1,
        gradient_clip_val=1.0,
        enable_progress_bar=True
    )
    
    # Train
    trainer.fit(model, train_loader, val_loader)
    
    # Load best checkpoint
    best_model_path = checkpoint_callback.best_model_path
    logger.info(f"Loading best model from: {best_model_path}")
    best_model = BiomassTeacherModel.load_from_checkpoint(best_model_path)
    
    # Evaluate on validation set
    val_result = trainer.validate(best_model, val_loader, verbose=False)
    fold_results.append({
        'fold': fold_id,
        'val_loss': val_result[0]['val_loss']
    })
    
    wandb_logger.experiment.finish()

Training fold 0:
  Train size: 285
  Val size: 72
Train batches: 18
Val batches: 3
Tabular features dimension: 21
[38;2;161;247;255m
[2025-12-11 01:31:15]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:400: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:881: Checkpoint directory C:\_GitHub\CSIRO-Image2Biomass-Prediction\notebooks\kaggle\checkpoints\teacher\fold0 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name        | Type     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 3.395


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 1.404 >= min_delta = 0.0. New best score: 1.991


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.270 >= min_delta = 0.0. New best score: 1.721


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.187 >= min_delta = 0.0. New best score: 1.534


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.130 >= min_delta = 0.0. New best score: 1.404


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.245 >= min_delta = 0.0. New best score: 1.159


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.125 >= min_delta = 0.0. New best score: 1.034


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.021 >= min_delta = 0.0. New best score: 1.013


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.046 >= min_delta = 0.0. New best score: 0.967


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Training fold 1:
  Train size: 285
  Val size: 72
Train batches: 18
Val batches: 3
Tabular features dimension: 21
[38;2;161;247;255m
[2025-12-11 01:48:34]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:400: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name        | Type              | Params | Mode  | FLOPs
------------------------------------------------------------------
0 | backbone    | SwinTransformerV2 | 27.6 M | train | 0    
1 | head_green  | Sequential        | 311 K  | train | 0    
2 | head_clover | Sequentia

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 2.524


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.404 >= min_delta = 0.0. New best score: 2.120


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.477 >= min_delta = 0.0. New best score: 1.644


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.113 >= min_delta = 0.0. New best score: 1.530


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.411 >= min_delta = 0.0. New best score: 1.119


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.046 >= min_delta = 0.0. New best score: 1.073


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 1.053


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.053 >= min_delta = 0.0. New best score: 1.000


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 0.987


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.061 >= min_delta = 0.0. New best score: 0.926


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.926. Signaling Trainer to stop.
`Trainer.fit` stopped: `max_epochs=20` reached.


Training fold 2:
  Train size: 286
  Val size: 71
Train batches: 18
Val batches: 3
Tabular features dimension: 21
[38;2;161;247;255m
[2025-12-11 02:03:31]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:400: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name        | Type              | Params | Mode  | FLOPs
------------------------------------------------------------------
0 | backbone    | SwinTransformerV2 | 27.6 M | train | 0    
1 | head_green  | Sequential        | 311 K  | train | 0    
2 | head_clover | Sequentia

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 3.244


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.939 >= min_delta = 0.0. New best score: 2.305


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.465 >= min_delta = 0.0. New best score: 1.840


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.166 >= min_delta = 0.0. New best score: 1.675


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.219 >= min_delta = 0.0. New best score: 1.456


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.141 >= min_delta = 0.0. New best score: 1.315


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 1.293


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.170 >= min_delta = 0.0. New best score: 1.124


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 1.113


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 1.094


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.051 >= min_delta = 0.0. New best score: 1.042


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


Training fold 3:
  Train size: 286
  Val size: 71
Train batches: 18
Val batches: 3
Tabular features dimension: 21
[38;2;161;247;255m
[2025-12-11 02:18:18]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:400: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name        | Type              | Params | Mode  | FLOPs
------------------------------------------------------------------
0 | backbone    | SwinTransformerV2 | 27.6 M | train | 0    
1 | head_green  | Sequential        | 311 K  | train | 0    
2 | head_clover | Sequentia

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 4.125


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 1.392 >= min_delta = 0.0. New best score: 2.733


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.394 >= min_delta = 0.0. New best score: 2.339


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.581 >= min_delta = 0.0. New best score: 1.758


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.067 >= min_delta = 0.0. New best score: 1.691


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.087 >= min_delta = 0.0. New best score: 1.604


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.047 >= min_delta = 0.0. New best score: 1.557


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.152 >= min_delta = 0.0. New best score: 1.405


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.072 >= min_delta = 0.0. New best score: 1.333


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.032 >= min_delta = 0.0. New best score: 1.301


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.078 >= min_delta = 0.0. New best score: 1.223


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 1.221
`Trainer.fit` stopped: `max_epochs=20` reached.


Training fold 4:
  Train size: 286
  Val size: 71
Train batches: 18
Val batches: 3
Tabular features dimension: 21
[38;2;161;247;255m
[2025-12-11 02:33:13]
INFO: Model initialized: backbone=swinv2_tiny_window8_256, feat_dim=768, combined_dim=789, fusion=mean, use_log_target=True[0m


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\loggers\wandb.py:400: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\utilities\model_summary\model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name        | Type              | Params | Mode  | FLOPs
------------------------------------------------------------------
0 | backbone    | SwinTransformerV2 | 27.6 M | train | 0    
1 | head_green  | Sequential        | 311 K  | train | 0    
2 | head_clover | Sequentia

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
c:\_GitHub\CSIRO-Image2Biomass-Prediction\.venv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 2.951


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 1.061 >= min_delta = 0.0. New best score: 1.890


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.219 >= min_delta = 0.0. New best score: 1.671


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.165 >= min_delta = 0.0. New best score: 1.506


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.026 >= min_delta = 0.0. New best score: 1.480


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 1.474


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.080 >= min_delta = 0.0. New best score: 1.394


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.108 >= min_delta = 0.0. New best score: 1.286


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 1.269


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 1.251


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.180 >= min_delta = 0.0. New best score: 1.071


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 1.051


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 1.046
`Trainer.fit` stopped: `max_epochs=20` reached.


In [None]:
print("Training Summary")
print()
results_df = pd.DataFrame(fold_results)
print(results_df)
print(f"Mean Val Loss: {results_df['val_loss'].mean():.4f} ± {results_df['val_loss'].std():.4f}")