## Baseline Submission (End-to-End ResNet and ViT)
- `Input`: Images
- `Output`: Target Biomass Predictions by Target Names.

In [43]:
import polars as pl
from PIL import Image
from pathlib import Path

import torch
from torchvision.models import (
    resnet34, ResNet34_Weights, 
    resnet50, ResNet50_Weights, 
    vit_b_16, ViT_B_16_Weights,
    vit_b_32, ViT_B_32_Weights
)

ROOT_PATH = Path(".")
DATASET_PATH = ROOT_PATH / "csiro-biomass"

**Loading the Dataset**

In [30]:
# Loading the Dataset.
train_df = pl.read_csv(DATASET_PATH / "train.csv")
train_df.head()

sample_id,image_path,Sampling_Date,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target
str,str,str,str,str,f64,f64,str,f64
"""ID1011485656__Dry_Clover_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Clover_g""",0.0
"""ID1011485656__Dry_Dead_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Dead_g""",31.9984
"""ID1011485656__Dry_Green_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Green_g""",16.2751
"""ID1011485656__Dry_Total_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Total_g""",48.2735
"""ID1011485656__GDM_g""","""train/ID1011485656.jpg""","""2015/9/4""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""GDM_g""",16.275


**Data Preprocessing**

In [31]:
# Converting the Sampling Date Series
train_df_dt = train_df.with_columns(
    pl.col("Sampling_Date").str.to_date().alias("Sampling_Date")
)

# Extracting the Date, Month and Year
train_df_dt = train_df_dt.with_columns(
    pl.col("Sampling_Date").dt.day().alias("Sample_Day"),
    pl.col("Sampling_Date").dt.month().alias("Sample_Month"),
    pl.col("Sampling_Date").dt.year().alias("Sample_Year")
)

# Extracting the Seasons as per Australia
train_df_seasons = train_df_dt.with_columns(
    pl.when(pl.col("Sample_Month").is_in([12, 1, 2])).then(pl.lit("Summer"))
    .when(pl.col("Sample_Month").is_in([3, 4, 5])).then(pl.lit("Autumn"))
    .when(pl.col("Sample_Month").is_in([6, 7, 8])).then(pl.lit("Winter"))
    .when(pl.col("Sample_Month").is_in([9, 10, 11])).then(pl.lit("Spring"))
    .otherwise(pl.lit("unknown")).alias("Sample_Season")
)

In [32]:
# Creating a uniform image_id
processed_df = train_df_seasons.with_columns(
    pl.col("sample_id").str.split("__").list.get(0).alias("image_id")
)

# Dropping the unnecessary columns
processed_df = processed_df.drop("sample_id", "Sampling_Date")
processed_df

image_path,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,target_name,target,Sample_Day,Sample_Month,Sample_Year,Sample_Season,image_id
str,str,str,f64,f64,str,f64,i8,i8,i32,str,str
"""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Clover_g""",0.0,4,9,2015,"""Spring""","""ID1011485656"""
"""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Dead_g""",31.9984,4,9,2015,"""Spring""","""ID1011485656"""
"""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Green_g""",16.2751,4,9,2015,"""Spring""","""ID1011485656"""
"""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""Dry_Total_g""",48.2735,4,9,2015,"""Spring""","""ID1011485656"""
"""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,"""GDM_g""",16.275,4,9,2015,"""Spring""","""ID1011485656"""
…,…,…,…,…,…,…,…,…,…,…,…
"""train/ID983582017.jpg""","""WA""","""Ryegrass""",0.64,9.0,"""Dry_Clover_g""",0.0,1,9,2015,"""Spring""","""ID983582017"""
"""train/ID983582017.jpg""","""WA""","""Ryegrass""",0.64,9.0,"""Dry_Dead_g""",0.0,1,9,2015,"""Spring""","""ID983582017"""
"""train/ID983582017.jpg""","""WA""","""Ryegrass""",0.64,9.0,"""Dry_Green_g""",40.94,1,9,2015,"""Spring""","""ID983582017"""
"""train/ID983582017.jpg""","""WA""","""Ryegrass""",0.64,9.0,"""Dry_Total_g""",40.94,1,9,2015,"""Spring""","""ID983582017"""


In [35]:
# Pivoting the Table for Unifying the Rows of each Image.
indexes = [
    "image_id", "image_path", "State", "Species", "Pre_GSHH_NDVI", "Height_Ave_cm", 
    "Sample_Day", "Sample_Month", "Sample_Year", "Sample_Season"
]

wide_df = processed_df.pivot(on="target_name", values="target", index=indexes)
wide_df

image_id,image_path,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,Sample_Day,Sample_Month,Sample_Year,Sample_Season,Dry_Clover_g,Dry_Dead_g,Dry_Green_g,Dry_Total_g,GDM_g
str,str,str,str,f64,f64,i8,i8,i32,str,f64,f64,f64,f64,f64
"""ID1011485656""","""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,4,9,2015,"""Spring""",0.0,31.9984,16.2751,48.2735,16.275
"""ID1012260530""","""train/ID1012260530.jpg""","""NSW""","""Lucerne""",0.55,16.0,1,4,2015,"""Autumn""",0.0,0.0,7.6,7.6,7.6
"""ID1025234388""","""train/ID1025234388.jpg""","""WA""","""SubcloverDalkeith""",0.38,1.0,1,9,2015,"""Spring""",6.05,0.0,0.0,6.05,6.05
"""ID1028611175""","""train/ID1028611175.jpg""","""Tas""","""Ryegrass""",0.66,5.0,18,5,2015,"""Autumn""",0.0,30.9703,24.2376,55.2079,24.2376
"""ID1035947949""","""train/ID1035947949.jpg""","""Tas""","""Ryegrass""",0.54,3.5,11,9,2015,"""Spring""",0.4343,23.2239,10.5261,34.1844,10.9605
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""ID975115267""","""train/ID975115267.jpg""","""WA""","""Clover""",0.73,3.0,8,7,2015,"""Winter""",40.03,0.0,0.8,40.83,40.83
"""ID978026131""","""train/ID978026131.jpg""","""Tas""","""Clover""",0.83,3.1667,4,9,2015,"""Spring""",24.6445,4.1948,12.0601,40.8994,36.7046
"""ID980538882""","""train/ID980538882.jpg""","""NSW""","""Phalaris""",0.69,29.0,24,2,2015,"""Summer""",0.0,1.1457,91.6543,92.8,91.6543
"""ID980878870""","""train/ID980878870.jpg""","""WA""","""Clover""",0.74,2.0,8,7,2015,"""Winter""",32.3575,0.0,2.0325,34.39,34.39


**Feature Engineering**

In [36]:
# Binary Cols for Species name based target Contribution.
wide_df = wide_df.with_columns(
    pl
    .when(pl.col("Species").str.contains("Clover")).then(pl.lit(True))
    .when(pl.col("Species").str.contains("clover")).then(pl.lit(True))
    .otherwise(pl.lit(False)).alias("has_Clover")
)

# Transforming the Categorical Cols
# wide_df = wide_df.to_dummies(columns=["State", "Sample_Season"])

wide_df

image_id,image_path,State,Species,Pre_GSHH_NDVI,Height_Ave_cm,Sample_Day,Sample_Month,Sample_Year,Sample_Season,Dry_Clover_g,Dry_Dead_g,Dry_Green_g,Dry_Total_g,GDM_g,has_Clover
str,str,str,str,f64,f64,i8,i8,i32,str,f64,f64,f64,f64,f64,bool
"""ID1011485656""","""train/ID1011485656.jpg""","""Tas""","""Ryegrass_Clover""",0.62,4.6667,4,9,2015,"""Spring""",0.0,31.9984,16.2751,48.2735,16.275,true
"""ID1012260530""","""train/ID1012260530.jpg""","""NSW""","""Lucerne""",0.55,16.0,1,4,2015,"""Autumn""",0.0,0.0,7.6,7.6,7.6,false
"""ID1025234388""","""train/ID1025234388.jpg""","""WA""","""SubcloverDalkeith""",0.38,1.0,1,9,2015,"""Spring""",6.05,0.0,0.0,6.05,6.05,true
"""ID1028611175""","""train/ID1028611175.jpg""","""Tas""","""Ryegrass""",0.66,5.0,18,5,2015,"""Autumn""",0.0,30.9703,24.2376,55.2079,24.2376,false
"""ID1035947949""","""train/ID1035947949.jpg""","""Tas""","""Ryegrass""",0.54,3.5,11,9,2015,"""Spring""",0.4343,23.2239,10.5261,34.1844,10.9605,false
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""ID975115267""","""train/ID975115267.jpg""","""WA""","""Clover""",0.73,3.0,8,7,2015,"""Winter""",40.03,0.0,0.8,40.83,40.83,true
"""ID978026131""","""train/ID978026131.jpg""","""Tas""","""Clover""",0.83,3.1667,4,9,2015,"""Spring""",24.6445,4.1948,12.0601,40.8994,36.7046,true
"""ID980538882""","""train/ID980538882.jpg""","""NSW""","""Phalaris""",0.69,29.0,24,2,2015,"""Summer""",0.0,1.1457,91.6543,92.8,91.6543,false
"""ID980878870""","""train/ID980878870.jpg""","""WA""","""Clover""",0.74,2.0,8,7,2015,"""Winter""",32.3575,0.0,2.0325,34.39,34.39,true


**Data Preparation**

In [None]:
class CSIRODataset(torch.utils.data.Dataset):
    """Custom Dataset for building the PyTorch friendly data pipeline."""
    def __init__(self, dataframe: pl.DataFrame, transforms=None) -> None:
        """
        args:
            dataframe: the unique wide-format dataframe.
            transforms: provisional image level and tabular data transforms.
        """
        self.transforms = transforms

        # Indexes
        self.indexes = dataframe.select(pl.col("image_id"))

        # Image Paths
        self.image_paths = dataframe.select(pl.col("image_path"))

        # Target Cols
        target_cols = ["Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"]
        self.targets = dataframe.select(pl.col(*target_cols))

    def __len__(self) -> int:
        return len(self.image_paths)
    
    def __getitem__(self, idx: int) -> dict:

        # Access the image path
        relative_path = self.image_paths[idx]
        img_path = ROOT_PATH / relative_path #type: ignore

        # Load the image
        image = Image.open(img_path).convert("RGB")

        # Access the targets
        targets = self.targets[idx]
        targets_tensor = torch.tensor(targets, dtype=torch.float32)

        # Transforms
        if self.transforms:
            image = self.transforms(image)

        return {"image": image, "targets": targets_tensor}        

In [None]:
class DataHandler:
    """Constructs the Loaded Dataset based on the split percentage."""

    def __init__(self, dataframe: pl.DataFrame, split_percent: float = 0.1) -> None:

        # Accessing the dataframe for construction
        self.dataframe = dataframe

        # The length of the Validation Set.
        valid_len = int(split_percent * len(self.dataframe))

        # The training and validation indices
        full_idx = torch.randperm(len(dataframe))
        self.train_idx = full_idx[:-valid_len]
        self.valid_idx = full_idx[-valid_len:]

    def construct_dataset(self, batch_size: int = 64) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
        train_dataset = CSIRODataset(self.dataframe)


In [None]:
# Accessing only the required information for the end-to-end baseline
X = wide_df.select(pl.col("image_id", "image_path"))
y = wide_df.select(pl.col("Dry_Clover_g", "Dry_Dead_g", "Dry_Green_g", "Dry_Total_g", "GDM_g"))

**Modelling**

In [None]:
# Model Initialisation
resnet_small = resnet34(weights=ResNet34_Weights.DEFAULT)
resnet_small_preprocessing = ResNet34_Weights.DEFAULT.transforms()

vit_small = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
vit_small_preprocessing = ViT_B_16_Weights.DEFAULT.transforms()

Dry_Clover_g,Dry_Dead_g,Dry_Green_g,Dry_Total_g,GDM_g
f64,f64,f64,f64,f64
0.0,31.9984,16.2751,48.2735,16.275
0.0,0.0,7.6,7.6,7.6
6.05,0.0,0.0,6.05,6.05
0.0,30.9703,24.2376,55.2079,24.2376
0.4343,23.2239,10.5261,34.1844,10.9605
