In [9]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

print(tf.__version__)

batch_size = 32
img_height = 530
img_width = 1020


train_dir = 'C:\\Users\\moizk\\Documents\\5_shot\\train'
test_dir = 'C:\\Users\\moizk\\Documents\\5_shot\\test'

# Load Images and split then info a training set and a validation set 
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical')

val_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical')

# Configure
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

for images, labels in train_ds.take(1):
    print(images.shape, labels.shape)

2.15.0
Found 110 files belonging to 22 classes.
Using 88 files for training.
Found 110 files belonging to 22 classes.
Using 22 files for validation.
(24, 530, 1020, 3) (24, 22)


In [10]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.layers import Dropout

base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(530, 1020, 3))

model = tf.keras.models.Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(1024, activation='relu'),
    Dropout(0.5),
    Dense(22, activation='softmax')
])

# Freeze the base_model layers
for layer in base_model.layers:
    layer.trainable = False
    
base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = 100  # Adjust this according to the base model's layer count

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])



Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5



In [None]:
# Train Model

epochs = 10  

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)