In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy

In [2]:
# Load dataset
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    r'C:/Users/Atharv/Downloads/new_dataset',
    image_size=(256, 256),
    batch_size=32
)

class_names = dataset.class_names
num_classes = len(class_names)

Found 47481 files belonging to 60 classes.


In [3]:
def get_dataset_partitions_tf(ds, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000):
    ds_size = len(ds)
    if shuffle:
        ds = ds.shuffle(shuffle_size, seed=12)
    train_size = int(train_split * ds_size)
    val_size = int(val_split * ds_size)

    train_ds = ds.take(train_size)
    val_ds = ds.skip(train_size).take(val_size)
    test_ds = ds.skip(train_size).skip(val_size)

    return train_ds, val_ds, test_ds

train_ds, val_ds, test_ds = get_dataset_partitions_tf(dataset)

In [4]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
# Load model without compilation
model = load_model('../api/saved_models/50_20.h5', compile=False)

# Recompile with current parameters
model.compile(
    optimizer=Adam(learning_rate=1e-5),
    loss=SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
# Unfreeze last 10 layers for fine-tuning
for layer in model.layers[-10:]:
    layer.trainable = True

# Recompile after unfreezing
model.compile(
    optimizer=Adam(learning_rate=1e-5),  # Lower LR for fine-tuning
    loss=SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,  # Correct variable name and assignment
    epochs=1
)

In [None]:
# for layer in model.get_layer('vgg16').layers[-4:]:  # Last few layers of VGG16
#     layer.trainable = True


In [None]:
print(f"Test accuracy: {accuracy*100:.2f}%")  # Correct f-string syntax