In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import time
from IPython import display
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import models, callbacks, applications
from tensorflow.keras import backend as K
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, roc_curve, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from scipy.linalg import sqrtm
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import plot_model
import seaborn as sns
import pandas as pd
import rasterio
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau
import re

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
IMG_SHAPE = (3,8,8,3)

In [None]:
def load_temporal_data(base_dir, patches, days):
    """
    Loads NPZ files from multiple temporal patches and combines them into a single NumPy array.
    
    Parameters:
        base_dir (str): Base directory containing the temporal data.
        patches (list): List of patch names (e.g., ['F440', 'F450']).
        days (list): List of day folder names corresponding to different timestamps.
    
    Returns:
        np.ndarray: Combined dataset containing all temporal data for winter_barley.
    """
    winter_barley_data = []
    
    for patch in patches:
        segment_groups = {}

        # Iterate over all days
        for day_idx, day in enumerate(days):
            patch_dir = os.path.join(base_dir, patch, day)
            
            if not os.path.exists(patch_dir):
                print(f"Warning: Directory {patch_dir} does not exist. Skipping...")
                continue
            
            # Process files in each day's folder
            for filename in os.listdir(patch_dir):
                if filename.endswith(".npz"):
                    match = re.search(r'_(\d+)_segment_(\d+)_', filename)  # Extract segment ID
                    if match:
                        segment_id = match.group(2)
                        if segment_id not in segment_groups:
                            segment_groups[segment_id] = {}
                        segment_groups[segment_id][day_idx] = os.path.join(patch_dir, filename)
        
        # Process only segments that exist in all three days
        for segment_id, files in segment_groups.items():
            if len(files) == len(days):  # Ensure we have all three days
                segment_data = []

                # Load NPZ files for the segment
                for day_idx in range(len(days)):
                    data = np.load(files[day_idx])
                    image = data['image']  # Assuming the key is 'image'
                    segment_data.append(image)
                
                # Stack along the first axis to maintain temporal ordering
                segment_data = np.array(segment_data)  # Shape: (3, H, W, C)
                winter_barley_data.append(segment_data)
    
    # Convert the collected data into a NumPy array
    if winter_barley_data:
        winter_barley_array = np.array(winter_barley_data)  # Shape: (num_segments, 3, H, W, C)
        return winter_barley_array
    else:
        print("No valid temporal data found.")
        return None

# Example usage
# base_directory = "New folder/"
# patches = ["c11_c22_c33"]
# days = ["_0208", "_0912", "_1609"]  # Replace with actual day folder names
# winter_barley_array = load_temporal_data(base_directory, patches, days)


# base_directory = "dataset/segmented_npz/processed_segmented_npz/_8x8/temporal/c11_c22_c33/winter_barley/"
# patches = ["F440", "F450"]
# days = ["_0209", "_0912", "_1609"]  # Replace with actual day folder names

# winter_barley_array = load_temporal_data(base_directory, patches, days)

# # Example usage
# base_directory = "dataset/segmented_npz/processed_segmented_npz/_8x8/temporal/c11_c22_c33/winter_wheat/"
# patches = ["F230", "F250"]
# days = ["_0209", "_0912", "_1609"]  # Replace with actual day folder names

# winter_wheat_array = load_temporal_data(base_directory, patches, days)

In [None]:
winter_barley_y = [0 for i in range(len(winter_barley_array))]
winter_wheat_y = [1 for i in range(len(winter_wheat_array))]

In [None]:
winter_barley_array1, winter_barley_test, winter_barley_y1, winter_barley_ytest = train_test_split(winter_barley_array, winter_barley_y, test_size=0.2, random_state=7)

In [None]:
winter_wheat_array1, winter_wheat_test, winter_wheat_y1, winter_wheat_ytest = train_test_split(winter_wheat_array, winter_wheat_y, test_size=0.2, random_state=7)

In [None]:
xtest = np.concatenate([winter_barley_test, winter_wheat_test], axis=0)
ytest = np.concatenate([winter_barley_ytest, winter_wheat_ytest], axis=0)

In [None]:
xtrain_barley, xval_barley, ytrain_barley, yval_barley = train_test_split(winter_barley_array1, winter_barley_y1, test_size=0.1, random_state=7)

In [None]:
xtrain_wheat, xval_wheat, ytrain_wheat, yval_wheat = train_test_split(winter_wheat_array1, winter_wheat_y1, test_size=0.1, random_state=7)

In [None]:
xtrain = np.concatenate([xtrain_barley, xtrain_wheat], axis=0)
ytrain = np.concatenate([ytrain_barley, ytrain_wheat], axis=0)

In [None]:
xval = np.concatenate([xval_barley, xval_wheat], axis=0)
yval = np.concatenate([yval_barley, yval_wheat], axis=0)

In [None]:
np.bincount(ytrain)[1]-np.bincount(ytrain)[0]

In [None]:
# Initialize ImageDataGenerator with transformations
datagen = ImageDataGenerator(
    featurewise_center=False,  
    samplewise_center=False,  
    featurewise_std_normalization=False,  
    samplewise_std_normalization=False,  
    rotation_range=25,  
    width_shift_range=0.1,  
    height_shift_range=0.1,  
    horizontal_flip=False,  
    vertical_flip=False
)

# Reshape xtrain to (num_samples * 3, 8, 8, 3) so that it's 4D
num_samples = xtrain.shape[0]  # Original number of samples
xtrain_reshaped = xtrain.reshape(-1, 8, 8, 3)  # New shape: (num_samples * 3, 8, 8, 3)

# Fit the generator
datagen.fit(xtrain_reshaped)

# Generate augmented images
samples = datagen.flow(xtrain_reshaped, batch_size=1)

# Store the augmented images and reshape them back to (num_samples, 3, 8, 8, 3)
image = []
for i in range(1159 * 3):  # Iterate over each sample
    img = next(samples).squeeze()  # Remove batch dimension
    image.append(img)

# Convert list to numpy array and reshape back to original form
image = np.array(image).reshape(-1, 3, 8, 8, 3)  # Final shape: (num_samples, 3, 8, 8, 3)

print("Augmented Image Shape:", image.shape)  # Expected: (num_samples, 3, 8, 8, 3)

In [None]:
xtrain_augmented = np.concatenate((xtrain,image),axis=0)

In [None]:
ytrain1 = [0 for i in range(1159)]
ytrain_augmented = np.concatenate((ytrain,ytrain1),axis=0)

In [None]:
x = np.concatenate((xtrain_augmented,xval,xtest), axis=0)
y = np.concatenate((ytrain_augmented,yval,ytest), axis=0)

In [None]:
def z_score_normalization(data):
    mean = np.mean(data, axis=(1, 2), keepdims=True)  # Compute mean per channel
    std = np.std(data, axis=(1, 2), keepdims=True)  # Compute std per channel
    return (data - mean) / (std + 1e-7)  # Normalize with small epsilon to avoid division by zero

In [None]:
x_normalized = z_score_normalization(x)

In [None]:
xtrain_normalized, xtest_normalized, ytrain_augmented, ytest = train_test_split(x_normalized,y, test_size=0.2, random_state=7)
xtrain_normalized, xval_normalized, ytrain_augmented, yval = train_test_split(xtrain_normalized,ytrain_augmented, test_size=0.1, random_state=7)

In [None]:
cnn_model = models.Sequential()

# 3D Convolutional Layers
cnn_model.add(layers.Conv3D(16, (3,3,3), activation='relu', input_shape=(3,8,8,3), padding='same'))
cnn_model.add(layers.BatchNormalization())

cnn_model.add(layers.Conv3D(32, (3,3,3), activation='relu', padding='same'))
cnn_model.add(layers.BatchNormalization())

cnn_model.add(layers.Conv3D(64, (3,3,3), activation='relu', padding='same'))
cnn_model.add(layers.BatchNormalization())

cnn_model.add(layers.MaxPooling3D(pool_size=(1,2,2)))  # Reduce spatial dimensions

cnn_model.add(layers.Conv3D(128, (3,3,3), activation='relu', padding='same'))
cnn_model.add(layers.BatchNormalization())

cnn_model.add(layers.MaxPooling3D(pool_size=(1,2,2)))  # Further spatial reduction

# Flatten and Fully Connected Layers
cnn_model.add(layers.Flatten())

cnn_model.add(layers.Dense(128, activation='relu'))
cnn_model.add(layers.BatchNormalization())
cnn_model.add(layers.Dropout(0.3))

cnn_model.add(layers.Dense(1, activation='sigmoid'))  # Binary classification

# Compile Model
cnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Model Summary
cnn_model.summary()

In [None]:
early_stopping_callback = callbacks.EarlyStopping(patience=100)
batch_size = 10
print(batch_size)

cnn_model.compile(optimizer=Adam(learning_rate=0.01), loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=['accuracy'])

In [None]:
precision_0 = []
precision_1 = []
recall_0 = []
recall_1 = []
accuracy = []
auc_score = []
class TestAfterEpoch(tf.keras.callbacks.Callback):
    def __init__(self, test_data, test_labels):
        self.test_data = test_data
        self.test_labels = test_labels
        
    def on_epoch_end(self, epoch, logs=None):
        self.model.evaluate(self.test_data, self.test_labels)
        predictions = self.model.predict(self.test_data)
        y_pred_classes = np.where(predictions > 0.5, 1,0)
        
        cm = confusion_matrix(self.test_labels, y_pred_classes)
        print(cm)
        
        labels = ['Class 0', 'Class 1']  # Optional: class labels

        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
        disp.plot(cmap='Blues', values_format='d')  # Customize color and format

        # Customize plot appearance
        plt.title('Confusion Matrix', fontsize=16)
        plt.xlabel('Predicted Labels', fontsize=14)
        plt.ylabel('True Labels', fontsize=14)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.show()
        
        
        
        print("\n\n")
        
        report = classification_report(self.test_labels, y_pred_classes, digits=5, output_dict=True)
        precision_0.append(report["0"]["precision"])
        precision_1.append(report["1"]["precision"])
        recall_0.append(report["0"]["recall"])
        recall_1.append(report["1"]["recall"])
        accuracy.append(report["accuracy"])
        
        
        
        
        print(classification_report(self.test_labels, y_pred_classes, digits=5))
        
        # Assuming you have y_true (true labels) and y_pred_proba (predicted probabilities)
        y_true = self.test_labels  # True labels
        y_pred_proba = predictions  # Predicted probabilities for the positive class

        # Calculate ROC curve
        fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)

        # Calculate AUC score
        auc = roc_auc_score(y_true, y_pred_proba)
        auc_score.append(auc)

        with open("3d_cnn_3_channel.txt", "a") as file:
            file.write(f"fpr: {fpr}")
            file.write(f"\n tpr: {tpr}")
            file.write(f"\n auc: {auc:.4f}")

        # Plot ROC curve
        plt.figure()
        plt.plot(fpr, tpr, label=f'AUC = {auc:.4f}')
        plt.plot([0, 1], [0, 1], 'k--')  # Diagonal line for random classifier
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC Curve')
        plt.legend(loc='lower right')
        plt.show()

In [None]:
test_callback = TestAfterEpoch(xtest_normalized, ytest)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=1e-3)
epochs_info = cnn_model.fit(xtrain_normalized,
                            ytrain_augmented,
                            epochs=500,
                           validation_data=(xval_normalized, yval),
                           callbacks = [test_callback, reduce_lr])