In [2]:
import numpy as np
import os
# import PIL
# import PIL.Image
import tensorflow as tf
# import tensorflow_datasets as tfds
import pathlib
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

dataDir = "/code/archive/garbage_classification/Garbage_classification"

imgHeight = 384
imgWidth = 512
batchSize = 32

trainDs = tf.keras.utils.image_dataset_from_directory(
  dataDir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(imgHeight, imgWidth),
  batch_size=batchSize)

valDs = tf.keras.utils.image_dataset_from_directory(
  dataDir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(imgHeight, imgWidth),
  batch_size=batchSize)

AUTOTUNE = tf.data.AUTOTUNE

trainDs = trainDs.cache().prefetch(buffer_size=AUTOTUNE)
valDs = valDs.cache().prefetch(buffer_size=AUTOTUNE)

data_augmentation = tf.keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.2),
  layers.RandomZoom(0.1)
])

data_augmentation2 = tf.keras.Sequential([
  layers.RandomContrast(0.3),
])

resize_and_rescale = tf.keras.Sequential([
  layers.Resizing(int(imgHeight/2), int(imgWidth/2)),
  layers.Rescaling(1./255)
])

model = tf.keras.Sequential()
model.add(resize_and_rescale),
model.add(data_augmentation),
model.add(layers.Conv2D(32, (3, 3), activation='relu')),
model.add(layers.MaxPooling2D((2, 2))),
model.add(layers.Conv2D(64, (3, 3), activation='relu')),
model.add(layers.MaxPooling2D((2, 2))),
model.add(layers.Conv2D(128, (3, 3), activation='relu')),
model.add(layers.Dropout(0.3))
model.add(layers.Flatten()),
model.add(layers.Dense(128, activation='relu')),
model.add(layers.Dense(5)),



model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(trainDs, epochs=12, 
                    validation_data=valDs)

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()


Found 2390 files belonging to 5 classes.
Using 1912 files for training.
Found 2390 files belonging to 5 classes.
Using 478 files for validation.
Epoch 1/12

: 

: 