In [None]:
import os
import time
import re
import numpy as np
import matplotlib.pyplot as plt
import rasterio
import mxnet as mx
from mxnet import gluon, nd, autograd
from mxnet.gluon import nn
from mxnet.gluon.data import DataLoader, ArrayDataset
from mxnet.lr_scheduler import FactorScheduler
from mxnet.gluon.data.vision import transforms
from sklearn.model_selection import train_test_split

In [None]:
# Paths
image_dir = r"D:\Source\Test\TextMxnet\data\2022\BB\08X_Features_Multi"
mask_dir = r"D:\Source\Test\TextMxnet\data\2022\BB\XX_Reference_Masks_ResUNetA"
loss_path = r"D:\Source\Test\TextMxnet\data\2022\BB\Output\Loss\loss_plot5.png"
result_path = r"D:\Source\Test\TextMxnet\data\2022\BB\Output\Result"
trained_model_path = r"D:\Source\Test\TextMxnet\data\2022\BB\Output\Model_params\resunet_model.params"

epochs = 50
batch_size = 8
learning_rate = 0.000001

try:
    ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()
except:
    ctx = mx.cpu()

In [None]:
# Define the ResUNetA model
class ResUNetA(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(ResUNetA, self).__init__(**kwargs)
        with self.name_scope():
            self.encoder = nn.HybridSequential()
            self.encoder.add(nn.Conv2D(64, kernel_size=3, padding=1))
            self.encoder.add(nn.BatchNorm())
            self.encoder.add(nn.Activation('relu'))
            self.encoder.add(nn.Dropout(0.5))  # Added dropout to stabilize training

            self.decoder = nn.HybridSequential()
            self.decoder.add(nn.Conv2D(64, kernel_size=3, padding=1))
            self.decoder.add(nn.BatchNorm())
            self.decoder.add(nn.Activation('relu'))
            self.decoder.add(nn.Dropout(0.5))  # Added dropout to stabilize training

            self.output = nn.Conv2D(1, kernel_size=1)  # Output for binary segmentation

    def hybrid_forward(self, F, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return self.output(x)

In [None]:
# Function to extract numeric identifier from filenames
def extract_number(filename):
    match = re.search(r'\d+', filename)
    return int(match.group()) if match else None

In [None]:
# Load TIFF files
def load_tif_files_with_id(directory, num_files=None):
    file_dict = {}
    count = 0
    try:
        for filename in os.listdir(directory):
            if filename.endswith(".tif"):
                file_id = extract_number(filename)
                if file_id is not None:
                    with rasterio.open(os.path.join(directory, filename)) as src:
                        image = src.read()  # Load as NumPy array
                        if file_id in file_dict:
                            file_dict[file_id].append(image)
                            #print(f'fid: {file_id}. count: {len(file_dict[file_id])}')
                        else:
                             file_dict[file_id] = [image]  
                       # Limit the number of files loaded, if specified
                count += 1
                if num_files and count >= num_files:
                    break
        
    except Exception as e:
        print(f"Error loading TIFF files: {e}")
    return file_dict

In [None]:
# Load images and masks from directories with numeric identifiers as keys
images_dict = load_tif_files_with_id(image_dir, num_files=10)
masks_dict = load_tif_files_with_id(mask_dir, num_files=10)
#print(len(images_dict))

In [None]:
# Function to calculate NDVI given NIR and Red bands
def calculate_ndvi(image, nir_index, red_index):
    nir = image[nir_index]  # NIR band
    red = image[red_index]  # Red band
    # Calculate NDVI with division by zero handling
    ndvi = (nir - red) / (nir + red + 1e-5)  # Adding a small value to avoid division by zero
    return ndvi

In [None]:
# Assuming NIR is the 4th band and Red is the 1st band
nir_band_index = 3
red_band_index = 0

for key, values in images_dict.items():
    # Verify the image structure before proceeding
    if len(values) < 2:
        print(f"Warning: Insufficient images for key {key}")
        continue

    # Calculate NDVI for each image
    try:
        ndv_ndvi = calculate_ndvi(values[0], nir_band_index, red_band_index)
        vnir_ndvi = calculate_ndvi(values[1], nir_band_index, red_band_index)
        mask_ndvi = calculate_ndvi(masks_dict[key][0], nir_band_index, red_band_index)
    except IndexError:
        print(f"Error: Invalid band indices for key {key}")
        continue

    # Plot the NDVI results with color bars
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    ndv_img = plt.imshow(ndv_ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
    plt.colorbar(ndv_img, label="NDVI")
    plt.title(f"NDVI - NDV Image (fid: {key})")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    vnir_img = plt.imshow(vnir_ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
    plt.colorbar(vnir_img, label="NDVI")
    plt.title(f"NDVI - VNIR Image (fid: {key})")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    mask_img = plt.imshow(mask_ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
    plt.colorbar(mask_img, label="NDVI")
    plt.title(f"NDVI - Mask Image (fid: {key})")
    plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
def preprocess_data_dict(images_dict, masks_dict):
    images_preprocessed = {}
    masks_preprocessed = {}

    try:
        for key in images_dict.keys():
           # Get the image and corresponding mask
            #print(f'fid: {key}. count: {len(images_dict[key])}')
            img = np.array(images_dict[key])  # Convert img to a NumPy array if it's not already
            #print(img.shape)
            msk = np.array(masks_dict.get(key))  # Convert msk to a NumPy array if it's not already
            #print(msk.shape)
            # Normalize image to range [0, 1]
            img = img / 255.0

            # Handle NaN or infinite values in the mask
            msk = np.nan_to_num(msk)  # Replace NaN, inf, -inf with 0

            # If the mask has multiple channels, reduce to a single channel by summing across channels
            if msk.ndim > 2 and msk.shape[0] > 1:
                msk = np.sum(msk, axis=0, keepdims=True)

            # Convert summed mask to binary (0 or 1)
            msk = np.where(msk > 0, 1, 0)
            # Store preprocessed data in output dictionaries
            images_preprocessed[key] = img.astype('float32')
            masks_preprocessed[key] = msk.astype('float32')
    except Exception as e:
        print(f"Error during preprocessing: {e}")

    return images_preprocessed, masks_preprocessed

In [None]:
# Preprocess the loaded images and masks
images_preprocessed, masks_preprocessed = preprocess_data_dict(images_dict, masks_dict)

In [None]:
def train_test_split_dict(images_dict, masks_dict, test_size=0.2, random_state=None):
    """
    Splits dictionaries of images and masks into training and validation sets while keeping keys.

    Args:
        images_dict (dict): Dictionary of images, with keys as identifiers.
        masks_dict (dict): Dictionary of masks, with keys as identifiers.
        test_size (float): Proportion of the data to include in the validation split.
        random_state (int): Random seed for reproducibility.

    Returns:
        dict: Training split of the images.
        dict: Validation split of the images.
        dict: Training split of the masks.
        dict: Validation split of the masks.
    """
    # Get the list of keys
    keys = list(images_dict.keys())
    
    # Split keys into training and validation sets
    keys_train, keys_val = train_test_split(keys, test_size=test_size, random_state=random_state)
    
    # Create train and validation dictionaries for images and masks
    images_train = {key: images_dict[key] for key in keys_train}
    images_val = {key: images_dict[key] for key in keys_val}
    masks_train = {key: masks_dict[key] for key in keys_train}
    masks_val = {key: masks_dict[key] for key in keys_val}
    
    for key in images_train.keys():
        print(f"image train fid:{key}. Size: {len(images_train[key])}")
    for key in images_val.keys():
        print(f"val fid:{key}. Size: {len(images_val[key])}")
    for key in masks_train.keys():
        print(f"mask train fid:{key}. Size: {len(masks_train[key])}")
    for key in masks_val.keys():
        print(f"mask val fid:{key}. Size: {len(masks_val[key])}")
    
    return images_train, images_val, masks_train, masks_val

In [None]:
# Split images and masks into training and validation sets
X_train, X_val, Y_train, Y_val = train_test_split_dict(images_preprocessed, masks_preprocessed, test_size=0.2, random_state=42)

In [None]:
# Function to compute IoU
def iou_metric(pred, mask):
    try:
        pred = (pred > 0.5).astype(np.uint8)
        intersection = np.logical_and(mask, pred)
        union = np.logical_or(mask, pred)
        iou_score = np.sum(intersection) / np.sum(union)
    except Exception as e:
        print(f"Error computing IoU: {e}")
        iou_score = 0
    return iou_score

In [None]:
# Function for validation
def validate(model, val_data, ctx):
    iou_scores = []
    try:
        for data, label in val_data:
            data = data.as_in_context(ctx)
            label = label.as_in_context(ctx)

            pred = model(data)
            pred = nd.sigmoid(pred)
            print(f"Sigmoid output - min: {nd.min(pred).asscalar()}, max: {nd.max(pred).asscalar()}")
            pred = pred > 0.5

            iou = iou_metric(pred.asnumpy(), label.asnumpy())
            iou_scores.append(iou)
    except Exception as e:
        print(f"Error during validation: {e}")
    return np.mean(iou_scores)

In [None]:
def plot_loss(train_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss Over Time')
    plt.legend()
    plt.grid(True)
     # Adding text to the right corner
    info_text = f'Epochs: {epochs}\nBatch Size: {batch_size}\nLearning Rate: {learning_rate}'
    plt.text(len(train_losses) - 1, max(train_losses), info_text,
             ha='right', va='top', fontsize=10, bbox=dict(facecolor='white', alpha=0.5))
    plt.savefig(loss_path)
    plt.show()

In [None]:
def get_data_loader(xDict, yDict, batch_size, isTrue):
    # Convert dictionary values to lists and preprocess them
    x_values, y_values = list(xDict.values()), list(yDict.values())
    
    # Preprocess data to ensure correct shape
    #x_values, y_values = zip(*[preprocess_data(x, y) for x, y in zip(x_values, y_values)])
    
    # Create the DataLoader with the reshaped data
    data_loader = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(x_values, y_values), 
                                           batch_size=batch_size, shuffle=isTrue)
    return data_loader

In [221]:
train_data = get_data_loader(X_train, Y_train, batch_size, True)
val_data = get_data_loader(X_val, Y_val, batch_size, False)

for i, (data, label) in enumerate(train_data):
    print("Original Data shape:", data.shape)
    print("Original Label shape:", label.shape)
    
    # Reshape `data` to combine the extra dimension without altering the batch size
    if data.ndim == 5:
        data = data.reshape((data.shape[0], -1, data.shape[3], data.shape[4]))  # Combine channels
        print("Reshaped Data shape:", data.shape)
    
    # Reshape `label` to combine extra dimensions without altering the batch size
    if label.ndim == 5:
        label = label.reshape((label.shape[0], -1, label.shape[3], label.shape[4]))  # Combine channels if needed
        print("Reshaped Label shape before channel adjustment:", label.shape)
        
        # If `label` has more channels, select only the first channel
        if label.shape[1] > 1:
            label = label[:, 0:1, :, :]  # Keep only the first channel
        print("Final Label shape:", label.shape)

    # Send data and label to the designated device (CPU or GPU)
    data = data.as_in_context(ctx)
    label = label.as_in_context(ctx)

    print("Data shape:", data.shape)
    print("Label shape:", label.shape)



Original Data shape: (4, 2, 4, 256, 256)
Original Label shape: (4, 1, 6, 256, 256)
Reshaped Data shape: (4, 8, 256, 256)
Reshaped Label shape before channel adjustment: (4, 6, 256, 256)
Final Label shape: (4, 1, 256, 256)
Data shape: (4, 8, 256, 256)
Label shape: (4, 1, 256, 256)


In [None]:
def get_model():
    model = ResUNetA()
    try:
        model.initialize(ctx=ctx)
        model.hybridize()
        print('model is initiated')
        return model
    except Exception as e:
        print(f"Error initializing or hybridizing model: {e}")

In [None]:
# Training function with loss plotting and hyperparameter control
def train_model():
    model = get_model()
    # Define loss function and trainer
    loss_fn = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=True)
    # Reduce learning rate by a factor of 0.5 every 10 epochs
    lr_scheduler = FactorScheduler(step=10, factor=0.5)
    trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': learning_rate,
                                                             'clip_gradient': 0.1})

    # DataLoaders for training and validation
    try:
        train_data =get_data_loader(X_train,Y_train, batch_size, True) 
        val_data = get_data_loader(X_val, Y_val, batch_size, False) 
    except Exception as e:
        print(f"Error creating data loaders: {e}")
        return

    train_losses = []
    val_ious = []

    # Training loop
    for epoch in range(epochs):
        epoch_loss = 0
        try:
            accumulation_steps = 4  # Accumulate gradients over 4 batches
            for i, (data, label) in enumerate(train_data):
                print("Original Data shape:", data.shape)
                print("Original Label shape:", label.shape)
                
                # Reshape `data` and `label` to 4D if they have an extra dimension
                if data.ndim == 5:
                    data = data.reshape((-1, data.shape[2], data.shape[3], data.shape[4]))  # Flatten batch and extra dim
                    print("Reshaped Data shape:", data.shape)
                if label.ndim == 5:
                    label = label.reshape((-1, label.shape[2], label.shape[3], label.shape[4]))  # Flatten label shape
                    print("Reshaped Label shape:", label.shape)
                    
                # Send data and label to the designated device (CPU or GPU)
                data = data.as_in_context(ctx)
                label = label.as_in_context(ctx)
                print("Data shape:", data.shape)
                print("Label shape:", label.shape)

                with autograd.record(): # Start recording the operations for autograd
                    output = model(data) # Forward pass: compute the model output
                    print('I am here3')
                    output_sigmoid = nd.sigmoid(output)
                    # Print output stats before and after sigmoid
                    print(f"Output before sigmoid - min: {nd.min(output).asscalar()}, max: {nd.max(output).asscalar()}")
                    print(f"Output after sigmoid - min: {nd.min(output_sigmoid).asscalar()}, max: {nd.max(output_sigmoid).asscalar()}")

                    loss = loss_fn(output_sigmoid, label)
                    print(f'loss: {loss}')
                loss.backward()  # Backward pass: compute the gradients
                print(f'loss after backward: {loss}')
                if (i + 1) % accumulation_steps == 0:
                    trainer.step(batch_size * accumulation_steps)  # Effective batch size
                else:
                    trainer.step(batch_size)  # Update model parameters

                # Check for NaNs in output and loss
                print(f"Model output stats - min: {nd.min(output).asscalar()}, max: {nd.max(output).asscalar()}")
                epoch_loss += nd.mean(loss).asscalar()
                print(f'Batch {i}, epoch_loss: {epoch_loss}')
            print(f'Epoch {epoch + 1}, total epoch_loss: {epoch_loss}')
            print(f'train data: {train_data}')

            avg_loss = epoch_loss / len(train_data)
            val_iou = validate(model, val_data, ctx)

            train_losses.append(avg_loss)
            val_ious.append(val_iou)

            print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Validation IoU: {val_iou:.4f}")

            # Save the model parameters
            model.save_parameters(trained_model_path)
            print(f"Model parameters saved to {trained_model_path}")
        except Exception as e:
            print(f"Error during training epoch {epoch + 1}: {e}")

    # Plot training loss
    plot_loss(train_losses)

    # Return model and metrics
    return model, train_losses, val_ious

In [None]:
# Run training with hyperparameters
model= None
train_losses= None
val_ious = None
try:
    start_time = time.time()

   # Print the start time with a formatted string
    print(f"Start time for training model: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")

    model, train_losses, val_ious = train_model()
    end_time = time.time()
    print(f"End time for train model: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}")
    execution_time = end_time - start_time  # Calculate the duration
    print(f"Execution time for train model: {execution_time} seconds")
except Exception as e:
    print(f"Error during model training: {e}")

In [None]:
def visualize_predictions():
    try:
        # Create lists of (fid, image, mask) tuples
        data = [(fid, X_val[fid], Y_val[fid]) for fid in X_val.keys() if fid in Y_val]
    
        # Unpack the data into ArrayDataset
        fids, images, masks = zip(*data)
        print(fids)
        dataset = ArrayDataset(fids, images, masks)
    
        # Create DataLoader
        data_loader = DataLoader(dataset, batch_size, shuffle=False)

        model = ResUNetA()
        model.load_parameters(trained_model_path, mx.cpu())
        model.hybridize()
    except Exception as e:
        print("Error during model setup or loading parameters:", str(e))
        return

    for i, (fid, data, label) in enumerate(data_loader):
        try:
            # Ensure fid has only one element so we can use .asscalar()
            if fid.size == 1:
                id = fid.asscalar()  # Convert single-element NDArray to integer
            else:
                raise ValueError("fid contains more than one element; expected a single element.")
        except Exception as e:
            print(f"Error extracting single-element id at iteration {i}:", str(e))
            continue

        try:
            data = data.as_in_context(mx.cpu())
            label = label.as_in_context(mx.cpu())
        except Exception as e:
            print(f"Error moving data or label to context at iteration {i}:", str(e))
            continue

        try:
            with autograd.record():
                prediction = model(data)
                prediction = nd.sigmoid(prediction)
                prediction = (prediction > 0.5).astype('uint8')
        except Exception as e:
            print(f"Error during prediction at iteration {i}:", str(e))
            continue

        try:
            # Convert NDArray to numpy arrays for visualization
            data_np = data[0].asnumpy().squeeze()
            label_np = label[0].asnumpy().squeeze()
            prediction_np = prediction[0].asnumpy().squeeze()

            # Handle cases where data_np has 4 channels by using only the first 3 (assuming RGB + Alpha)
            if data_np.ndim == 3 and data_np.shape[0] == 4:
                data_np = data_np[:3, :, :]  # Take only the first 3 channels
            
            # If data_np has 3 channels, transpose it to (height, width, channels)
            if data_np.ndim == 3 and data_np.shape[0] == 3:
                data_np = data_np.transpose((1, 2, 0))
            
            # Check and normalize data_np if needed
            print(f"Original image (data_np) min: {data_np.min()}, max: {data_np.max()}")
            if data_np.max() > 1:
                data_np = data_np / 255.0  # Normalize if values are in [0, 255]

            # Check and normalize label_np if needed
            print(f"Mask (label_np) min: {label_np.min()}, max: {label_np.max()}")
            if label_np.max() > 1:
                label_np = label_np / 255.0  # Normalize if values are in [0, 255]

            # Ensure prediction is also scaled between 0 and 1 for visualization
            prediction_np = prediction_np / prediction_np.max()
            
        except Exception as e:
            print(f"Error accessing images or masks dictionaries at iteration {i}:", str(e))
            continue

        try:
            plt.figure(figsize=(10, 5))

            # Display the real image
            plt.subplot(1, 3, 1)
            plt.imshow(images_dict[id], cmap='gray' if data_np.ndim == 2 else None)
            plt.title(f"Original: {id}")
            plt.axis('off')

            # Display the corresponding mask
            plt.subplot(1, 3, 2)
            plt.imshow(masks_dict[id], cmap='viridis', vmin=0, vmax=1)
            plt.title(f"Mask: {id}")
            plt.axis('off')

            # Display the prediction
            plt.subplot(1, 3, 3)
            plt.imshow(prediction_np, cmap='viridis', vmin=0, vmax=1)
            plt.title("Prediction")
            plt.axis('off')

            plt.tight_layout()
            plt.savefig(os.path.join(result_path, f"final{i}.png"))
            plt.show()
            plt.close()
        except Exception as e:
            print(f"Error during visualization or saving figure at iteration {i}:", str(e))
            continue


In [None]:
try:
    start_time1 = time.time()
    print(f"Start time for visualize_predictions: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time1))}")
    # Create Dataloader for predictions using validation dataset
    visualize_predictions()
    end_time1 = time.time()
    print(f"End time for visualize_predictions: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time1))}")
    execution_time1 = end_time1 - start_time1  # Calculate the duration
    print(f"Execution time for visualize_predictions: {execution_time1}")
except Exception as e:
    print(f"Error during model training: {e}")