In [1]:
import os, random, shutil
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras import layers, Model
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=25,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    brightness_range=[0.7,1.3],
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

train_gen = train_datagen.flow_from_directory(
    'dataset/train',
    target_size=(224,224),
    batch_size=16,
    class_mode='binary',
    color_mode='rgb'
)

val_gen = val_datagen.flow_from_directory(
    'dataset/val',
    target_size=(224,224),
    batch_size=16,
    class_mode='binary',
    color_mode='rgb'
)

test_gen = test_datagen.flow_from_directory(
    'dataset/test',
    target_size=(224,224),
    batch_size=16,
    class_mode='binary',
    color_mode='rgb',
    shuffle=False
)

Found 661 images belonging to 2 classes.
Found 453 images belonging to 2 classes.
Found 449 images belonging to 2 classes.


In [5]:
# fake_count = len('/kaggle/input/fake-vs-real-medicine-datasets-images/dataset/train/Fake')
# real_count = len('/kaggle/input/fake-vs-real-medicine-datasets-images/dataset/train/Real')
fake_count = len(os.listdir('dataset/Fake'))
real_count = len(os.listdir('dataset/Real'))

total = fake_count + real_count

weight_for_fake = (1 / fake_count) * (total / 2.0)
weight_for_real = (1 / real_count) * (total / 2.0)

class_weight = {0: weight_for_real, 1: weight_for_fake}
print("Class weights:", class_weight)

Class weights: {0: 0.7843601895734598, 1: 1.3791666666666667}


In [6]:
base = VGG16(
    include_top=False,
    weights='imagenet',
    input_shape=(224,224,3)
)
base.trainable = False  # freeze the convolutional base

x = layers.GlobalAveragePooling2D()(base.output)
x = layers.Dropout(0.3)(x)
out = layers.Dense(1, activation='sigmoid')(x)

model = Model(inputs=base.input, outputs=out)

model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy', 
             tf.keras.metrics.Precision(name='precision'), 
             tf.keras.metrics.Recall(name='recall')]
)

model.summary()

In [7]:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "best_model.keras", save_best_only=True, monitor="val_loss", mode="min"
)
earlystop_cb = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", patience=5, restore_best_weights=True
)
callbacks = [checkpoint_cb, earlystop_cb]

In [8]:
history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=50,  # increased epochs with early stopping
    class_weight=class_weight,
    callbacks=callbacks
)

  self._warn_if_super_not_called()


Epoch 1/50
[1m 3/42[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m1:53[0m 3s/step - accuracy: 0.5208 - loss: 2.9211 - precision: 0.7738 - recall: 0.3165

KeyboardInterrupt: 

In [None]:
test_loss, test_acc, test_precision, test_recall = model.evaluate(test_gen)
print(f"Test Accuracy: {test_acc:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}")

In [None]:
print(train_gen.class_indices)

In [None]:
# Predict labels
y_pred_prob = model.predict(test_gen)
y_pred = (y_pred_prob > 0.5).astype(int).flatten()
y_true = test_gen.classes  # uncomment this line

# Class labels from generator
class_labels = list(test_gen.class_indices.keys())

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:\n", cm)

# Plot confusion matrix with class names
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()

# Classification report
print("\nClassification Report:\n", classification_report(y_true, y_pred, target_names=class_labels))

In [None]:
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.legend(); plt.show()