### Transfer Learning Task

- 쌀 이파리 병에 대한 분류

In [2]:
from glob import glob
import os

root = './datasets/rice_leaf_diseases_dataset/original'

directories =  glob(os.path.join(root, '*'))
directory_names = []


for directory in directories:
    directory_names.append(directory[directory.rindex('\\') + 1:])

print(directory_names)

['Bacterialblight', 'Brownspot', 'Leafsmut']


In [3]:
root = './datasets/rice_leaf_diseases_dataset/original/'


for name in directory_names:
    for i, file_name in enumerate(os.listdir(os.path.join(root, name))):
        old_file = os.path.join(root + name + '/', file_name)
        new_file = os.path.join(root + name + '/', name+str(i+1)+'.png')

    os.rename(old_file,new_file)

In [4]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

BATCH_SIZE = 64

image_generator = ImageDataGenerator(rescale = 1./255)

generator = image_generator.flow_from_directory(root, target_size=(244,244), batch_size=BATCH_SIZE, class_mode='categorical')
print(generator.class_indices)

Found 4684 images belonging to 3 classes.
{'Bacterialblight': 0, 'Brownspot': 1, 'Leafsmut': 2}


In [5]:
import pandas as pd
r_df = pd.DataFrame({'file_paths': generator.filepaths, 'targets':generator.classes})
r_df.file_paths
r_df.targets

0       0
1       0
2       0
3       0
4       0
       ..
4679    2
4680    2
4681    2
4682    2
4683    2
Name: targets, Length: 4684, dtype: int32

In [6]:
r_df.loc[:, 'file_paths'] = r_df.file_paths.apply(lambda x: x.replace('\\','/')).reset_index(drop=True)
display(r_df.file_paths)

0       ./datasets/rice_leaf_diseases_dataset/original...
1       ./datasets/rice_leaf_diseases_dataset/original...
2       ./datasets/rice_leaf_diseases_dataset/original...
3       ./datasets/rice_leaf_diseases_dataset/original...
4       ./datasets/rice_leaf_diseases_dataset/original...
                              ...                        
4679    ./datasets/rice_leaf_diseases_dataset/original...
4680    ./datasets/rice_leaf_diseases_dataset/original...
4681    ./datasets/rice_leaf_diseases_dataset/original...
4682    ./datasets/rice_leaf_diseases_dataset/original...
4683    ./datasets/rice_leaf_diseases_dataset/original...
Name: file_paths, Length: 4684, dtype: object

In [7]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test =\
train_test_split(r_df.file_paths, r_df.targets, stratify=r_df.targets, test_size = 0.2, random_state=124)

X_train, X_val, y_train, y_val =\
train_test_split(X_train, y_train, stratify= y_train, test_size=0.2, random_state=124)

print(y_train.value_counts())
print(y_val.value_counts())
print(y_test.value_counts())

targets
1    1037
0    1026
2     934
Name: count, dtype: int64
targets
1    259
0    257
2    234
Name: count, dtype: int64
targets
1    324
0    321
2    292
Name: count, dtype: int64


In [8]:
import shutil

root = './datasets/rice_leaf_diseases_dataset/'

for file_path in X_train:
    print(file_path)
    rice_dir = file_path[len(root+'original/'):file_path.rindex('/')]
    print(rice_dir)
    destination = os.path.join(root,'train/'+ rice_dir)
    if not os.path.exists(destination):
        os.makedirs(destination)
    shutil.copy2(file_path, destination)

./datasets/rice_leaf_diseases_dataset/original/Brownspot/BROWNSPOT3_108.jpg
Brownspot
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT2_138.JPG
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT_108.JPG
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERAILBLIGHT3_262.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Brownspot/BROWNSPOT5_052.jpg
Brownspot
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST2_032.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT1_172.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST4_119.JPG
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST3_154.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST8_105.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST8_116.jpg
Leafsmut
./datasets/ri

In [9]:
import shutil

for file_path in X_val:
    print(file_path)
    rice_dir = file_path[len(root+'original/'):file_path.rindex('/')]
    print(rice_dir)
    destination = os.path.join(root,'validation/'+rice_dir)
    if not os.path.exists(destination):
        os.makedirs(destination)
    shutil.copy2(file_path, destination)

./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST7_008.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST3_072.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT2_010.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Brownspot/brownspot_orig_052.jpg
Brownspot
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT2_062.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT2_021.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST9_052.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST3_134.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST8_115.JPG
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST9_054.JPG
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST4_050.JPG
Leafsmut
./datasets/rice_leaf_diseases_dataset/

In [10]:
import shutil

for file_path in X_test:
    print(file_path)
    rice_dir = file_path[len(root+'original/'):file_path.rindex('/')]
    print(rice_dir)
    destination = os.path.join(root,'test/'+rice_dir)
    if not os.path.exists(destination):
        os.makedirs(destination)
    shutil.copy2(file_path, destination)

./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERAILBLIGHT4_222.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Leafsmut/BLAST9_155.jpg
Leafsmut
./datasets/rice_leaf_diseases_dataset/original/Brownspot/BROWNSPOT1_129.jpg
Brownspot
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT2_162.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Brownspot/BROWNSPOT1_097.jpg
Brownspot
./datasets/rice_leaf_diseases_dataset/original/Brownspot/BROWNSPOT6_145.jpg
Brownspot
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERAILBLIGHT5_094.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERAILBLIGHT5_136.JPG
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERAILBLIGHT3_033.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/original/Bacterialblight/BACTERIALBLIGHT_189.jpg
Bacterialblight
./datasets/rice_leaf_diseases_dataset/o

In [29]:
IMAGE_SIZE = 64
BATCH_SIZE = 32

In [30]:
import albumentations as A

def transform(image):
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
    ],p=0.5)

    return aug(image=image)['image']

In [31]:
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image  import ImageDataGenerator

train_dir = './datasets/rice_leaf_diseases_dataset//train'
validation_dir = './datasets/rice_leaf_diseases_dataset/validation/'
test_dir = './datasets/rice_leaf_diseases_dataset/test'

train_generator = ImageDataGenerator(preprocessing_function= transform, rescale = 1./255)
validation_generator = ImageDataGenerator(rescale = 1./255)
test_generator = ImageDataGenerator(rescale = 1./255)

train_flow = train_generator.flow_from_directory(
    train_dir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle = True
)

validation_flow = validation_generator.flow_from_directory(
    validation_dir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

test_flow  = test_generator.flow_from_directory(
    test_dir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)


Found 2997 images belonging to 3 classes.
Found 750 images belonging to 3 classes.
Found 937 images belonging to 3 classes.


In [32]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Conv2D, Dropout, Flatten, Activation, MaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.applications import VGG16


def create_model(verbose= False):

    input_tensor = Input(shape=(IMAGE_SIZE,IMAGE_SIZE,3))
    model = VGG16(input_tensor = input_tensor , include_top = False, weights='imagenet')
    
    
    x = model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(50, activation = 'relu')(x)
    output = Dense(3, activation = 'softmax')(x)
    
    
    model = Model(inputs = model.input, outputs= output)
    if verbose: 
        model.summary()

    return model

In [33]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.optimizers import Adam 
from tensorflow.keras.losses import CategoricalCrossentropy


mcp_cb = ModelCheckpoint(
    filepath="./callback_files/weights.{epoch:03d}-{val_loss:.4f}-{acc:.4f}.weights.h5",
    monitor='val_loss',
    save_best_only=False,
    save_weights_only=True,
    mode='min'
)

rlr_cb = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=2,
    mode='min'
)

ely_cb = EarlyStopping(
    monitor='val_loss',
    patience=4,
    mode='min'
)
model = create_model(verbose=True)
model.compile(optimizer=Adam(), loss=CategoricalCrossentropy(), metrics=['acc'])

In [34]:
import gc
# 불필요한 오브젝트를 지우는 작업
gc.collect()

2499

In [None]:
history = model.fit(
    train_flow,
    batch_size=BATCH_SIZE,
    epochs=10,
    validation_data=validation_flow,
    callbacks=[mcp_cb, rlr_cb, ely_cb]
)

Epoch 1/10
[1m26/94[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m1:45[0m 2s/step - acc: 0.3161 - loss: 6.7144

In [23]:
model.evaluate(test_flow)

[1m 2/15[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m38s[0m 3s/step - acc: 0.5625 - loss: 0.8484

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

def show_history(history):
    plt.figure(figsize= (6,6))
    plt.ylim(np.arange(0,1,0.05))
    plt.plot(history.history['acc'], label='train')
    plt.plot(history.history['val_acc'], label='validation')
    plt.legend()
    
show_history(history)