In [None]:

########################################  Model for training   ###################################################################
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # Force TensorFlow to use CPU before loading anything else
import numpy as np
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.image import ssim
import matplotlib.pyplot as plt

# Allow memory growth for the GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        print("Memory growth enabled for GPU")
    except RuntimeError as e:
        print(f"Error enabling memory growth: {e}")

# Function to resize the frames to a target size (340, 300)
def resize_frame(frame, target_size=(140, 100)):
    """
    Resize the frame to the target size (300, 387) explicitly ensuring the correct dimensions.
    """
    if frame is None or frame.shape[0] == 0 or frame.shape[1] == 0:
        print(f"⚠ Invalid frame detected: {frame}")
        return None  # Return None for invalid frames
    
    height, width = frame.shape[:2]
    
    # Ensure resizing happens correctly (force it to target size)
    if (height, width) != target_size:
        print(f"⚠ Resizing: {frame.shape} -> {target_size}")
        frame = cv2.resize(frame, (target_size[1], target_size[0]))  # Resize to (300, 387)
    
    return frame  # Return the resized frame

# ConvLSTM Model definition
seq = keras.Sequential([
    keras.Input(shape=(None, 140, 100, 1)),  # Input shape is now (None, 300, 387, 1)
    layers.ConvLSTM2D(filters=50, kernel_size=(5, 5), padding="same", return_sequences=True),
    layers.LeakyReLU(alpha=0.1),
    layers.Dropout(0.2),

    layers.ConvLSTM2D(filters=50, kernel_size=(5, 5), padding="same", return_sequences=True),
    layers.LeakyReLU(alpha=0.1),
    layers.Dropout(0.2),

    layers.ConvLSTM2D(filters=50, kernel_size=(5, 5), padding="same", return_sequences=True),
    layers.LeakyReLU(alpha=0.1),
    layers.Dropout(0.2),

    layers.Conv3D(filters=1, kernel_size=(1, 3, 3), activation="sigmoid", padding="same"),
])

# Implement EarlyStopping and ReduceLROnPlateau
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.8, patience=8)


# Compute difference between consecutive frames to enforce motion consistency
@keras.utils.register_keras_serializable()
def temporal_loss(y_true, y_pred):
    diff_true = y_true[:, 1:] - y_true[:, :-1]
    diff_pred = y_pred[:, 1:] - y_pred[:, :-1]
    diff_diff = diff_true - diff_pred

    # Use Huber loss on the difference between consecutive frames
    huber_loss = tf.keras.losses.Huber(delta=1.0)  # You can adjust the delta value
    loss_value = huber_loss(diff_true, diff_pred)

    return tf.reduce_mean(loss_value)

def custom_loss(y_true, y_pred):
    # Clip predictions and ground truth to avoid extreme values
    y_pred = tf.clip_by_value(y_pred, 1e-6, 1.0 - 1e-6)
    y_true = tf.clip_by_value(y_true, 1e-6, 1.0 - 1e-6)
    
    # Compute SSIM loss safely
    ssim_loss = 1 - tf.reduce_mean(ssim(y_true, y_pred, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03))+ 1e-8
    
    # Compute MSE loss
    mse_loss = tf.reduce_mean(tf.square(y_true - y_pred))+ 1e-8
    
    # Compute Temporal Loss
    temp_loss = temporal_loss(y_true, y_pred)+ 1e-8
    #############
    # ✅ Immediately detect NaNs/Infs
    mse_loss = tf.debugging.check_numerics(mse_loss, "NaN detected in MSE loss")
    temp_loss = tf.debugging.check_numerics(temp_loss, "NaN detected in Temporal loss")
    ssim_loss = tf.debugging.check_numerics(ssim_loss, "NaN detected in SSIM loss")
    # Weighted sum of losses
    total_loss = ssim_loss + 0.1 * mse_loss + 0.2 * temp_loss
    #total_loss = ssim_loss + 0.05 * temp_loss
    return total_loss

# Compile Model with reduced learning rate
optimizer = keras.optimizers.Adam(learning_rate=0.005, clipnorm=1.0)  # Reduced learning rate
seq.compile(loss=custom_loss, optimizer=optimizer)
seq.summary()

####################################################################################################

# Function to load available frames from the dataset
def get_available_frames(folder_path):
    frame_numbers = []
    for file_name in os.listdir(folder_path):
        if file_name.startswith("frame_") and file_name.endswith(".png"):
            try:
                frame_num = int(file_name.split("_")[1].split(".")[0])
                frame_numbers.append(frame_num)
            except ValueError:
                continue  # Skip files with invalid naming
    return sorted(frame_numbers)

# Function to resize frames to a fixed size
def resize_frame(frame, target_size=(340, 300)):
    return cv2.resize(frame, target_size)

def load_image_sequence(folder_path, n_frames=20, start_frame=0, target_size=(140, 100)):
    """
    Loads a sequence of images from the specified folder.
    If a frame is missing, it appends a black frame instead.
    Handles non-sequential frame numbering.
    """
    # List all frames in the folder
    all_frames = [f for f in os.listdir(folder_path) if f.endswith('.png')]
    
    # Sort the frames based on their numeric frame number (extract number from filename)
    all_frames.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))  # Extract frame number from filename
    
    sequence = []
    
    # Ensure we start from the right frame
    for i in range(start_frame, start_frame + n_frames):
        if i < len(all_frames):
            image_path = os.path.join(folder_path, all_frames[i])
            
            # Read the image and resize it to the target size
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)  # Read as grayscale
            img = cv2.resize(img, target_size)  # Resize to target size
            
            # Append the frame to the sequence
            sequence.append(img)
        else:
            # If the frame is missing, append a black frame of the correct size
            print(f"⚠ Missing frame index {i}, using black frame")
            black_frame = np.zeros(target_size, dtype=np.uint8)  # Black frame with target size
            sequence.append(black_frame)
    
    # Ensure that all frames are the same shape before converting to numpy array
    for idx, frame in enumerate(sequence):
        print(f"Frame {idx} shape: {frame.shape}")  # Debugging frame shape

    # Convert to a numpy array of shape (n_frames, height, width)
    return np.array(sequence)

# Memory efficient frame loading: Load a batch of frames
def load_batch_of_frames(batch_size, folder_path, n_frames=20, target_size=(140, 100)):
    """
    Loads a batch of frames from the specified folder, ensuring memory efficiency.
    """
    # List all frames in the folder
    all_frames = [f for f in os.listdir(folder_path) if f.endswith('.png')]
    all_frames.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))  # Sort by frame number
    
    batch = []
    for idx in range(batch_size):
        # Randomly pick a starting frame for this sample
        start_frame = np.random.randint(0, len(all_frames) - n_frames + 1)
        
        sequence = []
        for i in range(start_frame, start_frame + n_frames):
            image_path = os.path.join(folder_path, all_frames[i])
            
            # Read and resize the frame
            img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            img = resize_frame(img, target_size)
            
            sequence.append(img)
        
        # Add this sequence to the batch
        batch.append(np.array(sequence))
    
    # Return the batch as a numpy array
    return np.array(batch)

# Generate movies from multiple folders with memory efficient batch loading
def generate_movies_from_multiple_folders(folder_paths, n_samples_per_folder=30, n_frames=20, batch_size=1, target_size=(100, 140)):
    noisy_movies = []
    shifted_movies = []

    for folder_path in folder_paths:
        if not os.path.exists(folder_path):
            print(f"❌ Folder not found: {folder_path}")
            continue

        print(f"📂 Processing folder: {folder_path}")

        for sample_idx in range(n_samples_per_folder):
            try:
                # Load a batch of frames from the folder
                movie_frames_batch = load_batch_of_frames(batch_size, folder_path, n_frames=n_frames, target_size=target_size)

                # For each batch, create noisy and shifted movies
                for movie_frames in movie_frames_batch:
                    noisy_movies.append(movie_frames)
                    shifted_movie = movie_frames[1:]

                    # Ensure the output has the same number of frames
                    if shifted_movie.shape[0] < n_frames:
                        shifted_movie = np.concatenate((shifted_movie, np.expand_dims(movie_frames[-1], axis=0)), axis=0)

                    shifted_movies.append(shifted_movie)

                print(f"   ✅ Sample {sample_idx + 1}/{n_samples_per_folder} from {folder_path}")

            except Exception as e:
                print(f"❌ Error loading sequence from {folder_path}, sample {sample_idx + 1}: {e}")
                continue

    print(f"📊 Total noisy movies collected: {len(noisy_movies)}")
    print(f"📊 Total shifted movies collected: {len(shifted_movies)}")

    # Convert lists to numpy arrays
    noisy_movies = np.array(noisy_movies)
    shifted_movies = np.array(shifted_movies)

    # Add the channel dimension (grayscale)
    noisy_movies = noisy_movies[..., np.newaxis]
    shifted_movies = shifted_movies[..., np.newaxis]

    return noisy_movies, shifted_movies

# Define the dataset folder paths
dataset_folders = [
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170815-133921-Al 2mm",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170815-134756-Al 2mm",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170904-112347-Al 2mm",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170904-113012-Al 2mm-part1",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170904-141232-Al 2mm-part3",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170904-141730-Al 2mm-part3",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170906-113317-Al 2mm-part3",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170906-153326-Al 2mm-part2",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170913-143933-Al 2mm-part2",
    r"/mnt/ssd/Ulster/Vision Project/Papers/Report 2/Datasets/al5083/train/170913-151508-Al 2mm-part1"
]

# Generate noisy and shifted movies
noisy_movies, shifted_movies = generate_movies_from_multiple_folders(dataset_folders, n_samples_per_folder=30, n_frames=20, target_size=(100, 140))

# Print the shapes to verify
print("Noisy movies shape:", noisy_movies.shape)  # Expected: (40, 20, 340, 300, 1) if all folders are used
print("Shifted movies shape:", shifted_movies.shape)  # Expected: (40, 20, 340, 300, 1) if all folders are used 

###############################################################################################################################
print(f"Shape of noisy_movies: {noisy_movies.shape}")
print(f"Shape of shifted_movies: {shifted_movies.shape}")

# Ensure shape is (batch_size, time_steps, height, width, 1)
noisy_movies_resized = noisy_movies  # No need for np.newaxis since last dim is already 1
shifted_movies_resized = shifted_movies  # No need for np.newaxis

# Print shape to verify correctness
print(f"Final Noisy Movies Shape: {noisy_movies_resized.shape}")
print(f"Final Shifted Movies Shape: {shifted_movies_resized.shape}")

# Visualization
num_frames = 5
sample_noisy = noisy_movies[0, :num_frames]  # First sample's first five frames
sample_shifted = shifted_movies[0, :num_frames]  # First sample's first five frames (shifted)

#fig, axes = plt.subplots(2, num_frames, figsize=(15, 6))
#for i in range(num_frames):
   # axes[0, i].imshow(sample_noisy[i].squeeze(), cmap="gray")
   # axes[0, i].set_title(f"Noisy Frame {i+1}")
    #axes[0, i].axis("off")
    
   # axes[1, i].imshow(sample_shifted[i].squeeze(), cmap="gray")
   # axes[1, i].set_title(f"Shifted Frame {i+1}")
   # axes[1, i].axis("off")

#plt.tight_layout()
#plt.show()
######################################################################################
noisy_movies_resized = noisy_movies_resized.astype('float32')
shifted_movies_resized = shifted_movies_resized.astype('float32')
noisy_movies_resized = noisy_movies_resized.reshape((noisy_movies_resized.shape[0], 20, 140, 100, 1))
shifted_movies_resized = shifted_movies_resized.reshape((shifted_movies_resized.shape[0], 20, 140, 100, 1))
print(noisy_movies_resized.shape)  # Should print (num_samples, 20, 340, 300, 1)
print(shifted_movies_resized.shape)  # Should print (num_samples, 20, 340, 300, 1)
#####################

#############Normalization #####################

noisy_movies_resized = noisy_movies_resized / 255.0
shifted_movies_resized = shifted_movies_resized / 255.0

#######################################################################
num_frames = 5
sample_noisy2 = noisy_movies_resized[0, :num_frames]  # First sample's first five frames
sample_shifted2 = shifted_movies_resized[0, :num_frames]  # First sample's first five frames (shifted)

fig, axes = plt.subplots(2, num_frames, figsize=(15, 6))
for i in range(num_frames):
    axes[0, i].imshow(sample_noisy[i].squeeze(), cmap="gray")
    axes[0, i].set_title(f"Noisy Frame {i+1}")
    axes[0, i].axis("off")
    
    axes[1, i].imshow(sample_shifted[i].squeeze(), cmap="gray")
    axes[1, i].set_title(f"Shifted Frame {i+1}")
    axes[1, i].axis("off")

plt.tight_layout()
plt.show()
###################################################

# Print available logs during training to check if val_loss is being computed
def on_epoch_end(epoch, logs=None):
    print(f"Epoch {epoch} - Available logs: {logs.keys()}")

callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=15, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.8, patience=13),
    keras.callbacks.LambdaCallback(on_epoch_end=on_epoch_end)  # Debug logs at each epoch
]

seq.fit(
    noisy_movies_resized, 
    shifted_movies_resized, 
    batch_size=1,
    epochs=30,
    validation_split=0.2,
    callbacks=callbacks
)

# Save Model
seq.save("conv_lstm_model-38.keras")