<a href="https://colab.research.google.com/github/HNik2/plant_diseases_detection_with_vgg/blob/master/vgg_model_for_plant_disease.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Import des bibliothèques nécessaire
import numpy as np 
import pandas as pd 
import seaborn as sns
import sys
import os
from numpy import load
from matplotlib import pyplot
from sklearn.model_selection import train_test_split
from keras import backend
from keras.layers import Dense
from keras.layers import Flatten
from keras.optimizers import SGD
from keras.applications.vgg16 import VGG16
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.callbacks import ModelCheckpoint
from keras.layers import Dropout
from keras.layers.normalization import BatchNormalization


Using TensorFlow backend.


In [0]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
# Copie du jeu de données des maladies des plantes dans google colaboratory
cp /content/drive/My\ Drive/dataset.zip /

In [0]:
# Copie du jeu de données de test des maladies des plantes dans google colaboratory
cp /content/drive/My\ Drive/test.zip /

In [0]:
import IPython.display as display
from PIL import Image
import os
import zipfile
import pathlib

#Décompression du jeu de données
local_zip = '/dataset.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/')
zip_ref.close()

In [0]:
#Décompression du jeu de données pour les test
local_zip = '/test.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/')
zip_ref.close()

In [0]:
#Définition des répertoires pour les données d'entrainement et les données de validation et pour les test
train_dir = '/dataset/train/'
valid_dir = '/dataset/valid/'
test_dir = '/test'

In [0]:
# definition du modèle cnn à partir du modèle vgg
def define_model(in_shape=(224, 224, 3), out_shape=38):
	# chargement du modèle vgg
  model = VGG16(weights='imagenet',input_shape=in_shape, include_top=False)
  for layer in model.layers:
    layer.trainable = False
	
	# permettre au dernier bloc vgg d'être entraînable
  model.get_layer('block5_conv1').trainable=True
  model.get_layer('block5_conv2').trainable=True
  model.get_layer('block5_conv3').trainable=True
  model.get_layer('block5_pool').trainable=True

	#Ajout des nouvelles couches de classification
  flat1=Flatten()(model.layers[-1].output)
  fcon1 = Dense(4096, activation='relu', kernel_initializer='he_uniform')(flat1)
  fdrop1 = Dropout(0.25)(fcon1)
  fbn1 = BatchNormalization()(fdrop1)
  fcon2 = Dense(4096, activation='relu', kernel_initializer='he_uniform')(fbn1)
  fdrop2 = Dropout(0.25)(fcon2)
  fbn2 = BatchNormalization()(fdrop2)
  output = Dense(out_shape, activation='softmax')(fbn2)
  #définition de notre modèle
  model = Model(inputs=model.inputs, outputs=output)
	#compiler le modèle
  opt = SGD(lr=0.01, momentum=0.9,decay=0.005)
  model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
  return model

In [0]:
# fonction pour afficher les graphes de la précision et l'erreur de l'entrainement
def summarize_diagnostics(history):
  sns.set()
  # plot le graphe pour l'erreur
  pyplot.subplot(211)
  pyplot.title('Cross Entropy Loss')
  pyplot.plot(history.history['loss'], color='blue', label='train')
  pyplot.plot(history.history['val_loss'], color='orange', label='valid')
  pyplot.xlabel('Epoch')
  pyplot.ylabel('Loss')
  pyplot.legend()
  # plot le graphe pour la précsion
  pyplot.subplot(212)
  pyplot.title('Classification Accuracy')
  pyplot.plot(history.history['acc'], color='blue', label='train')
  pyplot.plot(history.history['val_acc'], color='orange', label='valid')
  pyplot.xlabel('Epoch')
  pyplot.ylabel('Accuracy')
  pyplot.legend()
  # sauvegarder le résultat dans un fichier png
  filename = sys.argv[0].split('/')[-1]
  pyplot.savefig(filename + '_plot.png')
  pyplot.close()

In [0]:
# charger les données pour tester la prédiction de notre modèle
def load_image(filename):
	# load the image
	img = load_img(filename, target_size=(224, 224))
	# convert to array
	img = img_to_array(img)
	# reshape into a single sample with 3 channels
	img = img.reshape(1, 224, 224, 3)
    #rescale
	img = img/255
	# center pixel data
	#img = img.astype('float32')
	#img = img - [123.68, 116.779, 103.939]
	return img    

In [0]:
#prétraitement des images du jeu de données avant de le passer en entrée de notre modèle
batch_size = 128

train_datagen = ImageDataGenerator(rescale=1./255)
                                   
valid_datagen = ImageDataGenerator(rescale=1./255)

training_iterator = train_datagen.flow_from_directory(train_dir,
                                                 target_size=(224, 224),
                                                 batch_size=batch_size,
                                                 class_mode='categorical')

test_iterator = valid_datagen.flow_from_directory(valid_dir,
                                            target_size=(224, 224),
                                            batch_size=batch_size,
                                            class_mode='categorical')

Found 70295 images belonging to 38 classes.
Found 17572 images belonging to 38 classes.


In [0]:
#Affichage des différentes classes du jeu de données
class_dict = training_iterator.class_indices
print(class_dict)

{'Apple___Apple_scab': 0, 'Apple___Black_rot': 1, 'Apple___Cedar_apple_rust': 2, 'Apple___healthy': 3, 'Blueberry___healthy': 4, 'Cherry_(including_sour)___Powdery_mildew': 5, 'Cherry_(including_sour)___healthy': 6, 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': 7, 'Corn_(maize)___Common_rust_': 8, 'Corn_(maize)___Northern_Leaf_Blight': 9, 'Corn_(maize)___healthy': 10, 'Grape___Black_rot': 11, 'Grape___Esca_(Black_Measles)': 12, 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)': 13, 'Grape___healthy': 14, 'Orange___Haunglongbing_(Citrus_greening)': 15, 'Peach___Bacterial_spot': 16, 'Peach___healthy': 17, 'Pepper,_bell___Bacterial_spot': 18, 'Pepper,_bell___healthy': 19, 'Potato___Early_blight': 20, 'Potato___Late_blight': 21, 'Potato___healthy': 22, 'Raspberry___healthy': 23, 'Soybean___healthy': 24, 'Squash___Powdery_mildew': 25, 'Strawberry___Leaf_scorch': 26, 'Strawberry___healthy': 27, 'Tomato___Bacterial_spot': 28, 'Tomato___Early_blight': 29, 'Tomato___Late_blight': 30, 'Tomato

In [0]:
#Affichage des différentes classes du jeu de données
class_labels = list(class_dict.keys())
print(class_labels)

['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Sp

In [0]:
train_num_samples = training_iterator.samples
valid_num_samples = test_iterator.samples
# création du modèle
model = define_model()
model.summary()

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________

In [0]:

weightsfilepath = "/bestweights.hdf5"
checkpoint = ModelCheckpoint(weightsfilepath, monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=True, mode='max')
callbacks_list = [checkpoint]

In [0]:
#Entrainement du modèle
history = model.fit_generator(training_iterator, steps_per_epoch=len(training_iterator),
		validation_data=test_iterator, validation_steps=len(test_iterator), epochs=8, callbacks=callbacks_list, verbose=2)

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 1/8
 - 869s - loss: 0.4585 - acc: 0.8594 - val_loss: 0.2152 - val_acc: 0.9272

Epoch 00001: val_acc improved from -inf to 0.92721, saving model to /bestweights.hdf5
Epoch 2/8
 - 848s - loss: 0.0991 - acc: 0.9665 - val_loss: 0.1126 - val_acc: 0.9610

Epoch 00002: val_acc improved from 0.92721 to 0.96102, saving model to /bestweights.hdf5
Epoch 3/8
 - 848s - loss: 0.0572 - acc: 0.9814 - val_loss: 0.0820 - val_acc: 0.9721

Epoch 00003: val_acc improved from 0.96102 to 0.97206, saving model to /bestweights.hdf5
Epoch 4/8
 - 848s - loss: 0.0382 - acc: 0.9883 - val_loss: 0.0670 - val_acc: 0.9774

Epoch 00004: val_acc improved from 0.97206 to 0.97735, saving model to /bestweights.hdf5
Epoch 5/8
 - 847s - loss: 0.0298 - acc: 0.9907 - val_loss: 0.0631 - val_acc: 0.9786

Epoch 00005: val_acc improved from 0.97735 to 0.97860, saving model to /bestweights.hdf5
Epoch 6/8
 - 847s - loss: 0.0226 - acc:

In [0]:
# évaluation du modèle
_, acc = model.evaluate_generator(test_iterator, steps=len(test_iterator), verbose=1)
print('> %.3f' % (acc * 100.0))

> 98.219


In [0]:
# graphes de précision et d'erreur du modèle
summarize_diagnostics(history)

In [0]:
#sauvegarde du modèle
model.save('plantdisease_detection_vgg16model.h5')

In [0]:
# Test de prédiction
img = load_image('../test/AppleScab3.JPG')
print("Prediction for AppleScab3:")
prediction = model.predict(img)
predicted_class_name = class_labels[np.argmax(prediction)]
print("Detected the leaf as ", predicted_class_name)  

Prediction for AppleScab3:
Detected the leaf as  Squash___Powdery_mildew


In [0]:

for filename in os.listdir(test_dir):
    filepath = test_dir + '/' + filename
    img = load_image(filepath)
    prediction = model.predict(img)
    predicted_class_name = class_labels[np.argmax(prediction)]
    print(filename, "  predicted as ", predicted_class_name)  

TomatoEarlyBlight2.JPG   predicted as  Tomato___Late_blight
PotatoHealthy2.JPG   predicted as  Potato___healthy
TomatoYellowCurlVirus1.JPG   predicted as  Tomato___Tomato_Yellow_Leaf_Curl_Virus
TomatoYellowCurlVirus6.JPG   predicted as  Tomato___Tomato_Yellow_Leaf_Curl_Virus
TomatoEarlyBlight5.JPG   predicted as  Tomato___Early_blight
PotatoEarlyBlight2.JPG   predicted as  Potato___Early_blight
AppleScab3.JPG   predicted as  Squash___Powdery_mildew
PotatoEarlyBlight1.JPG   predicted as  Potato___Early_blight
PotatoEarlyBlight5.JPG   predicted as  Potato___Early_blight
TomatoYellowCurlVirus2.JPG   predicted as  Tomato___Tomato_Yellow_Leaf_Curl_Virus
PotatoHealthy1.JPG   predicted as  Potato___healthy
TomatoHealthy1.JPG   predicted as  Tomato___healthy
TomatoEarlyBlight1.JPG   predicted as  Tomato___Early_blight
AppleScab1.JPG   predicted as  Apple___Apple_scab
TomatoHealthy3.JPG   predicted as  Tomato___healthy
AppleCedarRust3.JPG   predicted as  Apple___Cedar_apple_rust
TomatoEarlyBlig

In [0]:
#Téléchargement du modèle
from google.colab import files
files.download('plantdisease_detection_vgg16model.h5')

In [0]:
#Convertion du modèle en format tflite
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file('plantdisease_detection_vgg16model.h5')
tflite_model = converter.convert()

with open('model_vgg.tflite', 'wb') as f:
  f.write(tflite_model)

Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 40 variables.
INFO:tensorflow:Converted 40 variables to const ops.


In [0]:
#Création du fichier labels contenant les différentes classes de maladies
print (training_iterator.class_indices)

labels = '\n'.join(sorted(training_iterator.class_indices.keys()))

with open('labels.txt', 'w') as f:
  f.write(labels)

{'Apple___Apple_scab': 0, 'Apple___Black_rot': 1, 'Apple___Cedar_apple_rust': 2, 'Apple___healthy': 3, 'Blueberry___healthy': 4, 'Cherry_(including_sour)___Powdery_mildew': 5, 'Cherry_(including_sour)___healthy': 6, 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot': 7, 'Corn_(maize)___Common_rust_': 8, 'Corn_(maize)___Northern_Leaf_Blight': 9, 'Corn_(maize)___healthy': 10, 'Grape___Black_rot': 11, 'Grape___Esca_(Black_Measles)': 12, 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)': 13, 'Grape___healthy': 14, 'Orange___Haunglongbing_(Citrus_greening)': 15, 'Peach___Bacterial_spot': 16, 'Peach___healthy': 17, 'Pepper,_bell___Bacterial_spot': 18, 'Pepper,_bell___healthy': 19, 'Potato___Early_blight': 20, 'Potato___Late_blight': 21, 'Potato___healthy': 22, 'Raspberry___healthy': 23, 'Soybean___healthy': 24, 'Squash___Powdery_mildew': 25, 'Strawberry___Leaf_scorch': 26, 'Strawberry___healthy': 27, 'Tomato___Bacterial_spot': 28, 'Tomato___Early_blight': 29, 'Tomato___Late_blight': 30, 'Tomato

In [0]:
#Téléchargement du modèle tflite et du fichier labels
from google.colab import files

files.download('model_vgg.tflite')
files.download('labels.txt')