In [None]:
# I would like to thank @takaito for his generous sharing (https://www.kaggle.com/code/takaito/csiro-img2bio-training-notebook)

import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import random
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

In [None]:
# ======== Load Datafile ========
train_df = pd.read_csv('/kaggle/input/csiro-biomass/train.csv')
pivot_df = train_df.pivot(
    index='image_path',
    columns='target_name',
    values='target'
).reset_index()
meta_cols = ["image_path", "State", "Pre_GSHH_NDVI", "Height_Ave_cm"]
other_cols = train_df[meta_cols].drop_duplicates('image_path').set_index('image_path')
train_df = pivot_df.merge(other_cols, on='image_path', how='left')
dummies = pd.get_dummies(train_df["State"], prefix="State", dummy_na=False).astype(int)
train_df = pd.concat([train_df.drop(columns=["State"]), dummies], axis=1)
    
train_df.head()

In [None]:
# ======== Seed ========
def seed_everything(seed: int = 114514):
    pl.seed_everything(seed, workers=True)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    # 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(114514)

In [None]:
# ======== Dataset ========
class ImageRegressionDataset(Dataset):
    def __init__(self, df, full_image_aug=None, per_half_transform=None):
        self.df = df
        self.full_image_aug = full_image_aug
        self.per_half_transform = per_half_transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open('/kaggle/input/csiro-biomass/' + row["image_path"]).convert("RGB")
    
        base = self.full_image_aug(image) if self.full_image_aug is not None else image
    
        left = base.crop((0, 0, 1000, 1000))
        right = base.crop((1000, 0, 2000, 1000))
    
        if self.per_half_transform is not None:
            left = self.per_half_transform(left)
            right = self.per_half_transform(right)
        else:
            left = F.to_tensor(left)
            right = F.to_tensor(right)
    
        imgs = torch.stack([left, right], dim=0)   # (2, C, H, W)
    
        targets = torch.tensor([
            row["Dry_Green_g"],
            row["Dry_Dead_g"],
            row["Dry_Clover_g"],
            row["GDM_g"],
            row["Dry_Total_g"]
        ], dtype=torch.float32)
    
        return imgs, targets

In [None]:
# ======== Loss Funtion ========
class WeightedR2Loss(nn.Module):
    def __init__(self, weights=None, eps=1e-8):
        super().__init__()
        if weights is None:
            weights = torch.tensor([0.1, 0.1, 0.1, 0.2, 0.5])
        self.register_buffer('weights', weights)
        self.eps = eps

    def forward(self, y_pred, y_true):
        """
        y_pred: (B, 3)
        y_true: (B, 5)
        """
        DG = y_pred[:,0]
        GDM = y_pred[:,1]
        DT = y_pred[:,2]
        DD = DT - GDM
        DC = GDM - DG
        y_hat = torch.stack([DG, DD, DC, GDM, DT], dim=1)
        
        y_mean = torch.mean(y_true, dim=0, keepdim=True)
        ss_res = torch.sum((y_true - y_hat) ** 2, dim=0)
        ss_tot = torch.sum((y_true - y_mean) ** 2, dim=0)
        r2 = 1 - ss_res / (ss_tot + self.eps)

        weighted_r2 = torch.sum(self.weights * r2)
        loss = 1 - weighted_r2

        return loss

In [None]:
# ======== Model ========
class MultiRegressionModel(pl.LightningModule):
    def __init__(
        self,
        model_name="efficientnet_b2",
        pretrained=False,
        lr=5e-3,
        output_dim=3,
        hidden_dim=1536,
        open_last_n_blocks=1,
        train_bn=False,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.backbone = timm.create_model(
            model_name, 
            pretrained=pretrained, 
            num_classes=0, 
            global_pool="avg"
        )
        feat_dim = getattr(self.backbone, "num_features", None)
        if feat_dim is None:
            x_dummy = torch.zeros(1, 3, 1000, 1000)
            with torch.no_grad():
                feat_dim = self.backbone(x_dummy).shape[-1]
        self.head = nn.Sequential(
            nn.Linear(feat_dim * 2, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim),
        )

        self.criterion = WeightedR2Loss()
        self.val_outputs = []
        self._freeze_backbone(open_last_n_blocks=open_last_n_blocks, train_bn=train_bn)

    # Helpers
    def _set_requires_grad(self, module, flag: bool):
        for p in module.parameters():
            p.requires_grad = flag
            
    def _freeze_backbone(self, open_last_n_blocks=1, train_bn=False):
        self._set_requires_grad(self.backbone, False)

        if hasattr(self.backbone, "blocks"):
            blocks = self.backbone.blocks
            n = len(blocks)
            assert open_last_n_blocks >= 1 and open_last_n_blocks <= n, \
                f"Expect open_last_n_blocks between [1, {n}], get {open_last_n_blocks}"
            for blk in blocks[n - open_last_n_blocks:]:
                self._set_requires_grad(blk, True)

        if not train_bn:
            for m in self.backbone.modules():
                if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                    if all(not p.requires_grad for p in m.parameters()):
                        m.eval()

        self._set_requires_grad(self.head, True)


    def forward(self, x_imgs):
        if x_imgs.ndim == 5:
            B, T, C, H, W = x_imgs.shape
            assert T == 2, f"Expect 2 pics, get T={T}"
            x = x_imgs.view(B * T, C, H, W)
            feats = self.backbone(x)
            F = feats.shape[-1]
            feats = feats.view(B, T, F)
            fused = torch.cat([feats[:, 0, :], feats[:, 1, :]], dim=1)
        elif x_imgs.ndim == 4:
            T, C, H, W = x_imgs.shape
            assert T == 2, f"Expect 2 pics, get T={T}"
            feats = self.backbone(x_imgs)
            fused = torch.cat([feats[0], feats[1]], dim=0).unsqueeze(0)
        else:
            raise ValueError("SHAPE should be train/val-(B, 2, C, H, W) or single-sample val-(2, C, H, W)")

        y_hat = self.head(fused)
        return y_hat

    def training_step(self, batch, batch_idx):
        x_img, y = batch
        y_hat = self(x_img)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x_img, y = batch
        y_hat = self(x_img)
        loss = self.criterion(y_hat, y)
        self.val_outputs.append((y_hat.detach().cpu(), y.detach().cpu()))
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        return loss

    def on_validation_epoch_end(self):
        y_hats = torch.cat([x[0] for x in self.val_outputs], dim=0)
        y_trues = torch.cat([x[1] for x in self.val_outputs], dim=0)
        self.val_outputs.clear()

        DG = y_hats[:, 0]
        GDM = y_hats[:, 1]
        DT = y_hats[:, 2]
        DD = DT - GDM
        DC = GDM - DG
        y_hat_full = torch.stack([DG, DD, DC, GDM, DT], dim=1)

        y_mean = torch.mean(y_trues, dim=0, keepdim=True)
        ss_res = torch.sum((y_trues - y_hat_full) ** 2, dim=0)
        ss_tot = torch.sum((y_trues - y_mean) ** 2, dim=0)
        r2 = 1 - ss_res / (ss_tot + 1e-8)

        weights = self.criterion.weights.cpu()
        weighted_r2 = torch.sum(weights * r2)

        for i, name in enumerate(["DG", "DD", "DC", "GDM", "DT"]):
            self.log(f"val_r2_{name}", r2[i], prog_bar=True, on_epoch=True)
        self.log("val_weighted_r2", weighted_r2, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=0.5,
            patience=2,
            threshold=0.001,
            min_lr=1e-7,
            verbose=False,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1,
            },
        }

In [None]:
# ======== DataModule ========
class ImageRegressionDataModule(pl.LightningDataModule):
    def __init__(self, train_df, valid_df, batch_size=8, num_workers=4, img_size=1000):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size

    def setup(self, stage=None):
        full_image_aug_train = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        ])
        full_image_aug_val = None

        per_half_transform_train = T.Compose([
            T.Resize((self.img_size, self.img_size)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
        ])
        per_half_transform_val = T.Compose([
            T.Resize((self.img_size, self.img_size)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
        ])

        self.train_dataset = ImageRegressionDataset(
            self.train_df,
            full_image_aug=full_image_aug_train,
            per_half_transform=per_half_transform_train
        )
        self.val_dataset = ImageRegressionDataset(
            self.valid_df,
            full_image_aug=full_image_aug_val,
            per_half_transform=per_half_transform_val
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          shuffle=False, num_workers=self.num_workers)

In [None]:
# ======== Train ========
kf = KFold(n_splits=3, random_state=114514, shuffle=True)
for fold, (train_index, valid_index) in enumerate(kf.split(train_df)):
    datamodule = ImageRegressionDataModule(train_df.iloc[train_index], train_df.iloc[valid_index])
    model = MultiRegressionModel(model_name="efficientnet_b2", pretrained=True, lr=5e-3)
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_weighted_r2",
        save_top_k=1,
        mode="max",
        filename=f"best_model_fold{fold}"
    )
    lr_monitor_callback = LearningRateMonitor(logging_interval="epoch")
    early_stop_callback = EarlyStopping(
        monitor="val_weighted_r2",
        patience=5,
        mode="max",
        min_delta=0.0002,
        verbose=True
    )
    
    trainer = pl.Trainer(
        max_epochs=64,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        callbacks=[checkpoint_callback, lr_monitor_callback, early_stop_callback],
        precision="16-mixed",
        log_every_n_steps=15,
    )
    trainer.fit(model, datamodule=datamodule)
    torch.save(model.state_dict(), f"model_fold{fold}.pth")