- Notebook modified from https://www.kaggle.com/code/markwijkhuizen/planttraits2024-eda-training-pub.
- Training only, EDA part not included.
- Image model only, tabular data not used.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as imageio
import albumentations as A

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch import nn
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import torch
import timm
import glob
import torchmetrics
import time
import psutil
import os
import time

tqdm.pandas()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class Config():
    IMAGE_SIZE = 384
    BACKBONE = 'swin_large_patch4_window12_384.ms_in22k_ft_in1k'
#     BACKBONE = 'tf_efficientnet_b0'
    TARGET_COLUMNS = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
    N_TARGETS = len(TARGET_COLUMNS)
    BATCH_SIZE = 10
    LR_MAX = 1e-4
    WEIGHT_DECAY = 0.01
    N_EPOCHS = 6
    TRAIN_MODEL = True
    IS_INTERACTIVE = os.environ['KAGGLE_KERNEL_RUN_TYPE'] == 'Interactive'
    
    MODEL_PATH = '/kaggle/input/plainttraits2024-swintransformer/model.pth'
        
CONFIG = Config()

In [3]:
%%time
read_images = False

if not read_images:
    train = pd.read_pickle('/kaggle/input/plainttraits2024-swintransformer/train.pkl')
    test = pd.read_pickle('/kaggle/input/plainttraits2024-swintransformer/test.pkl')
else: 
    # if CONFIG.TRAIN_MODEL:
    train = pd.read_csv('/kaggle/input/planttraits2024/train.csv')
    train['file_path'] = train['id'].apply(lambda s: f'/kaggle/input/planttraits2024/train_images/{s}.jpeg')
    train['jpeg_bytes'] = train['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
    train.to_pickle('train.pkl')

    test = pd.read_csv('/kaggle/input/planttraits2024/test.csv')
    test['file_path'] = test['id'].apply(lambda s: f'/kaggle/input/planttraits2024/test_images/{s}.jpeg')
    test['jpeg_bytes'] = test['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
    test.to_pickle('test.pkl')

for column in CONFIG.TARGET_COLUMNS:
    lower_quantile = train[column].quantile(0.005)
    upper_quantile = train[column].quantile(0.985)  
    train = train[(train[column] >= lower_quantile) & (train[column] <= upper_quantile)]    
    
CONFIG.N_TRAIN_SAMPLES = len(train)
CONFIG.N_STEPS_PER_EPOCH = (CONFIG.N_TRAIN_SAMPLES // CONFIG.BATCH_SIZE)
CONFIG.N_STEPS = CONFIG.N_STEPS_PER_EPOCH * CONFIG.N_EPOCHS + 1    
    
if CONFIG.TRAIN_MODEL:
    print('N_TRAIN_SAMPLES:', len(train), 'N_TEST_SAMPLES:', len(test))
else:
    print('N_TEST_SAMPLES:', len(test))

N_TRAIN_SAMPLES: 49168 N_TEST_SAMPLES: 6545
CPU times: user 1.57 s, sys: 2.88 s, total: 4.45 s
Wall time: 26.7 s


In [4]:
# if CONFIG.TRAIN_MODEL:
LOG_FEATURES = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']

y_df = np.zeros_like(train[CONFIG.TARGET_COLUMNS], dtype=np.float32)
for target_idx, target in enumerate(CONFIG.TARGET_COLUMNS):
    v = train[target].values
    if target in LOG_FEATURES:
        v = np.log10(v)
    y_df[:, target_idx] = v

SCALER = StandardScaler()
y_df = SCALER.fit_transform(y_df)

In [5]:
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

TRAIN_TRANSFORMS = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomSizedCrop(
            [int(0.85*CONFIG.IMAGE_SIZE), CONFIG.IMAGE_SIZE],
            CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE, w2h_ratio=1.0, p=0.75),
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.25),
        A.ImageCompression(quality_lower=85, quality_upper=100, p=0.25),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

TEST_TRANSFORMS = A.Compose([
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

class Dataset(Dataset):
    def __init__(self, X_jpeg_bytes, y, transforms=None):
        self.X_jpeg_bytes = X_jpeg_bytes
        self.y = y
        self.transforms = transforms

    def __len__(self):
        return len(self.X_jpeg_bytes)

    def __getitem__(self, index):
        X_sample = self.transforms(
            image=imageio.imread(self.X_jpeg_bytes[index]),
        )['image']
        y_sample = self.y[index]
        
        return X_sample, y_sample

if CONFIG.TRAIN_MODEL:
    # Splitting the data into training and validation sets
    train_df, val_df, y_train, y_val = train_test_split(train, y_df, test_size=0.2, random_state=42)

    # Creating datasets for training and validation
    train_dataset = Dataset(
        train_df['jpeg_bytes'].values,
        y_train,
        TRAIN_TRANSFORMS,
    )

    val_dataset = Dataset(
        val_df['jpeg_bytes'].values,
        y_val,
        TEST_TRANSFORMS,
    )

    # Creating dataloaders for training and validation
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=psutil.cpu_count(),
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=False,  # No need to shuffle validation data
        num_workers=psutil.cpu_count(),
    )


test_dataset = Dataset(
    test['jpeg_bytes'].values,
    test['id'].values,
    TEST_TRANSFORMS,
)

In [6]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
                CONFIG.BACKBONE,
                num_classes=CONFIG.N_TARGETS,
                pretrained=True)
        
    def forward(self, inputs):
        return self.backbone(inputs)

model = Model()
model = model.to(device)
# print(model)

model.safetensors:   0%|          | 0.00/801M [00:00<?, ?B/s]

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Count parameters
num_parameters = count_parameters(model)
print("Number of parameters in the model:", num_parameters)

Number of parameters in the model: 195207738


In [8]:
def get_lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=CONFIG.LR_MAX,
        total_steps=CONFIG.N_STEPS,
        pct_start=0.1,
        anneal_strategy='cos',
        div_factor=1e1,
        final_div_factor=1e1,
    )

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val):
        self.sum += val.sum()
        self.count += val.numel()
        self.avg = self.sum / self.count

if CONFIG.TRAIN_MODEL:
    MAE = torchmetrics.regression.MeanAbsoluteError().to(device)
    R2 = torchmetrics.regression.R2Score(num_outputs=CONFIG.N_TARGETS, multioutput='uniform_average').to(device)
    LOSS = AverageMeter()

    Y_MEAN = torch.tensor(y_train).mean(dim=0).to(device)
    EPS = torch.tensor([1e-6]).to(device)

    def r2_loss(y_pred, y_true):
        ss_res = torch.sum((y_true - y_pred)**2, dim=0)
        ss_total = torch.sum((y_true - Y_MEAN)**2, dim=0)
        ss_total = torch.maximum(ss_total, EPS)
        r2 = torch.mean(ss_res / ss_total)
        return r2

    LOSS_FN = r2_loss

    optimizer = torch.optim.AdamW(
        params=model.parameters(),
        lr=CONFIG.LR_MAX,
        weight_decay=CONFIG.WEIGHT_DECAY,
    )

    LR_SCHEDULER = get_lr_scheduler(optimizer)

In [9]:
if CONFIG.TRAIN_MODEL:
    print("Start Training:")
    for epoch in range(CONFIG.N_EPOCHS):
        MAE.reset()
        R2.reset()
        LOSS.reset()
        model.train()

        for step, (X_batch, y_true) in enumerate(train_dataloader):
            X_batch = X_batch.to(device)
            y_true = y_true.to(device)
            t_start = time.perf_counter_ns()
            y_pred = model(X_batch)
            loss = LOSS_FN(y_pred, y_true)
            LOSS.update(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            LR_SCHEDULER.step()
            MAE.update(y_pred, y_true)
            R2.update(y_pred, y_true)

            if not CONFIG.IS_INTERACTIVE and (step+1) == len(train_dataloader):
                print(
                    f'EPOCH {epoch+1:02d}, {step+1:04d}/{len(train_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                )
            elif CONFIG.IS_INTERACTIVE:
                print(
                    f'\rEPOCH {epoch+1:02d}, {step+1:04d}/{len(train_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                    end='\n' if (step + 1) == CONFIG.N_STEPS_PER_EPOCH else '', flush=True,
                )
            
        MAE.reset()
        R2.reset()
        LOSS.reset()
        model.eval()
        
        print()
        
        for step, (X_batch, y_true) in enumerate(val_dataloader):
            X_batch = X_batch.to(device)
            y_true = y_true.to(device)
            t_start = time.perf_counter_ns()
            with torch.no_grad():
                y_pred = model(X_batch)
                loss = LOSS_FN(y_pred, y_true)
                
            LOSS.update(loss)
            MAE.update(y_pred, y_true)
            R2.update(y_pred, y_true)
            if not CONFIG.IS_INTERACTIVE and (step+1) == len(val_dataloader):
                print(
                    f'EPOCH VAL, {epoch+1:02d}, {step+1:04d}/{len(val_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                )
            elif CONFIG.IS_INTERACTIVE:
                print(
                    f'\rEPOCH {epoch+1:02d}, {step+1:04d}/{len(val_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, mae: {MAE.compute().item():.4f}, r2: {R2.compute().item():.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                    end='\n' if (step + 1) == CONFIG.N_STEPS_PER_EPOCH else '', flush=True,
                )
        print()

    torch.save(model, 'model.pth')
else:
    model = torch.load(CONFIG.MODEL_PATH)
    model.to(device)

Start Training:
EPOCH 01, 3933/3933 | loss: 0.6692, mae: 0.6189, r2: 0.3598, step: 1.355s, lr: 9.97e-05

EPOCH VAL, 01, 0984/984 | loss: 0.6436, mae: 0.6039, r2: 0.3887, step: 0.178s, lr: 9.97e-05

EPOCH 02, 3933/3933 | loss: 0.6088, mae: 0.5812, r2: 0.4236, step: 1.237s, lr: 9.19e-05

EPOCH VAL, 02, 0984/984 | loss: 0.6008, mae: 0.5720, r2: 0.4378, step: 0.174s, lr: 9.19e-05

EPOCH 03, 3933/3933 | loss: 0.5312, mae: 0.5353, r2: 0.5007, step: 1.237s, lr: 7.52e-05

EPOCH VAL, 03, 0984/984 | loss: 0.5706, mae: 0.5492, r2: 0.4735, step: 0.174s, lr: 7.52e-05

EPOCH 04, 3933/3933 | loss: 0.4477, mae: 0.4832, r2: 0.5842, step: 1.239s, lr: 5.34e-05

EPOCH VAL, 04, 0984/984 | loss: 0.5488, mae: 0.5301, r2: 0.4984, step: 0.174s, lr: 5.34e-05

EPOCH 05, 3933/3933 | loss: 0.3583, mae: 0.4254, r2: 0.6712, step: 1.244s, lr: 3.09e-05

EPOCH VAL, 05, 0984/984 | loss: 0.5056, mae: 0.5031, r2: 0.5407, step: 0.174s, lr: 3.09e-05

EPOCH 06, 3933/3933 | loss: 0.2826, mae: 0.3730, r2: 0.7440, step: 1.237s,

In [10]:
SUBMISSION_ROWS = []
model.eval()

for X_sample_test, test_id in tqdm(test_dataset):
    with torch.no_grad():
        y_pred = model(X_sample_test.unsqueeze(0).to(device)).detach().cpu().numpy()
    
    y_pred = SCALER.inverse_transform(y_pred).squeeze()
    row = {'id': test_id}
    
    for k, v in zip(CONFIG.TARGET_COLUMNS, y_pred):
        if k in LOG_FEATURES:
            row[k.replace('_mean', '')] = 10 ** v
        else:
            row[k.replace('_mean', '')] = v

    SUBMISSION_ROWS.append(row)
    
submission_df = pd.DataFrame(SUBMISSION_ROWS)
submission_df.to_csv('submission.csv', index=False)
print("Submit!")

  0%|          | 0/6545 [00:00<?, ?it/s]

Submit!
