# Distillation. Teacher

## Imports

In [92]:
import os
import pandas as pd
import numpy as np
from typing import cast
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from PIL import Image

import torch
import pytorch_lightning as pl
import cv2
import albumentations as A
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2

import warnings
from notebooks_config import setup_logging, CustomLogger

warnings.simplefilter(action='ignore', category=FutureWarning)
print(f"PyTorch: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

PyTorch: 2.9.1+cu128
Device: NVIDIA GeForce RTX 5050 Laptop GPU


In [42]:
logger = setup_logging(full_color=True, include_function=False)
logger = cast(CustomLogger, logger)  # Type hinting
logger.success("Logging configuration test completed.")

[38;2;105;254;105m
[2025-12-10 23:16:41]
SUCCESS: Logging configured successfully âœ…[0m
[38;2;105;254;105m
[2025-12-10 23:16:41]
SUCCESS: Logging configuration test completed.[0m


## Data Loading and Preprocessing

In [16]:
PATH_DATA = './kaggle/input/csiro-biomass'
PATH_TRAIN_CSV = os.path.join(PATH_DATA, 'train.csv')
PATH_TEST_CSV = os.path.join(PATH_DATA, 'test.csv')
PATH_TRAIN_IMG = os.path.join(PATH_DATA, 'train')
PATH_TEST_IMG = os.path.join(PATH_DATA, 'test')

df = pd.read_csv(PATH_TRAIN_CSV)
df = df[~df['target_name'].isin(['Dry_Total_g', 'GDM_g'])]  # Remove unneeded targets
print(f"Dataset size: {df.shape}")
display(df.head())

Dataset size: (1071, 9)


Unnamed: 0,sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
0,ID1011485656__Dry_Clover_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Clover_g,0.0
1,ID1011485656__Dry_Dead_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Dead_g,31.9984
2,ID1011485656__Dry_Green_g,train/ID1011485656.jpg,2015/9/4,Tas,Ryegrass_Clover,0.62,4.6667,Dry_Green_g,16.275
5,ID1012260530__Dry_Clover_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Clover_g,0.0
6,ID1012260530__Dry_Dead_g,train/ID1012260530.jpg,2015/4/1,NSW,Lucerne,0.55,16.0,Dry_Dead_g,0.0


In [37]:
# pivot the dataframe to have one row per image with multiple target columns
tabular_df = df.pivot_table(index=['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI'],
                              columns='target_name', values='target', aggfunc='first').reset_index()
tabular_df.columns.name = None  # remove the aggregation name
print(tabular_df.shape)
print(tabular_df.head())

(357, 9)
               image_path Sampling_Date State            Species  \
0  train/ID1011485656.jpg      2015/9/4   Tas    Ryegrass_Clover   
1  train/ID1012260530.jpg      2015/4/1   NSW            Lucerne   
2  train/ID1025234388.jpg      2015/9/1    WA  SubcloverDalkeith   
3  train/ID1028611175.jpg     2015/5/18   Tas           Ryegrass   
4  train/ID1035947949.jpg     2015/9/11   Tas           Ryegrass   

   Height_Ave_cm  Pre_GSHH_NDVI  Dry_Clover_g  Dry_Dead_g  Dry_Green_g  
0         4.6667           0.62        0.0000     31.9984      16.2750  
1        16.0000           0.55        0.0000      0.0000       7.6000  
2         1.0000           0.38        6.0500      0.0000       0.0000  
3         5.0000           0.66        0.0000     30.9703      24.2376  
4         3.5000           0.54        0.4343     23.2239      10.5261  


In [30]:
print(tabular_df.columns.tolist())

['image_path', 'Sampling_Date', 'State', 'Species', 'Height_Ave_cm', 'Pre_GSHH_NDVI', 'Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']


In [31]:
target_cols  = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g']
num_features = ['Height_Ave_cm', 'Pre_GSHH_NDVI']
cat_features = ['Species', 'State']

In [39]:
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), num_features), # normalizing numeric features
        ('cat', OneHotEncoder(sparse_output=False, handle_unknown='ignore'), cat_features) # OHE for categorical features
    ]
)

In [33]:
tabular_data = preprocessor.fit_transform(tabular_df)

In [36]:
print(tabular_data.shape)
display(tabular_data[:5])

(357, 21)


array([[-0.28520388, -0.24631873,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
         0.        ],
       [ 0.81823967, -0.70706013,  0.        ,  0.        ,  0.        ,
         1.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  1.        ,  0.        ,  0.        ,
         0.        ],
       [-0.64220462, -1.82600352,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         1.        ],
       [-0.25275281,  0.01696207,  0.        ,  0.        

## Dataset

In [91]:
class BiomassDataset(Dataset):
    """Dataset for biomass prediction with image + tabular features"""
    
    def __init__(
        self, 
        df: pd.DataFrame,
        tabular_features: np.ndarray,
        target_cols: list[str],
        img_dir: str,
        transform: transforms.Compose | None = None,
        is_test: bool = False
    ):
        """
        Args:
            df: DataFrame with image_id, image_path, and target columns
            tabular_features: Preprocessed tabular features (shape: [n_samples, n_features])
            target_cols: List of target column names
            img_dir: Root directory for images
            transform: torchvision transform pipeline
            is_test: If True, targets are not expected in df
        """
        self.df = df.reset_index(drop=True)
        self.tabular_features = tabular_features
        self.target_cols = target_cols
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        
        assert len(self.df) == len(self.tabular_features), \
            f"DataFrame length {len(self.df)} != tabular features length {len(self.tabular_features)}"
    
    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> dict:
        """
        Returns:
            dict with keys:
                - 'left_image': tensor [C, H, W]
                - 'right_image': tensor [C, H, W]
                - 'tabular': tensor [n_features]
                - 'targets': tensor [n_targets] (if not test)
                - 'image_id': str
        """
        row = self.df.iloc[idx]
        
        # Load image
        img_path = os.path.join(self.img_dir, row['image_path'].replace('train/', '').replace('test/', ''))
        image = cv2.imread(img_path)
        
        if image is None:
            raise FileNotFoundError(f"Cannot load image: {img_path}")
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Split into left and right patches
        # Original image shape: [H, W, C] = [1000, 2000, 3]
        h, w, c = image.shape
        mid_w = w // 2
        
        left_patch = image[:, :mid_w, :]   # [1000, 1000, 3]
        right_patch = image[:, mid_w:, :]  # [1000, 1000, 3]
        
        # Convert to PIL Image for torchvision transforms
        left_pil = Image.fromarray(left_patch)
        right_pil = Image.fromarray(right_patch)
        
        # Apply transforms
        if self.transform:
            left_tensor = self.transform(left_pil)
            right_tensor = self.transform(right_pil)
        else:
            left_tensor = transforms.ToTensor()(left_pil)
            right_tensor = transforms.ToTensor()(right_pil)
        
        # Get tabular features
        tabular = torch.tensor(self.tabular_features[idx], dtype=torch.float32)
        
        # Prepare output
        output = {
            'left_image': left_tensor,
            'right_image': right_tensor,
            'tabular': tabular,
            'image_id': row['image_path'].split('/')[-1].replace('.jpg', '')
        }
        
        # Add targets if not test
        if not self.is_test:
            targets = torch.tensor(
                row[self.target_cols].values.astype(np.float32),
                dtype=torch.float32
            )
            output['targets'] = targets
        
        return output

In [83]:
def calculate_img_data_stat(df: pd.DataFrame):
    """Calculate mean and std of image data for normalization."""
    means = []
    stds = []

    loader = tqdm(df['image_path'], desc="Calculating image stats")
    
    for img_path in loader:
        full_path = os.path.join(PATH_TRAIN_IMG, img_path.replace('train/', '').replace('test/', ''))
        image = cv2.imread(full_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image / 255.0  # Normalize to [0, 1]
        
        means.append(np.mean(image, axis=(0, 1)))
        stds.append(np.std(image, axis=(0, 1)))
    
    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)
    
    return mean, std

In [None]:
train_mean, train_std = calculate_img_data_stat(tabular_df)
print(f"Train Image Mean: {train_mean}, Std: {train_std}")

In [93]:
# Each patch is 1000x1000, resize to 768x768 for vision transformers
size = 768

train_transform = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [94]:
# Create dataset instance
train_dataset = BiomassDataset(
    df=tabular_df,
    tabular_features=tabular_data,
    target_cols=target_cols,
    img_dir=PATH_TRAIN_IMG,
    transform=train_transform,
    is_test=False
)

# Test it
sample = train_dataset[0]
print(f"Left image shape: {sample['left_image'].shape}")
print(f"Right image shape: {sample['right_image'].shape}")
print(f"Tabular shape: {sample['tabular'].shape}")
print(f"Targets shape: {sample['targets'].shape}")
print(f"Image ID: {sample['image_id']}")
print(f"Target values: {sample['targets']}")

Left image shape: torch.Size([3, 768, 768])
Right image shape: torch.Size([3, 768, 768])
Tabular shape: torch.Size([21])
Targets shape: torch.Size([3])
Image ID: ID1011485656
Target values: tensor([ 0.0000, 31.9984, 16.2750])
