In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import accuracy_score
from torch import nn, optim
from tqdm.autonotebook import tqdm
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split, RandomSampler
from torch.utils.tensorboard import SummaryWriter
import argparse  
from sklearn.metrics import jaccard_score

from model_segmentation import *
from data_segmentation import create_dataset

np.random.seed(3)
torch.manual_seed(3)

In [None]:
# load data
valdata = create_dataset(datadir='/path/to/images',
                         seglabeldir='/path/to/segmentation_labels/', mult=3)

batch_size = 1 # 1 to create diagnostic images, any value otherwise
all_dl = DataLoader(valdata, batch_size=batch_size, shuffle=True)
progress = tqdm(enumerate(all_dl), total=len(all_dl))

# load model
model.load_state_dict(torch.load(
    '/path/to/segmentation.model', map_location=torch.device('cpu')))
model.eval()


save_path = '/path/to/segImages/'

# define loss function
loss_fn = nn.BCEWithLogitsLoss()

# run through test data
all_ious = []
all_accs = []
all_arearatios = []
for i, batch in progress:
    x, y = batch['img'].float().to(device), batch['fpt'].float().to(device)
    idx = batch['idx']

    output = model(x)

    # obtain binary prediction map
    pred = np.zeros(output.shape)
    pred[output >= 0] = 1

    # derive Iou score
    cropped_iou = []
    for j in range(y.shape[0]):
        z = jaccard_score(y[j].flatten().detach().numpy(),
                          pred[j][0].flatten(), zero_division=1)
        if (np.sum(pred[j][0]) != 0 and
            np.sum(y[j].detach().numpy()) != 0):
            cropped_iou.append(z)
    all_ious = [*all_ious, *cropped_iou]
    
    # derive scalar binary labels on a per-image basis
    y_bin = np.array(np.sum(y.detach().numpy(),
                            axis=(1,2)) != 0).astype(int)
    prediction = np.array(np.sum(pred,
                               axis=(1,2,3)) != 0).astype(int)

    # derive image-wise accuracy for this batch
    all_accs.append(accuracy_score(y_bin, prediction))

    # derive binary segmentation map from prediction
    output_binary = np.zeros(output.shape)
    output_binary[output.cpu().detach().numpy() >= 0] = 1

    # derive smoke areas
    area_pred = np.sum(output_binary, axis=(1,2,3))
    area_true = np.sum(y.cpu().detach().numpy(), axis=(1,2))

    # derive smoke area ratios
    arearatios = []
    for k in range(len(area_pred)):
        if area_pred[k] == 0 and area_true[k] == 0:
            arearatios.append(1)
        elif area_true[k] == 0:
            arearatios.append(0)
        else:
            arearatios.append(area_pred[k]/area_true[k])
    all_arearatios = np.ravel([*all_arearatios, *arearatios])


    if batch_size == 1:

        if prediction == 1 and y_bin == 1:
            res = 'true_pos'
        elif prediction == 0 and y_bin == 0:
            res = 'true_neg'
        elif prediction == 0 and y_bin == 1:
            res = 'false_neg'
        elif prediction == 1 and y_bin == 0:
            res = 'false_pos'

        # create plot
        f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(1, 3))

        # RGB plot
        ax1.imshow(0.2+1.5*(np.dstack([x[0][3], x[0][2], x[0][1]])-
                    np.min([x[0][3].numpy(),
                            x[0][2].numpy(),
                            x[0][1].numpy()]))/
                   (np.max([x[0][3].numpy(),
                            x[0][2].numpy(),
                            x[0][1].numpy()])-
                    np.min([x[0][3].numpy(),
                            x[0][2].numpy(),
                            x[0][1].numpy()])),
                   origin='upper')
        ax1.set_title({'true_pos': 'True Positive',
                       'true_neg': 'True Negative',
                       'false_pos': 'False Positive',
                       'false_neg': 'False Negative'}[res],
                      fontsize=8)
        ax1.set_xticks([])
        ax1.set_yticks([])

        # false color plot
        ax2.imshow(0.2+(np.dstack([x[0][0], x[0][9], x[0][10]])-
                    np.min([x[0][0].numpy(),
                            x[0][9].numpy(),
                            x[0][10].numpy()]))/
                   (np.max([x[0][0].numpy(),
                            x[0][9].numpy(),
                            x[0][10].numpy()])-
                    np.min([x[0][0].numpy(),
                            x[0][9].numpy(),
                            x[0][10].numpy()])),
                   origin='upper')

        ax2.set_xticks([])
        ax2.set_yticks([])

        # segmentation ground-truth and prediction
        ax3.imshow(y[0], cmap='Reds', alpha=0.3)
        ax3.imshow(pred[0][0], cmap='Greens', alpha=0.3)
        ax3.set_xticks([])
        ax3.set_yticks([])

        this_iou = jaccard_score(y[0].flatten().detach().numpy(),
                                 pred[0][0].flatten(), zero_division=1)
        ax3.annotate("IoU={:.2f}".format(this_iou), xy=(5,15), fontsize=8)

        f.subplots_adjust(0.05, 0.02, 0.95, 0.9, 0.05, 0.05)

        plt.savefig(res+(os.path.split(batch['imgfile'][0])[1]).\
                    replace('.tif', '_eval.png').replace(':', '_'), dpi=200)
        plt.close()

print('iou:', len(all_ious), np.average(all_ious))
print('accuracy:', len(all_accs), np.average(all_accs))
print('mean area ratio:', len(all_arearatios), np.average(all_arearatios),
      np.std(all_arearatios)/np.sqrt(len(all_arearatios)-1))

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
save_path = '/path/to/segImages'

# Define loss function
loss_fn = nn.BCEWithLogitsLoss()

images = []
all_ious = []
all_accs = []
all_arearatios = []
count1 = 0
count2 = 0
count3 = 0
count4 = 0

for i, batch in progress:
    x, y = batch['img'].float().to(device), batch['fpt'].float().to(device)
    idx = batch['idx']

    output = model(x)
    pred = torch.sigmoid(output)  # Apply sigmoid to convert logits to probabilities


    # Derive IoU scores
    cropped_iou = []
    for j in range(y.shape[0]):
        true_mask = y[j].flatten().detach().cpu().numpy()
        print("Pred[j][0]", pred[j][0])
        pred_mask = (pred[j][0] > 0.4).flatten().cpu().numpy()  # Threshold the prediction at 0.5
        #print(np.sum(pred_mask))
        #print(np.sum(true_mask))

        if np.sum(true_mask) == 0 and np.sum(pred_mask) == 0:
            continue

        iou = jaccard_score(true_mask, pred_mask, zero_division=1)
        #print(iou)
        cropped_iou.append(iou)

    all_ious.extend(cropped_iou)

    # Derive binary labels on a per-image basis
    y_bin = np.array(np.sum(y.detach().cpu().numpy(), axis=(1, 2)) != 0).astype(int)
    prediction = np.array(torch.sum(pred > 0.5, dim=(1, 2, 3)).cpu().numpy() != 0).astype(int)

    # Calculate image-wise accuracy for this batch
    all_accs.append(accuracy_score(y_bin, prediction))

    # Derive binary segmentation map from prediction
    output_binary = (pred > 0.5).cpu().detach().numpy()

    # Derive smoke areas
    area_pred = np.sum(output_binary, axis=(1, 2, 3))
    area_true = np.sum(y.cpu().detach().numpy(), axis=(1, 2))

    # Derive smoke area ratios
    arearatios = []
    for k in range(len(area_pred)):
        if area_pred[k] == 0 and area_true[k] == 0:
            arearatios.append(1)
        elif area_true[k] == 0:
            arearatios.append(0)
        else:
            arearatios.append(area_pred[k] / area_true[k])
    all_arearatios.extend(arearatios)

    
    if batch_size == 1:
        if prediction == 1 and y_bin == 1:
            res = 'true_pos'
            count1+=1
        elif prediction == 0 and y_bin == 0:
            res = 'true_neg'
            count2+=1
        elif prediction == 0 and y_bin == 1:
            res = 'false_neg'
            count3+=1
        elif prediction == 1 and y_bin == 0:
            res = 'false_pos'
            count4+=1

        # Create plot
        f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(1, 3))

        # RGB plot
        ax1.imshow(0.2 + 1.5 * (np.dstack([x[0][3], x[0][2], x[0][1]]) - np.min([x[0][3].numpy(), x[0][2].numpy(), x[0][1].numpy()])) / (np.max([x[0][3].numpy(), x[0][2].numpy(), x[0][1].numpy()]) - np.min([x[0][3].numpy(), x[0][2].numpy(), x[0][1].numpy()])), origin='upper')
        ax1.set_title({'true_pos': 'True Positive', 'true_neg': 'True Negative', 'false_pos': 'False Positive', 'false_neg': 'False Negative'}[res], fontsize=8)
        ax1.set_xticks([])
        ax1.set_yticks([])

        # False color plot
        ax2.imshow(0.2 + (np.dstack([x[0][0], x[0][9], x[0][10]]) - np.min([x[0][0].numpy(), x[0][9].numpy(), x[0][10].numpy()])) / (np.max([x[0][0].numpy(), x[0][9].numpy(), x[0][10].numpy()]) - np.min([x[0][0].numpy(), x[0][9].numpy(), x[0][10].numpy()])), origin='upper')
        ax2.set_xticks([])
        ax2.set_yticks([])

        # Segmentation ground-truth and prediction
        ax3.imshow(y[0], cmap='Reds', alpha=0.3)
        ax3.imshow(pred[0][0].detach().numpy(), cmap='Greens', alpha=0.3)
        ax3.set_xticks([])
        ax3.set_yticks([])

        this_iou = jaccard_score(y[0].flatten().detach().cpu().numpy().astype(bool), pred[0][0].flatten().detach().cpu().numpy() >= 0.004, zero_division=1)
        #print("this_iou: ", this_iou)
        ax3.annotate("IoU={:.2f}".format(this_iou), xy=(5, 15), fontsize=8)

        f.subplots_adjust(0.05, 0.02, 0.95, 0.9, 0.05, 0.05)

        # Construct the filename
        filename = os.path.split(batch['imgfile'][0])[1].replace('.tif', '_eval.png').replace(':', '_')

        # Set the complete save path including the filename
        complete_path = os.path.join(save_path, filename)

        # Save the image at the specified path
        plt.savefig(complete_path, dpi=200)

        # Append the image to the list
        images.append(complete_path)

        plt.close()

print('iou:', len(all_ious), np.average(all_ious))
print('accuracy:', len(all_accs), np.average(all_accs))
print('mean area ratio:', len(all_arearatios), np.average(all_arearatios), np.std(all_arearatios) / np.sqrt(len(all_arearatios) - 1))
print(count1," ",count2," ",count3," ",count4)