# Introduction
This notebook loads the trained model weights and performs inference using Test Time Augmentation (TTA).

The model training is done in [this notebook](https://www.kaggle.com/code/takaito/csiro-img2bio-training-notebook/notebook). (You can train it directly in the Kaggle environment!)

Since I’m more experienced in tabular and NLP competitions, I’ve been looking forward to a chance to learn through a simple image competition like this one.

I hope to pick up various image competition techniques through this challenge!

In [None]:
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import random
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

In [None]:
# ======== Seed ========
def seed_everything(seed: int = 42):
    pl.seed_everything(seed, workers=True)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    # 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(42)

In [None]:
# ======== Dataset ========
class InferenceDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open('/kaggle/input/csiro-biomass/' + row["image_path"]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
# ======== Model ========
class MultiRegressionModel(pl.LightningModule):
    def __init__(self, model_name="tf_efficientnetv2_s.in1k", pretrained=False, lr=1e-4, output_dim=5):
        super().__init__()
        self.save_hyperparameters()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=output_dim)
        self.criterion = nn.MSELoss()
        self.val_outputs = []

    def forward(self, x):
        return self.model(x)  # shape: (B, 5)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.val_outputs.append((y_hat.detach().cpu(), y.detach().cpu()))
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        return loss

    def on_validation_epoch_end(self):
        if len(self.val_outputs) == 0:
            self.log("val_weighted_r2", 0.0, prog_bar=True, on_epoch=True)
            for name in ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]:
                self.log(f"val_r2_{name}", 0.0, on_epoch=True)
            self.val_outputs.clear()
            return

        preds, trues = zip(*self.val_outputs)
        preds = torch.cat(preds).numpy()
        trues = torch.cat(trues).numpy()
        weighted_r2, r2s = weighted_r2_score(trues, preds)
        self.log("val_weighted_r2", weighted_r2, prog_bar=True, on_epoch=True)
        for i, name in enumerate(["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]):
            self.log(f"val_r2_{name}", r2s[i], on_epoch=True)
        self.val_outputs.clear()
        return

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
def tta_inference(model, images):
    preds = model(images)
    preds_lr = model(torch.flip(images, dims=[3]))
    preds_ud = model(torch.flip(images, dims=[2]))
    preds_lrud = model(torch.flip(images, dims=[2, 3]))
    preds_mean = (preds + preds_lr + preds_ud + preds_lrud) / 4.0
    return preds_mean

def get_id(x):
    return x.split('_')[0]

In [None]:
# Transform
height = 512
width=512
infer_transform = T.Compose([
    T.Resize((height, width)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225])
])

# DataLoader
test_df = pd.read_csv('/kaggle/input/csiro-biomass/test.csv')
test_df = test_df[~test_df['image_path'].duplicated()][['sample_id', 'image_path']].reset_index(drop=True)
test_df['sample_id'] = test_df['sample_id'].apply(get_id)
dataset = InferenceDataset(test_df, transform=infer_transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=4)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results_dict = {}
for fold in range(5):
    model = MultiRegressionModel(model_name="tf_efficientnetv2_s.in1k", pretrained=False)
    model.load_state_dict(torch.load(f"/kaggle/input/k/kunalinfinite/csiro-img2bio-training-notebook/model_fold{fold}.pth"))

    model.eval()
    model.to(device)
    results = []
    with torch.no_grad():
        for batch in dataloader:
            images = batch
            images = images.to(device)
            preds = tta_inference(model, images)
            preds = preds.cpu().numpy()
            results.append(preds)
    results_dict[fold] = np.concatenate(results)

In [None]:
result_df = pd.DataFrame(np.mean([results_dict[fold] for fold in range(3)], axis=0), columns=["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"])
result_df['sample_id'] = test_df['sample_id']
result_df = pd.melt(result_df, id_vars='sample_id', value_vars=["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"], value_name='target')
result_df['sample_id'] = result_df['sample_id'] + '__' + result_df['variable']
result_df['target'] = result_df['target'].clip(0, 200)
result_df[['sample_id', 'target']].to_csv('submission.csv', index=False)