<a href="https://colab.research.google.com/github/DevashreePatrikar/MLA_SpyNet/blob/main/MLA_SpyNet_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from tensorflow.keras.models import load_model
import cv2
import os

# ---------------------------
# Load trained MLA-SpyNet model
# ---------------------------
mla_model = load_model("models/mla_spynet_trained.h5")

# ---------------------------
# Load test dataset
# ---------------------------
def load_test_data(im_dir):
    all_images = []
    list_of_files = sorted(os.listdir(im_dir))
    for im_folder in list_of_files:
        list_of_img_files = sorted(os.listdir(os.path.join(im_dir, im_folder)))
        for image_file in list_of_img_files:
            image_path = os.path.join(im_dir, im_folder, image_file)
            image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            image = cv2.resize(image, (128, 128))
            all_images.append([image])
    all_images = np.array(all_images).reshape(-1, 20, 128, 128, 1)  # 20 frames per sequence
    return all_images

test_data = load_test_data("/content/drive/MyDrive/Test_Data")

# ---------------------------
# Load Bayesian threshold from training
# ---------------------------
R_th = np.load("models/bayesian_threshold.npy")  # assuming you saved it during training

# ---------------------------
# Compute End-Point Error (EPE)
# ---------------------------
def compute_epe(R_L, R_gt):
    return np.sqrt(np.sum((R_L - R_gt) ** 2, axis=-1)).mean()

# ---------------------------
# Run Test
# ---------------------------
for seq_idx, sequence in enumerate(test_data):
    print(f"Processing sequence {seq_idx + 1}/{len(test_data)}")
    for i in range(sequence.shape[0] - 1):  # avoid index error
        Z_n = sequence[i, ..., 0]
        Z_n1 = sequence[i+1, ..., 0]

        # Predicted optical flow by MLA-SpyNet
        R_L = mla_model.predict(np.expand_dims(np.stack([Z_n, Z_n1], axis=-1), axis=0))[0]

        # Ground truth: consecutive frames (same as training)
        R_gt = np.stack([Z_n, Z_n1], axis=-1)

        # Compute EPE
        epe = compute_epe(R_L, R_gt)

        # Detect anomaly using Bayesian threshold
        if epe > R_th:
            print(f"Anomaly detected at frame {i} in sequence {seq_idx + 1}")
