In [1]:
# --------------------------------------------------------------
# ROBUST RECAPTURE DETECTOR – 3 DEVICES, NO RECAPTURE BIAS
# --------------------------------------------------------------

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os, json, numpy as np, tensorflow as tf
from pathlib import Path
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras import applications, layers, models, callbacks

# ---------- 0. SETTINGS ----------
MERGED_ROOT   = "/content/merged_dataset"
RESULTS_DIR   = "/content/drive/MyDrive/Merged_Model/results_final"
CHECKPOINT_DIR = f"{RESULTS_DIR}/checkpoints"
Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)

# ---------- 1. LINK ALL IMAGES (flat) ----------
!rm -rf {MERGED_ROOT}
os.makedirs(f"{MERGED_ROOT}/originals", exist_ok=True)
os.makedirs(f"{MERGED_ROOT}/recaptures", exist_ok=True)

def link_flat(src, dst, prefix=None):
    if not Path(src).exists(): return 0
    cnt = 0
    for img in Path(src).iterdir():
        if img.suffix.lower() in {".jpg",".jpeg",".png",".heic"}:
            name = f"{prefix}_{img.name}" if prefix else img.name
            dst_path = Path(dst) / name
            if not dst_path.exists():
                os.symlink(img, dst_path)
                cnt += 1
    return cnt

print("Linking 3 devices …")
devices = [
    ("NTU-Roselab-Dataset", None),
    ("Nothing_2a",          "nothing"),
    ("Poco-M3",             "poco")
]
for name, pref in devices:
    link_flat(f"/content/drive/MyDrive/{name}/originals",   f"{MERGED_ROOT}/originals",   pref)
    link_flat(f"/content/drive/MyDrive/{name}/recaptures",  f"{MERGED_ROOT}/recaptures",  pref)

# ---------- 2. SPLIT 80/20 (per class) ----------
def get_paths(folder, prefix=None):
    return [p for p in Path(folder).iterdir()
            if p.suffix.lower() in {".jpg",".jpeg",".png",".heic"}
            and (prefix is None or p.name.startswith(f"{prefix}_"))]

all_orig = get_paths(f"{MERGED_ROOT}/originals")
all_recp = get_paths(f"{MERGED_ROOT}/recaptures")

np.random.seed(42)
np.random.shuffle(all_orig); np.random.shuffle(all_recp)

train_orig = all_orig[:int(0.8*len(all_orig))]
val_orig   = all_orig[int(0.8*len(all_orig)):]
train_recp = all_recp[:int(0.8*len(all_recp))]
val_recp   = all_recp[int(0.8*len(all_recp)):]

print(f"TRAIN → Original {len(train_orig)} | Recaptured {len(train_recp)}")
print(f"VAL   → Original {len(val_orig)} | Recaptured {len(val_recp)}")

# ---------- 3. tf.data PIPELINES ----------
def decode(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = img[:, :, :3]
    img = tf.image.resize(img, (224, 224))
    return tf.keras.applications.efficientnet.preprocess_input(img)

def get_label(path):
    return tf.cast(tf.strings.split(path, os.sep)[-2] == "recaptures", tf.float32)

# ----- TRAIN (oversample *Original* 4×) -----
train_paths = [str(p) for p in train_orig + train_recp]
train_ds = tf.data.Dataset.from_tensor_slices(train_paths)
train_ds = train_ds.map(lambda p: (p, tf.strings.split(p, os.sep)[-1]))

orig_ds = train_ds.filter(lambda p, fn: ~tf.strings.regex_full_match(fn, ".*_.*"))
recp_ds = train_ds.filter(lambda p, fn:  tf.strings.regex_full_match(fn, ".*_.*"))

orig_ds = orig_ds.map(lambda p, _: (decode(p), get_label(p))).repeat(4).shuffle(2000)
recp_ds = recp_ds.map(lambda p, _: (decode(p), get_label(p)))

train_ds = orig_ds.concatenate(recp_ds).shuffle(4000).batch(16).prefetch(tf.data.AUTOTUNE)

# ----- VALIDATION -----
val_paths = [str(p) for p in val_orig + val_recp]
val_ds = tf.data.Dataset.from_tensor_slices(val_paths)
val_ds = val_ds.map(lambda p: (decode(p), get_label(p))).batch(32).prefetch(tf.data.AUTOTUNE)

# ---------- 4. MODEL (EfficientNetB0 – swap to ResNet50 below) ----------
inp = tf.keras.Input((224, 224, 3))
base = applications.EfficientNetB0(include_top=False, weights='imagenet')(inp)
# base = applications.ResNet50(include_top=False, weights='imagenet')(inp)   # <-- uncomment for ResNet50
x = layers.GlobalAveragePooling2D()(base)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
out = layers.Dense(1, activation='sigmoid')(x)
model = models.Model(inp, out)
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy', 'Precision', 'Recall'])

# ---------- 5. CALLBACKS ----------
checkpoint_path = f"{CHECKPOINT_DIR}/cp-best.weights.h5"

cp_callback = callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor='val_f1_original',
    mode='max',
    save_best_only=True,
    verbose=1
)

class F1OriginalCallback(tf.keras.callbacks.Callback):
    def __init__(self, val_ds):
        super().__init__()
        self.val_ds = val_ds
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        y_true, y_pred = [], []
        for xb, yb in self.val_ds:
            pred = (self.model.predict(xb, verbose=0) > 0.5).astype(int)
            y_true.extend(yb.numpy())
            y_pred.extend(pred.flatten())
        f1 = classification_report(y_true, y_pred,
                                   target_names=['Original','Recaptured'],
                                   output_dict=True)['Original']['f1-score']
        logs['val_f1_original'] = f1
        print(f" — val_f1_original: {f1:.4f}")

f1_cb = F1OriginalCallback(val_ds)

early_stop = callbacks.EarlyStopping(
    monitor='val_f1_original',
    mode='max',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# RESUME
if os.path.exists(checkpoint_path):
    print("Resuming from checkpoint...")
    model.load_weights(checkpoint_path)

# ---------- 6. TRAIN ----------
print("\nStarting training (oversample Original)…")
model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=30,
    steps_per_epoch=250,
    callbacks=[cp_callback, f1_cb, early_stop],
    verbose=1
)

# ---------- 7. FINAL EVALUATION ----------
y_true, y_pred = [], []
for xb, yb in val_ds:
    pred = model.predict(xb, verbose=0)
    y_true.extend(yb.numpy())
    y_pred.extend((pred > 0.5).astype(int).flatten())

y_true = np.array(y_true)
y_pred = np.array(y_pred)

cm = confusion_matrix(y_true, y_pred)
report = classification_report(y_true, y_pred,
                               target_names=['Original','Recaptured'],
                               output_dict=True)
acc = cm.diagonal().sum() / cm.sum()
f1_orig = report['Original']['f1-score']
f1_recp = report['Recaptured']['f1-score']

# ----- SAVE RESULTS (2 dp) -----
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Original','Recaptured'],
            yticklabels=['Original','Recaptured'])
plt.title(f'CM – Acc: {acc:.2%} | F1-Orig: {f1_orig:.2f}')
plt.ylabel('True'); plt.xlabel('Predicted')
plt.savefig(f"{RESULTS_DIR}/confusion_matrix.png", dpi=200, bbox_inches='tight')
plt.close()

df_rep = pd.DataFrame(report).transpose().round(2)
df_rep.to_csv(f"{RESULTS_DIR}/classification_report.csv")

metrics = {
    "val_accuracy": round(float(acc), 2),
    "f1_original": round(f1_orig, 2),
    "f1_recaptured": round(f1_recp, 2),
    "original_correct": int(cm[0,0]),
    "recaptured_correct": int(cm[1,1]),
    "total_samples": int(cm.sum())
}
with open(f"{RESULTS_DIR}/metrics_summary.json", 'w') as f:
    json.dump(metrics, f, indent=2)

print("\n" + "="*60)
print("FINAL RESULTS (2 dp)")
print(f"Val Accuracy : {acc:.2%}")
print(f"F1 Original  : {f1_orig:.2f}")
print(f"F1 Recaptured: {f1_recp:.2f}")
print(f"Original     : {cm[0,0]}/{cm[0].sum()}")
print(f"Recaptured   : {cm[1,1]}/{cm[1].sum()}")
print("="*60)

# ---------- 8. SAVE MODEL + TFLite ----------
model.save(f"{RESULTS_DIR}/final_model.keras")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite = converter.convert()
with open(f"{RESULTS_DIR}/recapture_detector.tflite", 'wb') as f:
    f.write(tflite)

print(f"TFLite saved: {RESULTS_DIR}/recapture_detector.tflite")


Mounted at /content/drive
Linking 3 devices …
TRAIN → Original 1151 | Recaptured 1146
VAL   → Original 288 | Recaptured 287
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step

Starting training (oversample Original)…
Epoch 1/30
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - Precision: 0.4992 - Recall: 0.4496 - accuracy: 0.8642 - loss: 0.3580

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.6866
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1191s[0m 3s/step - Precision: 0.4999 - Recall: 0.4505 - accuracy: 0.8641 - loss: 0.3579 - val_Precision: 0.9630 - val_Recall: 0.0906 - val_accuracy: 0.5443 - val_loss: 2.3843 - val_f1_original: 0.6866
Epoch 2/30
[1m 73/250[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m7:14[0m 2s/step - Precision: 0.7255 - Recall: 0.7315 - accuracy: 0.8676 - loss: 0.3365

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.8600
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m308s[0m 1s/step - Precision: 0.7361 - Recall: 0.7289 - accuracy: 0.8671 - loss: 0.3134 - val_Precision: 0.8945 - val_Recall: 0.7979 - val_accuracy: 0.8522 - val_loss: 0.4165 - val_f1_original: 0.8600
Epoch 3/30
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - Precision: 0.6861 - Recall: 0.7082 - accuracy: 0.9185 - loss: 0.2122

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.8348
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1077s[0m 3s/step - Precision: 0.6865 - Recall: 0.7085 - accuracy: 0.9185 - loss: 0.2122 - val_Precision: 0.8233 - val_Recall: 0.8606 - val_accuracy: 0.8383 - val_loss: 0.6115 - val_f1_original: 0.8348
Epoch 4/30
[1m 73/250[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m7:30[0m 3s/step - Precision: 0.8096 - Recall: 0.8956 - accuracy: 0.9189 - loss: 0.2154

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.8125
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m283s[0m 1s/step - Precision: 0.8315 - Recall: 0.8873 - accuracy: 0.9228 - loss: 0.2094 - val_Precision: 0.9215 - val_Recall: 0.6132 - val_accuracy: 0.7809 - val_loss: 0.5967 - val_f1_original: 0.8125
Epoch 5/30
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - Precision: 0.7454 - Recall: 0.7439 - accuracy: 0.9431 - loss: 0.1523

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.8023
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1040s[0m 3s/step - Precision: 0.7459 - Recall: 0.7443 - accuracy: 0.9431 - loss: 0.1523 - val_Precision: 0.9412 - val_Recall: 0.5575 - val_accuracy: 0.7617 - val_loss: 0.6721 - val_f1_original: 0.8023
Epoch 6/30
[1m 73/250[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m7:42[0m 3s/step - Precision: 0.8463 - Recall: 0.7744 - accuracy: 0.9077 - loss: 0.2047

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.8379
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m332s[0m 1s/step - Precision: 0.8276 - Recall: 0.7782 - accuracy: 0.9050 - loss: 0.2142 - val_Precision: 0.9330 - val_Recall: 0.6794 - val_accuracy: 0.8157 - val_loss: 0.5737 - val_f1_original: 0.8379
Epoch 7/30
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - Precision: 0.8007 - Recall: 0.8387 - accuracy: 0.9520 - loss: 0.1315

  if self._should_save_model(epoch, batch, logs, filepath):


 — val_f1_original: 0.8322
[1m250/250[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1078s[0m 3s/step - Precision: 0.8010 - Recall: 0.8388 - accuracy: 0.9520 - loss: 0.1315 - val_Precision: 0.8627 - val_Recall: 0.7666 - val_accuracy: 0.8226 - val_loss: 0.3829 - val_f1_original: 0.8322
Epoch 7: early stopping
Restoring model weights from the end of the best epoch: 2.

FINAL RESULTS (2 dp)
Val Accuracy : 85.22%
F1 Original  : 0.86
F1 Recaptured: 0.84
Original     : 261/288
Recaptured   : 229/287
Saved artifact at '/tmp/tmpit5ari09'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor')
Output Type:
  TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)
Captures:
  138848177467664: TensorSpec(shape=(1, 1, 1, 3), dtype=tf.float32, name=None)
  138848177469584: Tensor