In [None]:
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import random
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

In [None]:
# ======== Seed ========
def seed_everything(seed: int = 42):
    pl.seed_everything(seed, workers=True)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    # 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(42)

In [None]:
# ======== Weighted R² ========
def weighted_r2_score(y_true: np.ndarray, y_pred: np.ndarray):
    """
    y_true, y_pred: shape (N, 5)
    """
    weights = np.array([0.1, 0.1, 0.1, 0.2, 0.5])
    r2_scores = []
    for i in range(5):
        y_t = y_true[:, i]
        y_p = y_pred[:, i]
        ss_res = np.sum((y_t - y_p) ** 2)
        ss_tot = np.sum((y_t - np.mean(y_t)) ** 2)
        r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0.0
        r2_scores.append(r2)
    r2_scores = np.array(r2_scores)
    weighted_r2 = np.sum(r2_scores * weights) / np.sum(weights)
    return weighted_r2, r2_scores

In [None]:
# ======== Dataset ========
class ImageRegressionDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open('/kaggle/input/csiro-biomass/' + row["image_path"]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        # target列が5つ (target1〜target5) の想定
        targets = torch.tensor([
            row["Dry_Green_g"],
            row["Dry_Dead_g"],
            row["Dry_Clover_g"],
            row["GDM_g"],
            row["Dry_Total_g"]
        ], dtype=torch.float32)
        return image, targets

In [None]:
# ======== Model ========
class MultiRegressionModel(pl.LightningModule):
    def __init__(self, model_name="tf_efficientnetv2_s.in1k", pretrained=True, lr=1e-4, output_dim=5):
        super().__init__()
        self.save_hyperparameters()
        self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=output_dim)
        self.criterion = nn.MSELoss()
        self.val_outputs = []

    def forward(self, x):
        return self.model(x)  # shape: (B, 5)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.val_outputs.append((y_hat.detach().cpu(), y.detach().cpu()))
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        return loss

    def on_validation_epoch_end(self):
        if len(self.val_outputs) == 0:
            self.log("val_weighted_r2", 0.0, prog_bar=True, on_epoch=True)
            for name in ["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]:
                self.log(f"val_r2_{name}", 0.0, on_epoch=True)
            self.val_outputs.clear()
            return

        preds, trues = zip(*self.val_outputs)
        preds = torch.cat(preds).numpy()
        trues = torch.cat(trues).numpy()
        weighted_r2, r2s = weighted_r2_score(trues, preds)
        self.log("val_weighted_r2", weighted_r2, prog_bar=True, on_epoch=True)
        for i, name in enumerate(["Dry_Green_g", "Dry_Dead_g", "Dry_Clover_g", "GDM_g", "Dry_Total_g"]):
            self.log(f"val_r2_{name}", r2s[i], on_epoch=True)
        self.val_outputs.clear()
        return

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [None]:
# ======== DataModule ========
class ImageRegressionDataModule(pl.LightningDataModule):
    def __init__(self, train_df, valid_df, batch_size=8, num_workers=4, width=512,height=512):
        super().__init__()
        self.train_df = train_df
        self.valid_df = valid_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.width = width
        self.height=height

    def setup(self, stage=None):
        train_transform = T.Compose([
            T.Resize((self.height, self.width)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
        ])
        val_transform = T.Compose([
            T.Resize((self.height, self.width)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225])
        ])
        self.train_dataset = ImageRegressionDataset(self.train_df, train_transform)
        self.val_dataset = ImageRegressionDataset(self.valid_df, val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          shuffle=False, num_workers=self.num_workers)

In [None]:
train_df = pd.read_csv('/kaggle/input/csiro-biomass/train.csv')
train_df = pd.pivot_table(train_df, index='image_path', columns=['target_name'], values='target').reset_index()
kf = KFold(n_splits=5, random_state=0, shuffle=True)
for fold, (train_index, valid_index) in enumerate(kf.split(train_df)):
    datamodule = ImageRegressionDataModule(train_df.iloc[train_index], train_df.iloc[valid_index])
    model = MultiRegressionModel(model_name="tf_efficientnetv2_s.in1k", pretrained=True, lr=1e-4)
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_weighted_r2",
        save_top_k=1,
        mode="max",
        filename=f"best_model_fold{fold}"
    )
    
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        callbacks=[checkpoint_callback],
        #precision="16-mixed",  # optional for faster training
    )
    trainer.fit(model, datamodule=datamodule)
    torch.save(model.state_dict(), f"model_fold{fold}.pth")