Import Section

In [None]:
import os
os.environ["MXNET_CUDNN_LIB_CHECKING"] = "0"
os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
import time
import random
import numpy as np
import mxnet as mx
from mxnet import  autograd, context
from mxnet.base import MXNetError
from mxnet.gluon.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

import myModel
from image_Dictionary import ImageDict
import const
from my_Save import saveAsCSV, SaveModels
from mx_Train import myTrain
from decode.postprocessing.instance_segmentation import InstSegm
from myPlots import lossPlot, visualize_all, plotPredictedImape
from iou import  get_iou
from image_Augmentor import GeoTiffDataset, SatelliteImageAugmentor

Set image type VNIR or NDV

In [None]:
isVnir = True
imageType = "NDV"
if isVnir:
    imageType = "VNIR"

Create output folders

In [None]:
ctx=context.gpu()
mx.nd.waitall()
numberOfimages =648

input_directory = const.images_2022
output_directory = os.path.join(const.result_2022_aug, imageType, str(numberOfimages))
output_models= os.path.join(output_directory,"models")
result_path = os.path.join(output_directory, "result")
lossFile =   os.path.join(output_directory,"loss.csv") 

def makedir(path):
    os.makedirs(path, exist_ok=True)
    print(path)
    
for i in [output_directory, output_models, result_path]:
    makedir(i)

# Image and Mask Loading with Preprocessing

The following code is used to load satellite images and their corresponding masks, with an option to preprocess them by selecting all available bands or just the standard three (e.g., RGB).

# Parameters

isAllband:

True: Loads all available spectral bands of the image.

False: Loads only the three standard bands (typically RGB).

numberOfimages:
Specifies how many images to load.

Set this to control the size of your dataset for training and validation.

In [None]:
images_2022= ImageDict(const.images_2022,  False)
image_dict_2022  = images_2022.load_tif_files(imageType, numberOfimages = numberOfimages, isAllband= True)
masks_2022 = ImageDict(const.masks_2022, True)
mask_dict_2022  = masks_2022.load_tif_files(imageType, image_dict_2022, numberOfimages = numberOfimages, isAllband= True)
train_ids, val_ids = train_test_split(list(mask_dict_2022.keys()), test_size=0.2, random_state=42)

In [None]:
images_2010 = ImageDict(const.images_2010, False)
testimages = 648
image_dict_2010 = images_2010.load_tif_files(imageType, numberOfimages=testimages, isAllband= True)
output_directory_2010 = os.path.join(const.result_2022_3, imageType)
makedir(output_directory_2010)

# Loading Images and Masks
Extract image data from a dictionary based on a list of IDs.

In [None]:
def get_images(ids, image_dict):
    data=  np.array([image_dict[id].image for id in ids])
    return mx.nd.array(data)

train_images =get_images(train_ids, image_dict_2022)
train_masks = get_images(train_ids, mask_dict_2022)
val_images = get_images(val_ids, image_dict_2022)
val_masks = get_images(val_ids, mask_dict_2022)

Batched data loaders for training and validation using MXNet's DataLoader, enabling efficient mini-batch processing

If you are using without hyperparameter tuning or Mixup and Cutmix please this below code

In [None]:
batch_size=4
try:
    train_dataset = mx.gluon.data.ArrayDataset(train_images, train_masks)
    train_loader = mx.gluon.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=0,shuffle=True)
    val_loader = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(val_images, val_masks), batch_size=batch_size, num_workers=0, shuffle=False)

except Exception as e:
    print(f"Error creating data loaders: {e}")

This code filters the training data from the full dataset, applies data augmentation using SatelliteImageAugmentor, and prepares custom and standard data loaders for training and validation.
Use this you are using the Augmentation

In [None]:
def creatImgDict(ids , dict):
    newdict = {}
    for i in ids:
        newdict[i] = dict[i]
    return newdict
    
new_image_dict_2022 = creatImgDict(train_ids, image_dict_2022)
new_mask_dict_2022 = creatImgDict(train_ids, mask_dict_2022)

# Initialize Augmentor and Dataset
augmentor = SatelliteImageAugmentor()
train_dataset = GeoTiffDataset(new_image_dict_2022, new_mask_dict_2022, augmentor=augmentor)
print('train_dataset', len(train_dataset))

# Create DataLoader
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(val_images, val_masks), batch_size=batch_size, num_workers=0, shuffle=False)


Main training section

Models per epoch are saved.

Validation and Training loss per epoch are saved in a csv file

In [None]:
# Maximum number of retries
max_retries = 5
retry_count = 0

# Retry logic
while retry_count < max_retries:
    try:
        mxTn = myTrain(train_loader, val_loader)
        loss_each_epoch, model_list, epoch = mxTn.train(ctx, epochs = 50)
        saveAsCSV(["Current Epoch", "Training Loss", "Validation Loss"], lossFile, loss_each_epoch)
        SaveModels(output_models, model_list)
        lossPlot(loss_each_epoch, output_directory)
        break
    except MXNetError  as e:
        if 'CUDNN_STATUS_EXECUTION_FAILED' in str(e):
            print(f"cuDNN execution failed. Retrying... ({retry_count + 1}/{max_retries})")
            mx.nd.waitall()  # Clear GPU memory
            time.sleep(5) # Wait for a few seconds before retrying
            retry_count += 1 # Increment the retry counter
        else:
            raise  # If it's another error, raise it

# Check if maximum retries were reached
if retry_count == max_retries:
    print("Maximum retries reached. Training failed due to cuDNN error.")

Returns the most recent model checkpoint file from the output folder.

In [None]:
def get_model_file_name():
    files = os.listdir(output_models)  # Get all files in the folder
    if files:
        last_file = os.path.join(output_models, f'model_VNIR_{len(files)-1}.params')
        print(f"This model is using: {last_file}")
        return last_file
    else:
        print("The folder is empty.")

Loads an image and its corresponding mask (if available) based on the year (2022 or 2010)

In [None]:
def get_img_metadata(id, is2022):
    if is2022:
        img = images_2022.getImage(id, image_dict_2022, ctx)
        mask = masks_2022.getImage(id, mask_dict_2022, ctx)
        currentMetadata = image_dict_2022[id]
    else:
        img = images_2010.getImage(id, image_dict_2010, ctx)
        print('img:', img.shape)
        currentMetadata = image_dict_2010[id]
        print('currentMetadata:', currentMetadata.image.shape)
        mask = None
    return img, mask, currentMetadata

Returns the appropriate reference shapefile path based on the year.

In [None]:
def get_ref_path(is2022):
    ref_path = const.output_ref_2022    
    if not is2022:
        ref_path = const.output_ref_2010
    return ref_path

   - Loads the model and performs predictions on a list of image IDs.
   - Outputs: segmentation, boundaries, distances, and instance masks.
   - Visualizes results, saves prediction shapefiles, and computes IoU scores.
   - All outputs (visuals + metrics) are stored in `result_path`.

In [None]:
def visualize_predictions(result_path, val_ids, t_ext , t_bound , is2022):    
    print(f"Starting visualization with t_ext = {t_ext}, t_bound = {t_bound}")
    modelPath = rf"{get_model_file_name()}"
    ref_path = get_ref_path(is2022)   

    netPredict = myModel.MyFractalResUNetcmtsk(True, modelPath, ctx)
    ious=[]
    ious.append({"ID": f't_ext: {t_ext}',"IOU": f't_bound: {t_bound}'}) 
    plotColl = []

    for id in val_ids:  # Limit to 'num_images' for visualization
        print(f"Processing image ID: {id}")
        try:
            img, mask, currentMetadata = get_img_metadata(id, is2022)
            with autograd.predict_mode():  
                outputs = netPredict.net(img) 
                print(f"I am here debugging: {id}")
                pred_segm  = np.array(outputs[0][0,1,:,:].asnumpy())
                pred_bound =  np.array(outputs[1][0,1,:,:].asnumpy())
                pred_dists =  np.array(outputs[2][0,1,:,:].asnumpy()) 
                pred_segm = 1-pred_segm
                inst =InstSegm(pred_segm, pred_bound, t_ext=t_ext, t_bound=t_bound)   # perform instance segmentation
                inst = np.nan_to_num(inst, nan=0)
                if is2022:
                    imgColl = plotPredictedImape(id, img, mask, pred_segm, pred_bound, pred_dists, inst, ref_path)
                    plotColl.append(imgColl)
                output_shapefile_path = visualize_all(id, img, currentMetadata, outputs, pred_segm, pred_bound, inst, result_path)
                print("Start IOU calculation")
                csv_file_path = os.path.join(result_path, str(id), "iou.csv")
                iou_score= get_iou(os.path.join(ref_path, f'tile_{id}.shp'), os.path.join(output_shapefile_path, f'{str(id)}.shp'))
                ious.append({ "ID": id,"IOU": iou_score })
                saveAsCSV(["ID", "IOU"], csv_file_path, ious, True)
        except Exception as e:
            print(f"Error processing image ID {id}: {e}")
    
    return plotColl
       
def visualize(result_path, val_ids,  t_ext , t_bound , is2022 = True):
    random_val_ids = random.choice(val_ids) # Choose a random validation ID
    return visualize_predictions(result_path, val_ids ,t_ext = t_ext, t_bound = t_bound, is2022 = is2022)

# Applying model in 2022 images

In [None]:
#This is for vnir: t_ext = 0.6, t_bound = 0.1
#6612
results_2022 = visualize(result_path, val_ids, t_ext = 0.6, t_bound = 0.1, is2022 = True)

# Applying model in 2010 images

In [None]:
print(output_directory_2010)
#list(image_dict_2010.keys())
results_2010 =visualize(output_directory_2010,[206] ,  t_ext = 0.6, t_bound = 0.1, is2022= False)