In [None]:
import os
import cv2
import hashlib
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt


valid_text=('.jpg','.png','.jpeg')
for image in os.listdir('train/images'):
    
    if not image.lower().endswith(valid_text):
        continue
        
    path=os.path.join('train/images',image)
    try:
        if cv2.imread(path) is None:
            print('corrupeted')
            os.remove(path)
    except:
        print("couldn't read image")
        os.remove(path)

hashes={}
valid_text=('.jpeg','.png','.jpg')

for image in os.listdir('train/images'):

    if not image.lower().endswith(valid_text):
        continue

    path=os.path.join('train/images',image)
    with open(path,'rb') as file :
        h=hashlib.md5(file.read()).hexdigest()
        if h in hashes:
            os.remove(path)
        else:
            hashes[h]=image

train_datagen=ImageDataGenerator(rescale=1./255)
valid_datagen=ImageDataGenerator(rescale=1./255)

train_data=train_datagen.flow_from_directory(
    'train',
    target_size=(128,128),
    batch_size=32,
    class_mode='binary'
    
)
valid_data=valid_datagen.flow_from_directory(
    'valid',
    target_size=(128,128),
    batch_size=32,
    class_mode='binary' 
)

model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
    MaxPooling2D(pool_size=(2,2)),
    
    Conv2D(64, (3,3), activation='relu'),
    MaxPooling2D(pool_size=(2,2)),
    
    Flatten(),
    Dense(128, activation='relu'),
    Dropout(0.5),
    Dense(1, activation='sigmoid')  # binary classification
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

history=model.fit(
    train_data,
    epochs=10,
    validation_data=valid_data
    
)



plt.plot(history.history['accuracy'], label='train acc')
plt.plot(history.history['val_accuracy'], label='val acc')
plt.legend()
plt.show()


