In [8]:
import os
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier
from sklearn import svm
import tensorflow as tf
from tensorflow.keras.applications import resnet

target_shape = (300, 300)

In [9]:
def preprocess_image(filename):
    """이미지를 loading하는 함수"""
    image_string = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, target_shape)

    return image

In [10]:
training_data_path = r'./data/plant_dataset/training_set'
test_data_path = r'./data/plant_dataset/test_set'
CLASSES = ['burn_disease', 'healthy', 'leafspot']


def read_image_label_name_from_path(data_path, class_names):
    images, labels = [], []
    for class_name in class_names:
        image_names = os.listdir(os.path.join(data_path, class_name))
        images.extend([os.path.join(data_path, class_name, image) for image in image_names])
        labels.extend([class_name] * len(image_names))
    return images, labels


# 훈련 이미지와 테스트 이미지 및 대응하는 라벨을 읽어옴
train_images, train_labels = read_image_label_name_from_path(training_data_path, CLASSES)
test_images, test_labels = read_image_label_name_from_path(test_data_path, CLASSES)

In [11]:
# 훈련 데이터셋과 테스트 데이터셋에 대한 dataset class 생성
train_dataset = tf.data.Dataset.from_tensor_slices(train_images)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images)

train_dataset = train_dataset.map(preprocess_image)
test_dataset = test_dataset.map(preprocess_image)

train_dataset = train_dataset.batch(32, drop_remainder=False)
test_dataset = test_dataset.batch(32, drop_remainder=False)

train_dataset = train_dataset.prefetch(8)
test_dataset = test_dataset.prefetch(8)

In [12]:
# 전 단계에서 fine-tuning한 특징 추출기를 loading함
embedding = tf.keras.models.load_model("embeddings/embeddings_e25_v1")



In [13]:
train_embeds = []
test_embeds = []
# 훈련 데이터셋 및 테스트 데이터셋의 이미지에 대해 특징을 추출하고 리스트에 저장함
for sample in train_dataset:
    image_embedding = embedding(resnet.preprocess_input(sample)).numpy()
    train_embeds.append(image_embedding)
for sample in test_dataset:
    image_embedding = embedding(resnet.preprocess_input(sample)).numpy()
    test_embeds.append(image_embedding)

In [14]:
train_embeds = np.vstack(train_embeds)
test_embeds = np.vstack(test_embeds)

In [15]:
le = LabelEncoder()
train_labels = le.fit_transform(train_labels)
test_labels = le.transform(test_labels)

In [16]:
from sklearn.preprocessing import StandardScaler

In [17]:
# Standard scaler를 이용해 정규화를 진행함
scaler = StandardScaler().fit(train_embeds)
std_train_embeds = scaler.transform(train_embeds)
std_test_embeds = scaler.transform(test_embeds)

In [18]:
shuffled_idx = np.random.permutation(len(std_train_embeds))
std_train_embeds = std_train_embeds[shuffled_idx]
train_labels = train_labels[shuffled_idx]

In [22]:
print(std_train_embeds.shape)
print(train_labels.shape)

(1284, 256)
(1284,)


In [61]:
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
# SVM 알고리즘을 사용해 훈련을 진행함
ovr_clf = OneVsRestClassifier(SVC())
ovr_clf.fit(std_train_embeds, train_labels)
ovr_clf.score(std_train_embeds, train_labels)

0.8761682242990654

In [62]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
# RandomForest 알고리즘을 사용해 훈련을 진행함
ovr_clf = OneVsRestClassifier(RandomForestClassifier())
ovr_clf.fit(std_train_embeds, train_labels)
ovr_clf.score(std_train_embeds, train_labels)

1.0

In [63]:
predict_result = ovr_clf.predict(std_test_embeds)

In [64]:
import pandas as pd
df = pd.DataFrame(list(zip(test_images, predict_result)), columns=['Name', 'pred'])

In [65]:
df['Name'] = df['Name'].map(lambda x: '\\'.join(x.split(os.path.sep)[-2:]))
df

Unnamed: 0,Name,pred
0,burn_disease\burn1.jpeg,0
1,burn_disease\burn2.jpeg,0
2,burn_disease\burn3.jpeg,0
3,burn_disease\burn4.jpeg,0
4,burn_disease\burn5.jpeg,0
5,healthy\hel1.jpeg,0
6,healthy\hel2.jpeg,1
7,healthy\hel3.jpeg,1
8,healthy\hel4.jpeg,1
9,healthy\hel5.jpeg,1


In [66]:
df.to_csv("submit.csv", index=False)