In [1]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.models import Model
import os

In [2]:
IMG_SIZE = (224, 224)  # ResNet50 expects 224x224 images
BATCH_SIZE = 32


In [3]:
train_dir = "C:\\Users\\USER\\Desktop\\capstone\\archive\\data\\train"
val_dir = "C:\\Users\\USER\\Desktop\\capstone\\archive\\data\\val"

In [4]:
train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

val_dataset = image_dataset_from_directory(val_dir,
                                           shuffle=True,
                                           batch_size=BATCH_SIZE,
                                           image_size=IMG_SIZE)

Found 33984 files belonging to 4 classes.
Found 6400 files belonging to 4 classes.


In [5]:
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y))
val_dataset = val_dataset.map(lambda x, y: (normalization_layer(x), y))

In [6]:
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False  # Freeze the base model

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 0us/step


In [7]:
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(1, activation='sigmoid')(x)  

In [8]:
model = Model(inputs=base_model.input, outputs=predictions)

In [9]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss='binary_crossentropy',  # Change to 'categorical_crossentropy' for multi-class
              metrics=['accuracy'])

In [10]:
model_save_dir = "model_checkpoints"
os.makedirs(model_save_dir, exist_ok=True)

In [11]:
for epoch in range(10):
    model.fit(train_dataset, validation_data=val_dataset, epochs=1)
    model.save(os.path.join(model_save_dir, f'epoch_{epoch+1:02d}.h5'))

[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3290s[0m 3s/step - accuracy: 0.1906 - loss: -125.5954 - val_accuracy: 0.0100 - val_loss: -1860.2554




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19146s[0m 18s/step - accuracy: 0.1907 - loss: -1484.3706 - val_accuracy: 0.0100 - val_loss: -6567.1787




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2233s[0m 2s/step - accuracy: 0.1911 - loss: -4212.5747 - val_accuracy: 0.0100 - val_loss: -13453.7930




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2281s[0m 2s/step - accuracy: 0.1912 - loss: -8002.1514 - val_accuracy: 0.0100 - val_loss: -22175.5176




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2124s[0m 2s/step - accuracy: 0.1912 - loss: -12748.7920 - val_accuracy: 0.0100 - val_loss: -32568.4805




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2138s[0m 2s/step - accuracy: 0.1911 - loss: -18351.0156 - val_accuracy: 0.0100 - val_loss: -44595.4883




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2425s[0m 2s/step - accuracy: 0.1907 - loss: -24841.5391 - val_accuracy: 0.0100 - val_loss: -58170.5234




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2214s[0m 2s/step - accuracy: 0.1909 - loss: -31946.1016 - val_accuracy: 0.0100 - val_loss: -73302.0625




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5798s[0m 5s/step - accuracy: 0.1915 - loss: -40073.1914 - val_accuracy: 0.0100 - val_loss: -89996.3438




[1m1062/1062[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2515s[0m 2s/step - accuracy: 0.1910 - loss: -48784.1641 - val_accuracy: 0.0100 - val_loss: -108175.2812


