In [1]:
!pip install catboost -q
!pip install lightgbm -q
!pip install gdown -q

In [2]:
import os
import cv2
import gdown
import random
import zipfile
import numpy as np
from tqdm import tqdm
import scipy.stats as st
from sklearn.svm import SVC
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV

In [3]:
RANDOM_STATE = 42
random.seed(RANDOM_STATE)

In [4]:
DATASET_DIR = os.path.join(os.getcwd(), "dataset")
TRAIN_DIR = os.path.join(DATASET_DIR, "train")
TEST_DIR = os.path.join(DATASET_DIR, "test")

TEMP_DIR = os.path.join(os.getcwd(), "temp")
TEMP_TRAIN_DIR = os.path.join(TEMP_DIR, "train")
TEMP_TEST_DIR = os.path.join(TEMP_DIR, "test")

ZIP_PATH = os.path.join(os.getcwd(), "dataset_32_classes.zip")
os.makedirs(DATASET_DIR, exist_ok=True)

In [5]:
# file_id = "1EbvcZbzVXSmB2N1SZYNeUfUuXb8wp3-k"
file_id = "1-1ehpRd0TnwB1hTHQbFHzdf55SrIri4f"
if os.path.exists(ZIP_PATH):
    print("Архив уже добавлен")
else:
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}", os.path.join(os.getcwd(), "dataset_32_classes.zip"), quiet=False
    )

    zip_file_name = "dataset_32_classes.zip"
    os.makedirs("dataset", exist_ok=True)

    with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
        zip_ref.extractall("dataset")

Архив уже добавлен


Сделаем ограничение на количество дискрипторов в 128 шт. По опыту других ниболее оптимальный выбор. И так как в картинках может встретиться разное количество признаков, то выполним преобразование:


*   Если количество дескрипторов меньше num_features, функция вычисляет среднее значение всех дескрипторов и присваивает его вектору признако
*   Если количество дескрипторов больше или равно num_features, функция берет только первые num_features дескрипторов и вычисляет их среднее значение


In [6]:
def get_SIFT_descriptors(img):
    sift = cv2.SIFT_create()
    keypoints, descriptors = sift.detectAndCompute(img, None)
    return descriptors

In [7]:
def create_feature_vector(descriptors, num_features=128):
    feature_vector = np.zeros(num_features)

    if descriptors is not None and len(descriptors) > 0:
        if descriptors.shape[0] < num_features:
            feature_vector = np.mean(descriptors, axis=0)
        else:
            feature_vector = np.mean(descriptors[:num_features], axis=0)

    return feature_vector

Функция предназначена для перебора всех картинок с датасета, преобразования их к типу SIFT и получения дискрипторов и названий классов.

In [8]:
def analyze_dataset(image_folder, size_img):
    features = []
    labels = []

    for class_name in os.listdir(image_folder):
        class_path = os.path.join(image_folder, class_name)

        if os.path.isdir(class_path):
            for filename in tqdm(os.listdir(class_path), desc=f"Обработка {class_name}", unit="image"):
                if filename.endswith(".jpg") or filename.endswith(".jpeg"):
                    image_path = os.path.join(class_path, filename)
                    img = cv2.imread(image_path)
                    img = cv2.resize(img, (size_img, size_img), interpolation=cv2.INTER_CUBIC)
                    descriptors = get_SIFT_descriptors(img)

                    feature_vector = create_feature_vector(descriptors)
                    features.append(feature_vector)
                    labels.append(class_name)

    features_array = np.array(features)
    labels_array = np.array(labels)

    return features_array, labels_array

# Обучение моделей

*   Получим дескрипторы с каждого изображения и названия классов с размером изображений 64 px
*   Обучим нелинейные модели RandomForest, LightGBM, CatBoost.
*   Подберем гиперпараметры для моделей c помощью гридсерча
*   Посмотрим метрики на различных модельках





In [9]:
X_train, y_train = analyze_dataset(TRAIN_DIR, 64)
X_test, y_test = analyze_dataset(TEST_DIR, 64)

Обработка Apple: 100%|██████████| 1120/1120 [00:01<00:00, 845.05image/s]
Обработка Avocado: 100%|██████████| 1120/1120 [00:01<00:00, 811.47image/s]
Обработка Banana: 100%|██████████| 1120/1120 [00:01<00:00, 838.34image/s]
Обработка Bean: 100%|██████████| 1120/1120 [00:01<00:00, 567.88image/s]
Обработка Bitter_Gourd: 100%|██████████| 1120/1120 [00:02<00:00, 540.52image/s]
Обработка Bottle_Gourd: 100%|██████████| 1120/1120 [00:01<00:00, 572.35image/s]
Обработка Brinjal: 100%|██████████| 1120/1120 [00:02<00:00, 550.00image/s]
Обработка Broccoli: 100%|██████████| 1120/1120 [00:02<00:00, 549.51image/s]
Обработка Cabbage: 100%|██████████| 1120/1120 [00:01<00:00, 576.08image/s]
Обработка Capsicum: 100%|██████████| 1120/1120 [00:01<00:00, 606.99image/s]
Обработка Carrot: 100%|██████████| 1120/1120 [00:01<00:00, 618.57image/s]
Обработка Cauliflower: 100%|██████████| 1120/1120 [00:01<00:00, 586.36image/s]
Обработка Cherry: 100%|██████████| 1120/1120 [00:01<00:00, 894.61image/s]
Обработка Cucumbe

## Проверка бейзлайна

In [10]:
svc = SVC()

In [11]:
svc.fit(X_train, y_train)
svc_pred = svc.predict(X_test)

In [12]:
print(classification_report(y_test, svc_pred))

              precision    recall  f1-score   support

       Apple       0.47      0.35      0.40       280
     Avocado       0.67      0.59      0.63       280
      Banana       0.74      0.73      0.74       280
        Bean       0.52      0.71      0.60       280
Bitter_Gourd       0.51      0.70      0.59       280
Bottle_Gourd       0.59      0.76      0.67       280
     Brinjal       0.50      0.41      0.45       280
    Broccoli       0.46      0.59      0.52       280
     Cabbage       0.40      0.53      0.46       280
    Capsicum       0.50      0.64      0.56       280
      Carrot       0.74      0.66      0.70       280
 Cauliflower       0.67      0.81      0.73       280
      Cherry       0.71      0.75      0.73       280
    Cucumber       0.70      0.47      0.56       280
       Grape       0.70      0.73      0.72       280
        Kiwi       0.64      0.63      0.64       280
       Mango       0.73      0.50      0.59       280
         Nut       0.69    

## Проверка лучшей модели с SVC

In [13]:
svc_best = SVC(C=8.1, kernel="rbf")

In [14]:
svc_best.fit(X_train, y_train)
svc_best_pred = svc_best.predict(X_test)

In [15]:
print(classification_report(y_test, svc_best_pred))

              precision    recall  f1-score   support

       Apple       0.58      0.50      0.54       280
     Avocado       0.75      0.75      0.75       280
      Banana       0.78      0.78      0.78       280
        Bean       0.63      0.75      0.69       280
Bitter_Gourd       0.61      0.73      0.67       280
Bottle_Gourd       0.68      0.85      0.76       280
     Brinjal       0.55      0.57      0.56       280
    Broccoli       0.51      0.73      0.60       280
     Cabbage       0.50      0.61      0.55       280
    Capsicum       0.58      0.74      0.65       280
      Carrot       0.81      0.80      0.81       280
 Cauliflower       0.75      0.84      0.79       280
      Cherry       0.79      0.84      0.81       280
    Cucumber       0.77      0.61      0.68       280
       Grape       0.82      0.84      0.83       280
        Kiwi       0.74      0.70      0.72       280
       Mango       0.81      0.67      0.74       280
         Nut       0.85    

## Проверка RandomForest

In [16]:
rf = RandomForestClassifier(random_state=RANDOM_STATE)

rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)
print(classification_report(y_test, rf_pred))

              precision    recall  f1-score   support

       Apple       0.43      0.38      0.40       280
     Avocado       0.75      0.67      0.71       280
      Banana       0.68      0.77      0.72       280
        Bean       0.54      0.72      0.61       280
Bitter_Gourd       0.47      0.70      0.56       280
Bottle_Gourd       0.52      0.72      0.60       280
     Brinjal       0.40      0.36      0.38       280
    Broccoli       0.40      0.51      0.45       280
     Cabbage       0.38      0.42      0.40       280
    Capsicum       0.48      0.56      0.52       280
      Carrot       0.74      0.68      0.71       280
 Cauliflower       0.63      0.76      0.69       280
      Cherry       0.70      0.74      0.72       280
    Cucumber       0.81      0.52      0.63       280
       Grape       0.78      0.75      0.76       280
        Kiwi       0.63      0.62      0.63       280
       Mango       0.84      0.63      0.72       280
         Nut       0.80    

Подберу гиперпаматеры, попробую улучшить метрики

In [17]:
rf = RandomForestClassifier(random_state=RANDOM_STATE)

params = {
    "n_estimators": [100, 200, 500],
    "max_features": ["sqrt", "log2", None],
    "max_depth": [None, 4, 6, 8],
    "criterion": ["gini", "entropy", "log_loss"],
}

rs_rf = RandomizedSearchCV(rf, params, cv=3)

In [18]:
rs_rf.fit(X_train, y_train)
rs_rf.best_params_

{'n_estimators': 200,
 'max_features': 'sqrt',
 'max_depth': None,
 'criterion': 'entropy'}

In [19]:
rf = RandomForestClassifier(random_state=RANDOM_STATE, **rs_rf.best_params_)

rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)
print(classification_report(y_test, rf_pred))

              precision    recall  f1-score   support

       Apple       0.69      0.38      0.49       280
     Avocado       0.83      0.69      0.75       280
      Banana       0.68      0.80      0.73       280
        Bean       0.61      0.75      0.67       280
Bitter_Gourd       0.50      0.72      0.59       280
Bottle_Gourd       0.54      0.76      0.63       280
     Brinjal       0.62      0.43      0.51       280
    Broccoli       0.41      0.56      0.47       280
     Cabbage       0.47      0.47      0.47       280
    Capsicum       0.56      0.63      0.59       280
      Carrot       0.78      0.69      0.73       280
 Cauliflower       0.62      0.79      0.69       280
      Cherry       0.79      0.80      0.79       280
    Cucumber       0.87      0.57      0.69       280
       Grape       0.85      0.79      0.82       280
        Kiwi       0.70      0.66      0.68       280
       Mango       0.91      0.69      0.78       280
         Nut       0.84    

## Проверка LightGBM

In [20]:
lgbm = LGBMClassifier(random_state=RANDOM_STATE)
lgbm.fit(X_train, y_train)
lgbm_pred = lgbm.predict(X_test)

[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.009874 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 32640
[LightGBM] [Info] Number of data points in the train set: 35840, number of used features: 128
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightG

In [21]:
print(classification_report(y_test, lgbm_pred))

              precision    recall  f1-score   support

       Apple       0.49      0.49      0.49       280
     Avocado       0.82      0.69      0.75       280
      Banana       0.80      0.78      0.79       280
        Bean       0.65      0.72      0.68       280
Bitter_Gourd       0.61      0.66      0.63       280
Bottle_Gourd       0.68      0.82      0.74       280
     Brinjal       0.53      0.55      0.54       280
    Broccoli       0.49      0.65      0.56       280
     Cabbage       0.45      0.56      0.50       280
    Capsicum       0.55      0.65      0.60       280
      Carrot       0.82      0.76      0.79       280
 Cauliflower       0.76      0.77      0.76       280
      Cherry       0.78      0.75      0.77       280
    Cucumber       0.75      0.66      0.70       280
       Grape       0.87      0.81      0.84       280
        Kiwi       0.70      0.65      0.67       280
       Mango       0.76      0.65      0.70       280
         Nut       0.86    

Подберу гиперпаматеры, попробую улучшить метрики

In [22]:
lgbm = LGBMClassifier(random_state=RANDOM_STATE)

params = {
    "min_child_samples": range(5, 101),
    "num_leaves": range(2, 257),
    "reg_alpha": st.loguniform(1e-8, 10.0),
    "reg_lambda": st.loguniform(1e-8, 10.0),
}

rs_lgbm = RandomizedSearchCV(lgbm, params, cv=3)

In [23]:
rs_lgbm.fit(X_train, y_train)

[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.005560 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 32640
[LightGBM] [Info] Number of data points in the train set: 23893, number of used features: 128
[LightGBM] [Info] Start training from score -3.466615
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.466615
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.466615
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.466615
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.465276
[LightGBM] [Info] Start training from score -3.466615
[LightG

In [24]:
rs_lgbm.best_params_

{'min_child_samples': 66,
 'num_leaves': 165,
 'reg_alpha': 0.0009335425359470147,
 'reg_lambda': 0.0007403650842074439}

In [25]:
lgbm_best = LGBMClassifier(**rs_lgbm.best_params_, random_state=RANDOM_STATE)
lgbm_best.fit(X_train, y_train)
lgbm_best_pred = lgbm_best.predict(X_test)

[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.006403 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 32640
[LightGBM] [Info] Number of data points in the train set: 35840, number of used features: 128
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightGBM] [Info] Start training from score -3.465736
[LightG

In [26]:
print(classification_report(y_test, lgbm_best_pred))

              precision    recall  f1-score   support

       Apple       0.64      0.46      0.53       280
     Avocado       0.86      0.72      0.79       280
      Banana       0.80      0.79      0.80       280
        Bean       0.65      0.79      0.71       280
Bitter_Gourd       0.62      0.73      0.67       280
Bottle_Gourd       0.70      0.85      0.77       280
     Brinjal       0.60      0.57      0.59       280
    Broccoli       0.50      0.70      0.59       280
     Cabbage       0.53      0.57      0.55       280
    Capsicum       0.59      0.72      0.65       280
      Carrot       0.87      0.79      0.83       280
 Cauliflower       0.71      0.83      0.76       280
      Cherry       0.83      0.81      0.82       280
    Cucumber       0.86      0.68      0.76       280
       Grape       0.88      0.85      0.87       280
        Kiwi       0.73      0.70      0.71       280
       Mango       0.89      0.72      0.80       280
         Nut       0.88    

## Проверка CatBoost

In [27]:
cat = CatBoostClassifier(logging_level="Silent", random_state=RANDOM_STATE)
cat.fit(X_train, y_train)
cat_pred = cat.predict(X_test)

In [28]:
print(classification_report(y_test, cat_pred))

              precision    recall  f1-score   support

       Apple       0.58      0.43      0.49       280
     Avocado       0.80      0.72      0.76       280
      Banana       0.76      0.78      0.77       280
        Bean       0.67      0.77      0.72       280
Bitter_Gourd       0.63      0.75      0.68       280
Bottle_Gourd       0.66      0.81      0.73       280
     Brinjal       0.59      0.54      0.56       280
    Broccoli       0.54      0.74      0.62       280
     Cabbage       0.56      0.59      0.57       280
    Capsicum       0.61      0.70      0.65       280
      Carrot       0.82      0.76      0.79       280
 Cauliflower       0.76      0.82      0.79       280
      Cherry       0.81      0.82      0.82       280
    Cucumber       0.79      0.65      0.72       280
       Grape       0.83      0.82      0.82       280
        Kiwi       0.67      0.67      0.67       280
       Mango       0.84      0.64      0.72       280
         Nut       0.83    

Подберу гиперпаматеры, попробую улучшить метрики

In [29]:
cat = CatBoostClassifier(logging_level="Silent", random_state=RANDOM_STATE, gpu_ram_part=0.9, task_type="GPU")

params = {
    "min_child_samples": range(5, 101),
    "learning_rate": [0.001, 0.01, 0.03, 0.1],
    "depth": [4, 6, 10],
    "reg_lambda": st.loguniform(1e-8, 10.0),
}

rs_cat = RandomizedSearchCV(cat, params, cv=3)

In [30]:
rs_cat.fit(X_train, y_train)
rs_cat.best_params_

{'depth': 10,
 'learning_rate': 0.1,
 'min_child_samples': 63,
 'reg_lambda': 7.771853478232728e-05}

In [31]:
cat_best = CatBoostClassifier(**rs_cat.best_params_, random_state=RANDOM_STATE, gpu_ram_part=0.9, task_type="GPU")

cat_best.fit(X_train, y_train)
cat_best_pred = cat_best.predict(X_test)

0:	learn: 3.0828450	total: 248ms	remaining: 4m 7s
1:	learn: 2.8513500	total: 501ms	remaining: 4m 10s
2:	learn: 2.6804868	total: 723ms	remaining: 4m
3:	learn: 2.5329670	total: 964ms	remaining: 3m 59s
4:	learn: 2.4215271	total: 1.18s	remaining: 3m 55s
5:	learn: 2.2978099	total: 1.43s	remaining: 3m 57s
6:	learn: 2.2063620	total: 1.66s	remaining: 3m 55s
7:	learn: 2.1102014	total: 1.9s	remaining: 3m 55s
8:	learn: 2.0233865	total: 2.14s	remaining: 3m 55s
9:	learn: 1.9482193	total: 2.37s	remaining: 3m 54s
10:	learn: 1.8931747	total: 2.58s	remaining: 3m 52s
11:	learn: 1.8291587	total: 2.83s	remaining: 3m 52s
12:	learn: 1.7532966	total: 3.08s	remaining: 3m 53s
13:	learn: 1.7027422	total: 3.31s	remaining: 3m 53s
14:	learn: 1.6340595	total: 3.56s	remaining: 3m 54s
15:	learn: 1.5979207	total: 3.77s	remaining: 3m 51s
16:	learn: 1.5428044	total: 4.02s	remaining: 3m 52s
17:	learn: 1.4985643	total: 4.26s	remaining: 3m 52s
18:	learn: 1.4476803	total: 4.52s	remaining: 3m 53s
19:	learn: 1.4217834	total: 

In [32]:
print(classification_report(y_test, cat_best_pred))

              precision    recall  f1-score   support

       Apple       0.60      0.52      0.56       280
     Avocado       0.89      0.72      0.80       280
      Banana       0.81      0.80      0.80       280
        Bean       0.68      0.80      0.73       280
Bitter_Gourd       0.59      0.73      0.65       280
Bottle_Gourd       0.64      0.84      0.72       280
     Brinjal       0.57      0.51      0.54       280
    Broccoli       0.50      0.67      0.57       280
     Cabbage       0.47      0.57      0.52       280
    Capsicum       0.55      0.71      0.62       280
      Carrot       0.84      0.75      0.79       280
 Cauliflower       0.72      0.81      0.77       280
      Cherry       0.83      0.83      0.83       280
    Cucumber       0.87      0.68      0.76       280
       Grape       0.87      0.84      0.85       280
        Kiwi       0.71      0.69      0.70       280
       Mango       0.89      0.69      0.78       280
         Nut       0.88    

# Таблица результатов

| Модель         | Гиперпараметры                                                                    | Размер изображения | Цветное | accuracy Test |
|----------------|-----------------------------------------------------------------------------------|--------------------|---------|---------------|
| RandomForest   | n_estimators: 100, criterion: "gini", max_depth: None, max_features: "sqrt"       | 64px               | цветное | 0.6           |
| RandomForest   | n_estimators: 200, criterion: "entropy", max_depth: None, max_features: "sqrt"    | 64px               | цветное | 0\.65         |
| LightGBM       | min_child_samples: 20, num_leaves: 31, reg_alpha: 0, reg_lambda': 0               | 64px               | цветное | 0\.67         |
| LightGBM       | min_child_samples: 66, num_leaves: 165, reg_alpha: 0.00093, reg_lambda': 0.00074  | 64px               | цветное | 0\.71         |
| CatBoost       | depth: None, learning_rate: None, min_child_samples: None, 'reg_lambda': None     | 64px               | цветное | 0\.69         |
| CatBoost       | depth: 10, learning_rate: 0.1, min_child_samples: 63, 'reg_lambda': 7.77e-05      | 64px               | цветное | 0\.7          |

# Выводы:

1.   Наилучшие результаты из нелинейных моделей дает LightGBM с подобранными гиперпараметрами, accuracy = 0.71
2.   Также хороший результат дает CatBoost с подобранными гиперпараметрами, accuracy = 0.7
