In [None]:
# === TF/Keras: VGG16 Transfer Learning + Report + 10 Sample Predictions (PRINT ONLY; 2 FIGURES) ===
# - Copies dataset once from Drive -> local for faster I/O
# - VGG16 (ImageNet) backbone + small custom head
# - Prints training/testing time
# - Figure 1: Epoch vs Validation Accuracy (baseline run)
# - Figure 2: Grid of 10 test images with GT + predicted labels (+ table)
# - Amount-of-data vs performance table
# - No files are saved

import os, time, random, shutil
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
import matplotlib.pyplot as plt

# ---------------- CONFIG ----------------
DRIVE_ROOT = "/content/drive/MyDrive/AI assignment /final"   # must contain train/val/test
LOCAL_ROOT = "/content/dataset_local/final"

IMG_SIZE  = (224, 224)
BATCH     = 32
EPOCHS    = 12
SEED      = 42
AUTOTUNE  = tf.data.AUTOTUNE
FRACTIONS = [1.00, 0.50, 0.25]
# ---------------------------------------

print("Physical GPUs:", tf.config.list_physical_devices('GPU'))
try:
    tf.keras.mixed_precision.set_global_policy("mixed_float16")
except Exception as e:
    print("Mixed precision not set:", e)

tf.keras.utils.set_random_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# ---------------- Dataset staging ----------------
def stage_dataset(drive_root, local_root):
    must = ["train","val","test"]
    for m in must:
        if not os.path.exists(os.path.join(drive_root, m)):
            raise FileNotFoundError(f"Missing folder: {m} under {drive_root}")
    if not (os.path.exists(local_root) and all(os.path.exists(os.path.join(local_root,m)) for m in must)):
        if os.path.exists(local_root):
            shutil.rmtree(local_root)
        os.makedirs(os.path.dirname(local_root), exist_ok=True)
        shutil.copytree(drive_root, local_root)
        print("Copied dataset to local:", local_root)
    else:
        print("Using existing local copy:", local_root)
    return local_root

DATA_ROOT = stage_dataset(DRIVE_ROOT, LOCAL_ROOT)

# ---------------- Dataset helpers ----------------
def list_class_files(dir_path):
    classes = sorted([d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path,d))])
    files, labels = [], []
    for idx, cname in enumerate(classes):
        cdir = os.path.join(dir_path, cname)
        for root, _, fnames in os.walk(cdir):
            for fn in fnames:
                if fn.lower().endswith((".jpg",".jpeg",".png",".bmp",".webp",".tif",".tiff")):
                    files.append(os.path.join(root, fn))
                    labels.append(idx)
    return files, labels, classes

def make_ds_from_files(files, labels, batch=BATCH, shuffle=False):
    files  = tf.convert_to_tensor(files, dtype=tf.string)
    labels = tf.convert_to_tensor(labels, dtype=tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((files, labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(files), seed=SEED, reshuffle_each_iteration=True)

    def _load(path, y):
        img = tf.io.read_file(path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.resize(img, IMG_SIZE)
        img.set_shape(IMG_SIZE + (3,))
        img = tf.cast(img, tf.float32)
        img = preprocess_input(img)
        return img, y

    ds = ds.map(_load, num_parallel_calls=AUTOTUNE)
    ds = ds.cache()
    ds = ds.batch(batch).prefetch(AUTOTUNE)
    return ds

def load_full_splits(root):
    tr_files, tr_labels, class_names = list_class_files(os.path.join(root, "train"))
    va_files, va_labels, _ = list_class_files(os.path.join(root, "val"))
    te_files, te_labels, _ = list_class_files(os.path.join(root, "test"))
    val_ds  = make_ds_from_files(va_files, va_labels)
    test_ds = make_ds_from_files(te_files, te_labels)
    meta = {
        "class_names": class_names,
        "train_files": tr_files, "train_labels": tr_labels,
        "val_files": va_files,   "val_labels": va_labels,
        "test_files": te_files,  "test_labels": te_labels
    }
    return meta, val_ds, test_ds

def make_train_subset(meta, fraction):
    tr_files, tr_labels = meta["train_files"], meta["train_labels"]
    n_total = len(tr_files)
    n_keep = max(1, int(n_total * fraction))
    idx = np.random.default_rng(SEED).permutation(n_total)[:n_keep]
    files_sub  = [tr_files[i]  for i in idx]
    labels_sub = [tr_labels[i] for i in idx]
    return make_ds_from_files(files_sub, labels_sub, shuffle=True), n_keep

# ---------------- Model ----------------
def build_vgg16_transfer(num_classes):
    base = VGG16(include_top=False, weights="imagenet", input_shape=IMG_SIZE + (3,))
    base.trainable = False

    inputs = layers.Input(shape=IMG_SIZE + (3,))
    x = base(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation="softmax", dtype="float32")(x)
    return models.Model(inputs, outputs)

def compile_model(m):
    opt = tf.keras.optimizers.Adam(learning_rate=3e-4)
    m.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return m

# ---------------- Train once ----------------
def train_once(fraction, meta, val_ds, test_ds, show_curve=False, return_model=False):
    train_ds, n_train = make_train_subset(meta, fraction)
    model = build_vgg16_transfer(len(meta["class_names"]))
    compile_model(model)

    t0 = time.time()
    hist = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, verbose=0)
    train_time = time.time() - t0

    t1 = time.time()
    test_loss, test_acc = model.evaluate(test_ds, verbose=0)
    test_time = time.time() - t1

    print(f"\n[VGG16 Transfer | fraction={fraction:.2f}]  "
          f"train_samples={n_train}  params={model.count_params():,}  "
          f"train_time={train_time:.1f}s  test_time={test_time:.1f}s  "
          f"test_acc={test_acc:.3f}")

    if show_curve:
        val_acc = hist.history["val_accuracy"]
        plt.figure(figsize=(6,4))
        plt.plot(range(1, EPOCHS+1), val_acc, marker="o")
        plt.xlabel("Epoch")
        plt.ylabel("Validation Accuracy")
        plt.title("Epoch vs Validation Accuracy")
        plt.grid(True)
        plt.show()

    summary = {
        "fraction": fraction,
        "params": model.count_params(),
        "train_samples": n_train,
        "train_time_s": train_time,
        "test_time_s": test_time,
        "val_acc_last": hist.history["val_accuracy"][-1],
        "test_acc": test_acc
    }
    return (summary, model) if return_model else summary

# ---------------- Predictions ----------------
def show_predictions(model, meta, k=10):
    rng = np.random.default_rng(SEED)
    idxs = rng.choice(len(meta["test_files"]), size=k, replace=False)

    rows = []
    ncols = 5
    nrows = int(np.ceil(k / ncols))
    plt.figure(figsize=(3*ncols, 2.6*nrows))

    for i, idx in enumerate(idxs, 1):
        path = meta["test_files"][idx]
        y_true = meta["test_labels"][idx]

        img = tf.io.read_file(path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.resize(img, IMG_SIZE)

        img_net = preprocess_input(tf.cast(img, tf.float32))
        probs = model(tf.expand_dims(img_net, 0), training=False).numpy()[0]
        pred = np.argmax(probs)

        gt_name = meta["class_names"][y_true]
        pr_name = meta["class_names"][pred]

        plt.subplot(nrows, ncols, i)
        plt.imshow(tf.cast(img, tf.uint8))
        plt.axis("off")
        plt.title(f"gt:{gt_name}\npred:{pr_name} ({np.max(probs):.2f})")

        rows.append([os.path.basename(path), gt_name, pr_name, np.max(probs)])

    plt.tight_layout()
    plt.show()

    df = pd.DataFrame(rows, columns=["file", "gt", "pred", "prob"])
    print(df.to_string(index=False))

# ================= RUN =================
meta, val_ds, test_ds = load_full_splits(DATA_ROOT)
print("Classes:", meta["class_names"])
print(f"Samples â†’ train={len(meta['train_files'])}, val={len(meta['val_files'])}, test={len(meta['test_files'])}")

baseline, model = train_once(1.00, meta, val_ds, test_ds, show_curve=True, return_model=True)
show_predictions(model, meta, k=10)

rows = [train_once(f, meta, val_ds, test_ds) for f in FRACTIONS]
print("\n=== Amount of Data vs Performance (VGG16 Transfer) ===")
print(pd.DataFrame(rows)[["fraction","train_samples","train_time_s","test_time_s","val_acc_last","test_acc"]]
      .sort_values("fraction"))
