In [76]:
import tensorflow as tf
from tensorflow.keras import models, layers
import matplotlib.pyplot as plt

In [223]:
BATCH_SIZE = 128
CHANNELS=3
EPOCHS=50

In [224]:
dataset = tf.keras.utils.image_dataset_from_directory(
    "DiseaseDataset",
    shuffle = True,
    seed=123,
    image_size = (294,222),
    batch_size = BATCH_SIZE
)

Found 27153 files belonging to 10 classes.


In [225]:
class_names = dataset.class_names
class_names

['1. Eczema 1677',
 '10. Warts Molluscum and other Viral Infections - 2103',
 '2. Melanoma 15.75k',
 '3. Atopic Dermatitis - 1.25k',
 '4. Basal Cell Carcinoma (BCC) 3323',
 '5. Melanocytic Nevi (NV) - 7970',
 '6. Benign Keratosis-like Lesions (BKL) 2624',
 '7. Psoriasis pictures Lichen Planus and related diseases - 2k',
 '8. Seborrheic Keratoses and other Benign Tumors - 1.8k',
 '9. Tinea Ringworm Candidiasis and other Fungal Infections - 1.7k']

In [221]:
for image_batch, label_batch in dataset.take(1):
    print(image_batch[0].shape)

(294, 222, 3)


In [200]:
len(dataset)

849

In [201]:
train_size = 0.8
len(dataset)*train_size

679.2

In [202]:
train_ds=dataset.take(679)
len(train_ds)

679

In [203]:
test_ds = dataset.skip(679)
len(test_ds)

170

In [204]:
val_ds = test_ds.take(85)
test_ds = test_ds.skip(85)

In [205]:
len(test_ds)

85

In [206]:
len(val_ds)

85

In [207]:
len(train_ds)

679

In [226]:
def get_dataset_partitions_tf(ds, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000):
    assert (train_split + test_split + val_split) == 1
    ds_size=len(ds)

    if shuffle:
        ds=ds.shuffle(shuffle_size, seed=12)
    
    train_size = int(train_split * ds_size)
    val_size = int(val_split * ds_size)
    
    train_ds = ds.take(train_size)    
    val_ds = ds.skip(train_size).take(val_size)
    test_ds = ds.skip(train_size).skip(val_size)
    return train_ds, val_ds, test_ds
    

In [227]:
train_ds, val_ds, test_ds = get_dataset_partitions_tf(dataset)

In [228]:
len(train_ds)

170

In [229]:
len(val_ds)

21

In [230]:
len(test_ds)

22

In [231]:
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
val_ds = val_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds = test_ds.cache().shuffle(1000).prefetch(buffer_size=tf.data.AUTOTUNE)

In [232]:
resize_and_rescale = tf.keras.Sequential([
    layers.experimental.preprocessing.Resizing(294, 222),
    layers.experimental.preprocessing.Rescaling(1./255)
])

In [233]:
data_augmentation = tf.keras.Sequential([
    layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    layers.experimental.preprocessing.RandomRotation(0.2),
])

In [234]:
train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x, training=True), y)
).prefetch(buffer_size=tf.data.AUTOTUNE)

In [245]:
input_shape = (BATCH_SIZE, 294, 222, CHANNELS)

model = models.Sequential([
    resize_and_rescale,
    data_augmentation,
    layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu', input_shape=(294,222,3)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(294,222,3)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(294,222,3)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(294,222,3)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(294,222,3)),
    layers.MaxPooling2D((2,2)),
    layers.Conv2D(filters=64, kernel_size=(3,3), activation='relu', input_shape=(294,222,3)),
    layers.MaxPooling2D((2,2)),

    layers.Flatten(),
    layers.Dense(64,activation='relu'),
    layers.Dense(10,activation='softmax'),    
])

model.build(input_shape=input_shape)

In [246]:
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [247]:
history = model.fit(
    train_ds,
    batch_size=BATCH_SIZE,
    validation_data=val_ds,
    verbose=1,
    epochs=5
)

Epoch 1/5

KeyboardInterrupt: 