<a href="https://colab.research.google.com/github/DevashreePatrikar/MLA_SpyNet/blob/main/MLA_SpyNet_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import cv2
import os
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input, Multiply, GlobalAveragePooling2D, Reshape
from tensorflow.keras.models import Model, load_model
from sklearn.model_selection import train_test_split
from scipy.stats import norm

# ---------------------------
# CBAM Attention Module
# ---------------------------
def cbam_block(input_tensor, reduction_ratio=16):
    """ CBAM (Convolutional Block Attention Module) """
    # Channel Attention
    channel = GlobalAveragePooling2D()(input_tensor)
    channel = Reshape((1, 1, input_tensor.shape[-1]))(channel)
    channel = Conv2D(input_tensor.shape[-1] // reduction_ratio, (1, 1), activation='relu', padding='same')(channel)
    channel = Conv2D(input_tensor.shape[-1], (1, 1), activation='sigmoid', padding='same')(channel)
    channel_att = Multiply()([input_tensor, channel])

    # Spatial Attention
    spatial = Conv2D(1, (7, 7), activation='sigmoid', padding='same')(channel_att)
    spatial_att = Multiply()([channel_att, spatial])

    return spatial_att

# ---------------------------
# SpyNet Block
# ---------------------------
def spynet_block(input_tensor):
    x = Conv2D(64, (7, 7), padding="same", activation="relu")(input_tensor)
    x = cbam_block(x)  # Attention after first conv
    x = Conv2D(64, (5, 5), padding="same", activation="relu")(x)
    x = cbam_block(x)  # Attention after second conv
    x = Conv2D(32, (3, 3), padding="same", activation="relu")(x)
    x = cbam_block(x)  # Attention after third conv
    R_L = Conv2D(2, (3, 3), padding="same")(x)  # Output Optical Flow
    return R_L

# ---------------------------
# MLA-SpyNet Model
# ---------------------------
def build_mla_spynet_model():
    input_layer = Input(shape=(128, 128, 2))
    output_layer = spynet_block(input_layer)
    model = Model(inputs=input_layer, outputs=output_layer, name="MLA_SpyNet")
    return model

mla_spynet_model = build_mla_spynet_model()
mla_spynet_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss="mse")

# ---------------------------
# Load Dataset
# ---------------------------
def load_data(im_dir):
    all_images = []
    list_of_files = sorted(os.listdir(im_dir))

    for im_folder in list_of_files:
        list_of_img_files = sorted(os.listdir(os.path.join(im_dir, im_folder)))
        for image_file in list_of_img_files:
            image_path = os.path.join(im_dir, im_folder, image_file)
            image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            image = cv2.resize(image, (128, 128))
            all_images.append([image])

    all_images = np.array(all_images).reshape(-1, 20, 128, 128, 1)
    return all_images

train_data = load_data("/content/drive/MyDrive/Train_Data")
test_data = load_data("/content/drive/MyDrive/Test_Data")

# ---------------------------
# Load DSAPM Model
# ---------------------------
dsapm_model = load_model("models/dsapm.h5")

# ---------------------------
# Prepare Training Data
# ---------------------------
X_train, Y_train = [], []
for sequence in train_data:
    for i in range(18):
        Z_n = sequence[i, ..., 0]
        Z_n1 = sequence[i+1, ..., 0]
        input_pair = np.stack([Z_n, Z_n1], axis=-1)
        X_train.append(input_pair)
        # Use MLA-SpyNet itself as placeholder for initial training
        R_gt = mla_spynet_model.predict(np.expand_dims(input_pair, axis=0))
        Y_train.append(R_gt[0])

X_train = np.array(X_train)
Y_train = np.array(Y_train)

# ---------------------------
# Train-validation split
# ---------------------------
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=42)

# ---------------------------
# Train MLA-SpyNet Model
# ---------------------------
mla_spynet_model.fit(X_train, Y_train, batch_size=16, epochs=20, validation_data=(X_val, Y_val))
mla_spynet_model.save("models/mla_spynet_trained.h5")

# ---------------------------
# Compute End-Point Error (EPE)
# ---------------------------
def compute_epe(predicted_flow, gt_flow):
    return np.sqrt(np.sum((predicted_flow - gt_flow)**2, axis=-1)).mean()

# ---------------------------
# Compute Bayesian Threshold R_th
# ---------------------------
normal_epe_values = []
for sequence in train_data:
    for i in range(17):
        Z_n = sequence[i, ..., 0]
        Z_n1 = sequence[i+1, ..., 0]
        R_L = mla_spynet_model.predict(np.expand_dims(np.stack([Z_n, Z_n1], axis=-1), axis=0))
        # Use actual consecutive frames as "ground truth" for optical flow
        R_gt = np.stack([Z_n, Z_n1], axis=-1)
        normal_epe_values.append(compute_epe(R_L, R_gt))

normal_epe_values = np.array(normal_epe_values)
mu_epe = normal_epe_values.mean()
sigma_epe = normal_epe_values.std()

# Set Bayesian threshold (5% tail)
prob_threshold = 0.05
R_th = norm.ppf(prob_threshold, loc=mu_epe, scale=sigma_epe)

print(f"Bayesian Threshold R_th: {R_th}")
