In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from sklearn.metrics import accuracy_score
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.callbacks import EarlyStopping

In [2]:
train_dir = 'train'
test_dir = 'test'

In [3]:
image_size = (224, 224)
batch_size = 32

In [4]:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

In [5]:
test_datagen = ImageDataGenerator(rescale=1./255)

In [6]:
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical'
)

Found 32398 images belonging to 3 classes.


In [7]:
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical'
)

Found 10500 images belonging to 3 classes.


In [8]:
X_train, y_train = next(train_generator)
X_test, y_test = next(test_generator)

In [9]:
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Flatten(),
    Dense(512, activation='relu'),
    Dropout(0.5),
    Dense(3, activation='softmax')
])

  super().__init__(


In [10]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [11]:
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    epochs=30,
    validation_data=test_generator,
    validation_steps=test_generator.samples // test_generator.batch_size
)

Epoch 1/30


  self._warn_if_super_not_called()


[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3071s[0m 3s/step - accuracy: 0.8427 - loss: 0.4088 - val_accuracy: 0.9015 - val_loss: 0.2808
Epoch 2/30
[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 211us/step - accuracy: 0.9375 - loss: 0.0726 - val_accuracy: 1.0000 - val_loss: 0.0833
Epoch 3/30


  self.gen.throw(typ, value, traceback)


[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3004s[0m 3s/step - accuracy: 0.9262 - loss: 0.2193 - val_accuracy: 0.8647 - val_loss: 0.3627
Epoch 4/30
[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 75us/step - accuracy: 0.8438 - loss: 0.1729 - val_accuracy: 1.0000 - val_loss: 0.0279
Epoch 5/30
[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2990s[0m 3s/step - accuracy: 0.9333 - loss: 0.1964 - val_accuracy: 0.9233 - val_loss: 0.2149
Epoch 6/30
[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 90us/step - accuracy: 0.9688 - loss: 0.0463 - val_accuracy: 1.0000 - val_loss: 0.0100
Epoch 7/30
[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2982s[0m 3s/step - accuracy: 0.9394 - loss: 0.1849 - val_accuracy: 0.9303 - val_loss: 0.1939
Epoch 8/30
[1m1012/1012[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 78us/step - accuracy: 0.9375 - loss: 0.1282 - val_accuracy: 1.0000 - val_loss: 0.0531
Epoch 9/30
[1m101

In [12]:
test_loss, test_acc = model.evaluate(test_generator, steps=test_generator.samples // test_generator.batch_size)
print('Test accuracy:', test_acc)

[1m328/328[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m278s[0m 847ms/step - accuracy: 0.9413 - loss: 0.1755
Test accuracy: 0.941692054271698


In [13]:
from tensorflow.keras.models import save_model
from tensorflow.keras.models import load_model

In [14]:
model.save('fire_detection_model.h5')



In [15]:
model.save('fire_detection_model.keras')