In [31]:
# This is MAC branch
import os

import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from einops import rearrange
from scipy.special import softmax
from torch.utils.data import DataLoader

from PL_Support_Codes.models import build_model
from PL_Support_Codes.tools import load_cfg_file
from PL_Support_Codes.datasets.utils import generate_image_slice_object
from PL_Support_Codes.utils.utils_image import ImageStitcher_v2 as ImageStitcher
from PL_Support_Codes.datasets import build_dataset, tensors_and_lists_collate_fn

from PL_Support_Codes.models.lf_model import LateFusionModel
from PL_Support_Codes.models.ef_model import EarlyFusionModel
from PL_Support_Codes.models.water_seg_model import WaterSegmentationModel

In [32]:
# Setup model parameters
dataset_name = "batch_infer"
infer_split = "all"
infer_seed_num = 0
infer_train_split_pct = 0.0
infer_num_workers = 0
n_classes_model = 3
model_used_here = "unet_orig" #unet_cbam
# optimizer_used = "adam"
model_loss_fn_a_infer = "cross_entropy"
model_loss_fn_b_infer = "cross_entropy"
model_loss_fn_a_infer_ratio = 1
model_loss_fn_b_infer_ratio = 0


base_save_dir = r"E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC"
checkpoint_path = r"E:\Zhijie_PL_Pipeline\Trained_model\UNET\checkpoints\CSDA_UNET_HPC.ckpt"

# Root folder containing the directories you waant to run inference on, under this folder, there should be different dates folder, within the dates folder, there should be imgs
ROOT_FOLDER = r"E:\Zhijie_PL_Pipeline\DATA\RGV_local\\"
# JSON file path
JSON_FILE = r"E:\Zhijie_PL_Pipeline\Zhijie_PL_Pipeline\dataset_dirs.json"

In [33]:
def infer():
    # Load configuration file.
    path_components = checkpoint_path.split(os.sep)
    experiment_dir = os.sep.join(path_components[:4]) 

    cfg_path = os.path.join(experiment_dir, 'config.yaml')
    print("check point file path: ", checkpoint_path)
    cfg = load_cfg_file(cfg_path)

    if 'model_n_classes' in cfg:
        n_classes_used = cfg.model_n_classes
    else:
        n_classes_used = n_classes_model
    
    if 'model_used' in cfg:
        model_used_infer = cfg.model_used
    else:
        model_used_infer = model_used_here


    


    if not os.path.exists(base_save_dir):
        os.makedirs(base_save_dir)
    print("Saving inference to: ",base_save_dir)
    # Load dataset.
    slice_params = generate_image_slice_object(cfg.crop_height, cfg.crop_width, min(cfg.crop_height, cfg.crop_width))
    eval_dataset = build_dataset(dataset_name,
                                 infer_split,
                                 slice_params,
                                 sensor=cfg.dataset.sensor,
                                 channels=cfg.dataset.channels,
                                 n_classes=n_classes_used,
                                 norm_mode=cfg.norm_mode,
                                 eval_region=cfg.eval_region,
                                 ignore_index=cfg.ignore_index,
                                 seed_num=infer_seed_num,
                                 train_split_pct=infer_train_split_pct,
                                 output_metadata=True,
                                 # ** allows us to pass in any additional arguments to the dataset as dictionary.
                                 **cfg.dataset.dataset_kwargs)

    eval_loader = DataLoader(eval_dataset,
                             batch_size=cfg.batch_size,
                             shuffle=False,
                             num_workers=infer_num_workers, collate_fn=tensors_and_lists_collate_fn)
    
    MODELS = {
        'ms_model': WaterSegmentationModel,
        'ef_model': EarlyFusionModel,
        'lf_model': LateFusionModel
    }
# here need to retrain model and save new config file, in the new one it
#should be   cfg.model_used
    model = MODELS[cfg.model.name].load_from_checkpoint(checkpoint_path,
                                       in_channels=eval_dataset.n_channels,
                                       n_classes=eval_dataset.n_classes,
                                       lr=cfg.lr,
                                       log_image_iter=cfg.log_image_iter,
                                       to_rgb_fcn=eval_dataset.to_RGB,
                                       ignore_index=eval_dataset.ignore_index,
                                       model_used=model_used_infer,
                                       model_loss_fn_a = model_loss_fn_a_infer,
                                       model_loss_fn_b = model_loss_fn_b_infer,
                                       model_loss_fn_a_ratio = model_loss_fn_a_infer_ratio,
                                       model_loss_fn_b_ratio = model_loss_fn_b_infer_ratio,
                                       **cfg.model.model_kwargs)
    model._set_model_to_eval()

    # Get device.
    if torch.cuda.is_available():
        device = 'cuda'
        print("!!!!!! CUDA is available!!!!!!")
    else:
        device = 'mps'
        print("!!!!!! CUDA is not available, using MPS !!!!!!")
    model = model.to(device)

    # Generate predictions on target dataset.
    pred_canvases = {}
    with torch.no_grad(): #no_grad() prevents gradiant calculation, which is not needed for inference.
        # breakpoint()
        for batch in tqdm(eval_loader, colour='green', desc='Generating predictions'):
            # Move batch to device.
            for key, value in batch.items():
                if isinstance(value, torch.Tensor):
                    batch[key] = value.to(dtype=torch.float32).to(device)

            # Generate predictions.
            # this pass the current batch into model, to generate prediction, 
            #at this stage, the output is the raw output that is the probability distribution over the classes for the corresponding pixel in the input image. This distribution can be interpreted as 
            #the model's confidence in each class for that pixel. They are the raw score of what model thinks the possibility of each class for each pixel.
            output = model(batch).detach().cpu().numpy()
            # convert the each class probability distribution to the softmax probability distribution. meaning that the probabilities will add up to 1 between different classes
            preds = softmax(output, axis=1)

            input_images = batch['image'].detach().cpu().numpy()
            # rearrange the tensor to the format of (batch, height, width, channel)
            preds = rearrange(preds, 'b c h w -> b h w c')
            input_images = rearrange(input_images, 'b c h w -> b h w c')
            batch_mean = rearrange(batch['mean'], 'b c 1 1 -> b 1 1 c').detach().cpu().numpy()
            batch_std = rearrange(batch['std'], 'b c 1 1 -> b 1 1 c').detach().cpu().numpy()

            for b in range(output.shape[0]):# output.shape[0] is the batch size. so this code is iterating through each image in the batch

                pred = preds[b]
                metadata = batch['metadata'][b]
                input_image = input_images[b]
                region_name = metadata['region_name']

                # Check if image stitcher exists for this region.
                if region_name not in pred_canvases.keys():
                    # Get base save directories.
                    pred_save_dir = os.path.join(base_save_dir, region_name + '_pred')

                    # Initialize image stitchers.
                    pred_canvases[region_name] = ImageStitcher(pred_save_dir, save_backend='tifffile', save_ext='.tif')
                
                # Add input image and prediction to stitchers.
                unnorm_img = (input_image * batch_std[b]) + batch_mean[b]
                image_name = os.path.splitext(os.path.split(metadata['image_path'])[1])[0]
                pred_canvases[region_name].add_image(pred, image_name, metadata['crop_params'], metadata['crop_params'].og_height, metadata['crop_params'].og_width)

    # Convert stitched images to proper format.
    for region_name in pred_canvases.keys():
        # Combine images.
        pred_canvas = pred_canvases[region_name].get_combined_images()

        for image_name, image in pred_canvas.items():
            # Figure out the predicted class.
            pred = np.clip(image.argmax(axis=2), 0, 1)
            # save_path = os.path.join(pred_canvases[region_name].save_dir, image_name + '.tif')
            if not os.path.exists(os.path.join(base_save_dir, region_name)):
                os.makedirs(os.path.join(base_save_dir, region_name))
            save_path = os.path.join(base_save_dir, region_name, image_name + '.tif')
            print(f'Saving {save_path}')
            Image.fromarray((pred*255).astype('uint8')).save(save_path)

In [34]:
import os
import json
from tqdm import tqdm

counter = 1
# Loop through each sub-directory in the root folder
for dir in os.listdir(ROOT_FOLDER):
    full_dir_path = os.path.join(ROOT_FOLDER, dir)
    if os.path.isdir(full_dir_path):
        FOLDER_NAME = full_dir_path

        # Update the JSON file
        with open(JSON_FILE, 'r') as json_file:
            data = json.load(json_file)
            data['batch_infer'] = FOLDER_NAME

        with open(JSON_FILE, 'w') as json_file:
            json.dump(data, json_file)

        # Execute the command
        print("This is the ", counter, "th iteration, out of ", len(os.listdir(ROOT_FOLDER)))
        print("This is the ", counter, "th iteration, out of ", len(os.listdir(ROOT_FOLDER)))
        print("We are infering ", full_dir_path)
        print("We are infering ", full_dir_path)
        infer()
        

Seed set to 0


This is the  1 th iteration, out of  1
This is the  1 th iteration, out of  1
We are infering  E:\Zhijie_PL_Pipeline\DATA\RGV_local\\RGV_240603_30
We are infering  E:\Zhijie_PL_Pipeline\DATA\RGV_local\\RGV_240603_30
check point file path:  E:\Zhijie_PL_Pipeline\Trained_model\UNET\checkpoints\CSDA_UNET_HPC.ckpt
Saving inference to:  E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC
RGV_240603_30
Number of images in all dataset: 30


Lightning automatically upgraded your loaded checkpoint from v1.8.2 to v2.2.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint E:\Zhijie_PL_Pipeline\Trained_model\UNET\checkpoints\CSDA_UNET_HPC.ckpt`


Model used!!!!!!!!!:  <class 'PL_Support_Codes.models.unet.UNet_Orig'>
!!!!!!!!!!!!
!!!!!!!!!!!!
Model used:  unet_orig
n_classes:  3
in_channels:  {'ms_image': 4}
ignore_index:  2
optimizer_name:  adam
0.0005
!!!!!!!!!!!!
!!!!!!!!!!!!
!!!!!! CUDA is available!!!!!!


Generating predictions: 100%|[32m██████████[0m| 30/30 [00:13<00:00,  2.15it/s]
Combining images: 100%|[32m██████████[0m| 30/30 [00:01<00:00, 21.60it/s]


Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623La_Paloma_Northeast.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623La_Paloma_SouthCentral.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623_Heidelberg_Indian_Hills_Capisallo_East.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623_Heidelberg_Indian_Hills_Capisallo_Far_North_West.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623_Heidelberg_Indian_Hills_Capisallo_North_West.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623_Heidelberg_Indian_Hills_Capisallo_Northeast.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623_Heidelberg_Indian_Hills_Capisallo_West.tif
Saving E:\Zhijie_PL_Pipeline\Infered_result\RGV30_CSDA_UNET_HPC\RGV_240603_30\20180623_Heidelberg_Indian_H

In [35]:
# checkpoint_path = r"E:\Zhijie_PL_Pipeline\Trained_model\Unet_PS_models\checkpoints\model-epoch=06-val_MulticlassJaccardIndex=0.8755.ckpt"
# path_components = checkpoint_path.split(os.sep)
# experiment_dir = os.sep.join(path_components[:4]) 
# # experiment_dir = os.path.dirname(checkpoint_path)
# # experiment_dir = os.path.join(*checkpoint_path.split(os.sep)[:-2])

# print("experiment_dir: ", experiment_dir)
# print(checkpoint_path.split('\\')[:-2])
