## Load Packages

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v3 as imageio
import albumentations as A

from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch import nn
from tqdm.notebook import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split, KFold, cross_val_score

import torch
import timm
import glob
import torchmetrics
import time
import psutil
import os
import time
import pickle

tqdm.pandas()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
class Config():
    IMAGE_SIZE = 384
#     BACKBONE = 'swin_large_patch4_window12_384.ms_in22k_ft_in1k'
    BACKBONE = 'tf_efficientnet_b0'
    TARGET_COLUMNS = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
    N_CLASSES = 17535
    
    N_TARGETS = len(TARGET_COLUMNS)
    BATCH_SIZE = 32
    LR_MAX = 1e-4
    WEIGHT_DECAY = 0.01
    N_EPOCHS = 8
    TRAIN_MODEL = True
    IS_INTERACTIVE = os.environ['KAGGLE_KERNEL_RUN_TYPE'] == 'Interactive'
    
#     MODEL_PATH = '/kaggle/input/plainttraits2024-swintransformer/model.pth'
    MODEL_PATH = '/kaggle/input/planttraits2024-swintransformer-tabular/model.pth'
        
CONFIG = Config()

## Load Data

In [3]:
%%time
read_images = False

if not read_images:
    train = pd.read_pickle('/kaggle/input/plainttraits2024-swintransformer/train.pkl')
    test = pd.read_pickle('/kaggle/input/plainttraits2024-swintransformer/test.pkl')
else: 
    # if CONFIG.TRAIN_MODEL:
    train = pd.read_csv('/kaggle/input/planttraits2024/train.csv')
    train['file_path'] = train['id'].apply(lambda s: f'/kaggle/input/planttraits2024/train_images/{s}.jpeg')
    train['jpeg_bytes'] = train['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
    train.to_pickle('train.pkl')

    test = pd.read_csv('/kaggle/input/planttraits2024/test.csv')
    test['file_path'] = test['id'].apply(lambda s: f'/kaggle/input/planttraits2024/test_images/{s}.jpeg')
    test['jpeg_bytes'] = test['file_path'].progress_apply(lambda fp: open(fp, 'rb').read())
    test.to_pickle('test.pkl')

for column in CONFIG.TARGET_COLUMNS:
    lower_quantile = train[column].quantile(0.005)
    upper_quantile = train[column].quantile(0.985)  
    train = train[(train[column] >= lower_quantile) & (train[column] <= upper_quantile)]    
    
sd_columns = [col for col in train.columns if col.endswith('_sd')]
train = train.drop(columns=sd_columns)
    
CONFIG.N_TRAIN_SAMPLES = len(train)
CONFIG.N_STEPS_PER_EPOCH = (CONFIG.N_TRAIN_SAMPLES // CONFIG.BATCH_SIZE)
CONFIG.N_STEPS = CONFIG.N_STEPS_PER_EPOCH * CONFIG.N_EPOCHS + 1    
CONFIG.TABULAR_COLUMNS = train.filter(regex='^(WORLDCLIM_BIO|SOIL|MODIS_2000|VOD)').columns
    
if CONFIG.TRAIN_MODEL:
    print('N_TRAIN_SAMPLES:', len(train), 'N_TEST_SAMPLES:', len(test))
else:
    print('N_TEST_SAMPLES:', len(test))

N_TRAIN_SAMPLES: 49168 N_TEST_SAMPLES: 6545
CPU times: user 2.1 s, sys: 3.11 s, total: 5.21 s
Wall time: 41.3 s


## Create Species

In [4]:
train['class'] = train[CONFIG.TARGET_COLUMNS].apply(lambda row: '_'.join(map(str, row)), axis=1)
class_map = {class_: i for i, class_ in enumerate(train['class'].unique())}
train['class'] = train['class'].map(class_map)
train

Unnamed: 0,id,WORLDCLIM_BIO1_annual_mean_temperature,WORLDCLIM_BIO12_annual_precipitation,WORLDCLIM_BIO13.BIO14_delta_precipitation_of_wettest_and_dryest_month,WORLDCLIM_BIO15_precipitation_seasonality,WORLDCLIM_BIO4_temperature_seasonality,WORLDCLIM_BIO7_temperature_annual_range,SOIL_bdod_0.5cm_mean_0.01_deg,SOIL_bdod_100.200cm_mean_0.01_deg,SOIL_bdod_15.30cm_mean_0.01_deg,...,VOD_X_1997_2018_multiyear_mean_m12,X4_mean,X11_mean,X18_mean,X26_mean,X50_mean,X3112_mean,file_path,jpeg_bytes,class
0,192027691,12.235703,374.466675,62.524445,72.256844,773.592041,33.277779,125,149,136,...,0.403038,0.401753,11.758108,0.117484,1.243779,1.849375,50.216034,/kaggle/input/planttraits2024/train_images/192...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,0
1,195542235,17.270555,90.239998,10.351111,38.220940,859.193298,40.009777,124,144,138,...,0.311158,0.480334,15.748846,0.389315,0.642940,1.353468,574.098472,/kaggle/input/planttraits2024/train_images/195...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,1
2,196639184,14.254504,902.071411,49.642857,17.873655,387.977753,22.807142,107,133,119,...,0.455440,0.796917,5.291251,8.552908,0.395241,2.343153,1130.096731,/kaggle/input/planttraits2024/train_images/196...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,2
3,195728812,18.680834,1473.933350,163.100006,45.009758,381.053986,20.436666,120,131,125,...,0.348838,0.525236,9.568305,1.083629,0.154200,1.155308,1042.686546,/kaggle/input/planttraits2024/train_images/195...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,3
4,195251545,0.673204,530.088867,50.857777,38.230709,1323.526855,45.891998,91,146,120,...,0.448166,0.411821,14.528877,0.657585,10.919966,2.246226,2386.467180,/kaggle/input/planttraits2024/train_images/195...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
55484,190558785,19.472172,244.795914,39.127552,67.074493,472.710358,27.758673,118,140,131,...,0.292813,0.337243,11.572778,0.233690,1.783193,1.608341,969.547831,/kaggle/input/planttraits2024/train_images/190...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,9614
55485,194523231,13.724150,1450.000000,162.260208,43.139324,652.716858,26.694387,125,144,135,...,0.294559,0.424371,6.114448,1.017099,12.713048,2.418300,1630.015480,/kaggle/input/planttraits2024/train_images/194...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,5980
55486,195888987,14.741204,581.866638,109.231110,89.272148,507.273010,26.874668,118,155,136,...,0.147268,0.639659,5.549596,2.717395,10.206478,2.722599,602.229880,/kaggle/input/planttraits2024/train_images/195...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,2844
55487,135487319,16.094763,1180.838135,80.176193,22.909716,342.184021,17.346189,109,130,117,...,0.341447,0.774642,7.024218,4.429659,9.372170,3.251739,244.387170,/kaggle/input/planttraits2024/train_images/135...,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,3811


In [6]:
lookup_table = train.groupby('class')[CONFIG.TARGET_COLUMNS].agg(lambda x: x.iloc[0])
lookup_table = lookup_table.reset_index()
print(lookup_table)

       class   X4_mean   X11_mean   X18_mean  X50_mean   X26_mean   X3112_mean
0          0  0.401753  11.758108   0.117484  1.849375   1.243779    50.216034
1          1  0.480334  15.748846   0.389315  1.353468   0.642940   574.098472
2          2  0.796917   5.291251   8.552908  2.343153   0.395241  1130.096731
3          3  0.525236   9.568305   1.083629  1.155308   0.154200  1042.686546
4          4  0.411821  14.528877   0.657585  2.246226  10.919966  2386.467180
...      ...       ...        ...        ...       ...        ...          ...
17530  17530  0.356848  25.458940   0.057045  0.868746   1.239993   394.275778
17531  17531  0.378788  11.007832   0.120857  1.945084   2.102238   499.885284
17532  17532  0.516764  14.608856   1.135980  1.659193  12.567676  3424.308528
17533  17533  0.420978  14.142705  11.404024  1.769373  48.640313  2610.927812
17534  17534  0.581193   8.214060   0.235906  3.182487   0.903203    91.068073

[17535 rows x 7 columns]


In [7]:
# Splitting the data into training and validation sets
train_df, val_df, y_train, y_val = train_test_split(train, train['class'], test_size=0.2, random_state=42)

In [8]:
SCALER_tabular = StandardScaler()
tabular_df_train = SCALER_tabular.fit_transform(train_df[CONFIG.TABULAR_COLUMNS])
tabular_df_val = SCALER_tabular.transform(val_df[CONFIG.TABULAR_COLUMNS])
tabular_df_test = SCALER_tabular.fit_transform(test[CONFIG.TABULAR_COLUMNS])

## Data Loader

In [9]:
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

TRAIN_TRANSFORMS = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomSizedCrop(
            [int(0.85*CONFIG.IMAGE_SIZE), CONFIG.IMAGE_SIZE],
            CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE, w2h_ratio=1.0, p=0.75),
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.25),
        A.ImageCompression(quality_lower=85, quality_upper=100, p=0.25),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

TEST_TRANSFORMS = A.Compose([
        A.Resize(CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE),
        A.ToFloat(),
        A.Normalize(mean=MEAN, std=STD, max_pixel_value=1),
        ToTensorV2(),
    ])

class Dataset(Dataset):
    def __init__(self, X_jpeg_bytes, X_tabular, y, transforms=None):
        self.X_jpeg_bytes = X_jpeg_bytes
        self.X_tabular = X_tabular
        self.y = y
        self.transforms = transforms

    def __len__(self):
        return len(self.X_jpeg_bytes)

    def __getitem__(self, index):
        X_sample = self.transforms(
            image=imageio.imread(self.X_jpeg_bytes[index]),
        )
        X_sample['tabular'] = self.X_tabular[index].astype('float32')
        y_sample = torch.tensor(self.y[index])
        
        return X_sample, y_sample

if CONFIG.TRAIN_MODEL:
    # Creating datasets for training and validation
    train_dataset = Dataset(
        train_df['jpeg_bytes'].values,
        tabular_df_train,
        np.array(y_train),
        TRAIN_TRANSFORMS,
    )

    val_dataset = Dataset(
        val_df['jpeg_bytes'].values,
        tabular_df_val,
        np.array(y_val),
        TEST_TRANSFORMS,
    )

    # Creating dataloaders for training and validation
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=psutil.cpu_count(),
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=CONFIG.BATCH_SIZE,
        shuffle=False,  # No need to shuffle validation data
        num_workers=psutil.cpu_count(),
    )



test_dataset = Dataset(
    test['jpeg_bytes'].values,
    tabular_df_test,
    test['id'].values,
    TEST_TRANSFORMS,
)

## Model

In [10]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model(
                CONFIG.BACKBONE,
                pretrained=True,
#                 num_classes=CONFIG.N_TARGETS,
                num_classes=0,
        )
        
        # Freeze the backbone weights
#         for param in self.backbone.parameters():
#             param.requires_grad = False
        
        # EfficientNet = 1280, SwinTrans = 1536, Tabular = 163
        self.custom_layers = nn.Sequential(
            nn.Linear(1280+163, 2024),
            nn.ReLU(), 
            nn.Linear(2024, CONFIG.N_CLASSES)  
        )
        
    def forward(self, inputs):
        image = inputs['image']
        tabular = inputs['tabular']

        x = self.backbone(image)
        x = torch.cat((tabular, x), dim=1)
        x = self.custom_layers(x)
        
        return x


model = Model()
model = model.to(device)
# print(model.backbone.head)

model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

## Util

In [11]:
def get_lr_scheduler(optimizer):
    return torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=CONFIG.LR_MAX,
        total_steps=CONFIG.N_STEPS,
        pct_start=0.1,
        anneal_strategy='cos',
        div_factor=1e1,
        final_div_factor=1e1,
    )

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val):
        self.sum += val.sum()
        self.count += val.numel()
        self.avg = self.sum / self.count

if CONFIG.TRAIN_MODEL:
    LOSS = AverageMeter()

    LOSS_FN = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(
        params=model.parameters(),
        lr=CONFIG.LR_MAX,
        weight_decay=CONFIG.WEIGHT_DECAY,
    )

    LR_SCHEDULER = get_lr_scheduler(optimizer)

## Train

In [12]:
if CONFIG.TRAIN_MODEL:
    print("Start Training:")
    
    best = float('inf')
    for epoch in range(CONFIG.N_EPOCHS):
        LOSS.reset()
        model.train()

        for step, (X_batch, y_true) in enumerate(train_dataloader):
            X_batch['image'] = X_batch['image'].to(device)
            X_batch['tabular'] = X_batch['tabular'].to(device)
            y_true = y_true.to(device)
            t_start = time.perf_counter_ns()
            y_pred = model(X_batch)
            loss = LOSS_FN(y_pred, y_true)
            LOSS.update(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            LR_SCHEDULER.step()

            if not CONFIG.IS_INTERACTIVE and (step+1) == len(train_dataloader):
                print(
                    f'EPOCH {epoch+1:02d}, {step+1:04d}/{len(train_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                )
            elif CONFIG.IS_INTERACTIVE:
                print(
                    f'\rEPOCH {epoch+1:02d}, {step+1:04d}/{len(train_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                    end='\n' if (step + 1) == CONFIG.N_STEPS_PER_EPOCH else '', flush=True,
                )

        LOSS.reset()
        model.eval()
        print()
        
        for step, (X_batch, y_true) in enumerate(val_dataloader):
            X_batch['image'] = X_batch['image'].to(device)
            X_batch['tabular'] = X_batch['tabular'].to(device)
            y_true = y_true.to(device)
            t_start = time.perf_counter_ns()
            with torch.no_grad():
                y_pred = model(X_batch)
                loss = LOSS_FN(y_pred, y_true)
                
            LOSS.update(loss)

            if not CONFIG.IS_INTERACTIVE and (step+1) == len(val_dataloader):
                print(
                    f'EPOCH VAL {epoch+1:02d}, {step+1:04d}/{len(val_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                )
            elif CONFIG.IS_INTERACTIVE:
                print(
                    f'\rEPOCH VAL {epoch+1:02d}, {step+1:04d}/{len(val_dataloader)} | ' + 
                    f'loss: {LOSS.avg:.4f}, ' +
                    f'step: {(time.perf_counter_ns()-t_start)*1e-9:.3f}s, lr: {LR_SCHEDULER.get_last_lr()[0]:.2e}',
                    end='\n' if (step + 1) == CONFIG.N_STEPS_PER_EPOCH else '', flush=True,
                )
        print()
        if LOSS.avg < best:
            best = LOSS.avg
            torch.save(model, 'model.pth')
            
            
else:
    model = torch.load(CONFIG.MODEL_PATH)
    model.to(device)

Start Training:
EPOCH 01, 1229/1229 | loss: 9.2940, step: 0.293s, lr: 1.00e-04
EPOCH VAL 01, 0308/308 | loss: 8.3416, step: 0.090s, lr: 1.00e-04
EPOCH 02, 1229/1229 | loss: 3.9886, step: 0.294s, lr: 9.70e-05
EPOCH VAL 02, 0308/308 | loss: 7.2325, step: 0.024s, lr: 9.70e-05
EPOCH 03, 1229/1229 | loss: 1.0800, step: 0.294s, lr: 8.84e-05
EPOCH VAL 03, 0308/308 | loss: 7.6543, step: 0.024s, lr: 8.84e-05
EPOCH 04, 0065/1229 | loss: 0.3769, step: 0.294s, lr: 8.78e-05

KeyboardInterrupt: 

In [13]:
model = torch.load('/kaggle/working/model.pth')
model.to(device);

In [34]:
TARGET_COLUMNS = ['X4', 'X11', 'X18', 'X50', 'X26', 'X3112']

SUBMISSION_ROWS = []
model.eval()

for X_sample_test, test_id in tqdm(test_dataset):
    X_sample_test['image'] = torch.Tensor(X_sample_test['image']).unsqueeze(0).to(device)
    X_sample_test['tabular'] = torch.Tensor(X_sample_test['tabular']).unsqueeze(0).to(device)
    with torch.no_grad():
        y_pred = model(X_sample_test).detach().cpu().numpy()
    
    predicted_class = np.argmax(y_pred)
    predicted_attributes = lookup_table.loc[lookup_table['class'] == predicted_class, CONFIG.TARGET_COLUMNS].values.tolist()
    row = [test_id.item()] + predicted_attributes[0] 
    
    SUBMISSION_ROWS.append(row)

submission_df = pd.DataFrame(SUBMISSION_ROWS, columns=['id'] + TARGET_COLUMNS)
submission_df.to_csv('submission.csv', index=False)
print("Submit!")

  0%|          | 0/6545 [00:00<?, ?it/s]

Submit!
