In [1]:
import torch
import cv2
import numpy as np
import pandas as pd
import joblib
import imageio.v3 as imageio

from torch import nn
from tqdm import tqdm
from move import move_to
from torchmetrics.regression import R2Score, MeanAbsoluteError
from train import Compile
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Dataset
from dataloader.transformers import TRANSFORMER, TEST_TRANSFORMER
from models.effnet import CustomEffnet

In [2]:
class TestDataset(Dataset):
    def __init__(self, X_jpeg_bytes, x_features, y):
        self.X_jpeg_bytes = X_jpeg_bytes
        self.y = y
        self.transforms = TEST_TRANSFORMER
        self.xs_cols = x_features.columns
        self.boxes = pd.read_csv('../data/2024/boxes_test.csv', index_col='id')

        self.boxes['box'] = self.boxes['box'].apply(
            lambda x: np.fromstring(x.replace('\n', '').replace('[', '').replace(']', '').replace('  ', ' '), sep=' ')
        )

    def __len__(self):
        return len(self.X_jpeg_bytes)

    def __getitem__(self, index):
        try:
            box = self.boxes.loc[self.y[index], 'box']
            X_sample = self.transforms(
                image=imageio.imread(self.X_jpeg_bytes[index])[int(box[1]):int(box[3]), int(box[0]):int(box[2])])[
                'image']
        except:
            X_sample = self.transforms(
                image=imageio.imread(self.X_jpeg_bytes[index]))['image']

        y_sample = self.y[index]

        return move_to(X_sample, 'cuda').unsqueeze(0), y_sample

In [3]:
def predict_test(checkpoint):
    tar_features = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
    log_features = ['X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']

    # load model
    model = CustomEffnet()
    state = torch.load(checkpoint)
    model.load_state_dict(state['model_state_dict'])

    df = pd.read_pickle('../data/2024/processed/test.pkl')
    pipe = joblib.load('../data/2024/processed/scaler.joblib')
    test_dataset = TestDataset(df['jpeg_bytes'].values, df, df['id'].values)
    preds = []

    model.eval()
    model.cuda()

    for x, idx in tqdm(test_dataset):
        with torch.no_grad():
            y = model(x).detach().cpu().numpy()

        logits = pipe.inverse_transform(y).squeeze()
        logits = logits[:6]

        row = {'id': idx}

        for k, v in zip(tar_features, logits):

            if k in log_features:
                row[k.replace('_mean', '')] = 10 ** v

            else:
                row[k.replace('_mean', '')] = v

        preds.append(row)

    preds = pd.DataFrame(preds)

    return preds

In [None]:
SEED = 2024
folds = KFold(n_splits=5, shuffle=True, random_state=SEED)
data = pd.read_csv('../data/2024/processed/train.csv')
test_preds = []


class PlantDataset(Dataset):

    def __init__(self, df, transformer=None):
        self.df = df
        self.columns = ['X4_mean', 'X11_mean', 'X18_mean', 'X50_mean', 'X26_mean', 'X3112_mean']
        self.dir = '../data/2024/train_images/'
        self.df['box'] = self.df['box'].apply(lambda x: np.fromstring(x.replace('\n', '')
                                                                      .replace('[', '')
                                                                      .replace(']', '')
                                                                      .replace('  ', ' '), sep=' '))
        self.boxes = self.df.pop('box')
        self.transform = transformer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if idx >= len(self): raise IndexError

        img_id = self.df.loc[idx, 'id']
        y = torch.tensor(self.df.loc[idx, self.columns].values, dtype=torch.float32)

        img = cv2.imread(f'{self.dir}/{img_id}.jpeg')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        box = self.boxes.loc[idx]
        img = img[int(box[1]):int(box[3]), int(box[0]):int(box[2])]

        if self.transform is not None:
            augmented = self.transform(image=img)
            img = augmented['image']

        return img, y

In [4]:
for _ in range(3):
    for f_id, (t_idx, v_idx) in enumerate(folds.split(data)):
        train_data = data.iloc[t_idx, :].reset_index(drop=True)
        val_data = data.iloc[v_idx, :].reset_index(drop=True)
        pred_val_fold = []

        train_dataset = PlantDataset(train_data, TRANSFORMER)
        valid_dataset = PlantDataset(val_data, TEST_TRANSFORMER)

        train_dataloader = DataLoader(train_dataset, 14, shuffle=True, drop_last=True)
        val_dataloader = DataLoader(valid_dataset, 14, shuffle=False, drop_last=False)

        model = CustomEffnet()
        # state = torch.load(f'./best_checkpoint_f_{f_id}.pth')
        # model.load_state_dict(state['model_state_dict'])
        # model.cuda()

        complied = Compile(model,
                           nn.MSELoss,
                           torch.optim.AdamW,
                           1e-5,
                           1e-4,
                           10,
                           14,
                           train_loader=train_dataloader,
                           save_to=f'best_checkpoint_f_{f_id}.pth',
                           val_loader=val_dataloader,
                           metrics={'r2': R2Score(6).cuda(),
                                    'mae': MeanAbsoluteError().cuda()})

        complied.fit()
        model.eval()

        for x, _ in tqdm(valid_dataset):
            with torch.no_grad():
                y = model(x.cuda().unsqueeze(0)).detach().cpu().numpy()

            pred_val_fold.append(y.squeeze(0))

        data.iloc[data[data['id'].isin(val_data['id'])].index, 1:-1] = 0.7 * val_data.iloc[:, 1:] + 0.3 * np.asarray(
            pred_val_fold)

        predicted_test = predict_test(f'./best_checkpoint_f_{f_id}.pth')

100%|██████████| 8655/8655 [03:22<00:00, 42.71it/s]
100%|██████████| 6545/6545 [02:34<00:00, 42.38it/s]


In [5]:
predicted_test

Unnamed: 0,id,X4,X11,X18,X50,X26,X3112
0,201238668,0.506703,12.684425,0.830860,1.506088,2.338350,359.560772
1,202310319,0.450098,17.277229,0.367876,1.488224,0.959753,464.126863
2,202604412,0.623843,10.397853,1.436415,2.171376,4.843397,382.602407
3,201353439,0.461603,16.264346,0.423372,1.207422,0.435874,394.140362
4,195351745,0.479500,12.647679,0.181226,1.541777,0.510777,144.144784
...,...,...,...,...,...,...,...
6540,195548469,0.517062,10.396581,0.596170,1.796500,1.858637,196.254767
6541,199261251,0.675147,13.105791,4.437296,1.500954,7.251051,1985.673451
6542,203031744,0.501308,14.724135,0.689901,1.420328,2.059553,588.073152
6543,197736382,0.397435,21.128257,0.383783,1.299778,0.554640,300.732035
