In [None]:
# the script was used to validate the networks within the publication 

In [None]:
import os
from PIL import Image
import numpy as np

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

from torchvision.utils import make_grid
from torchvision import transforms
from torchvision.transforms import functional as F, ToTensor, InterpolationMode
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet101

import torch.nn as nn
import torch.optim as optim

class SegmentationDataset(Dataset):
    def __init__(self, root_dirs, resize=256, crop_size=256):
        self.root_dirs = root_dirs
        self.image_dirs = [os.path.join(root_dir, 'image') for root_dir in root_dirs]
        self.mask_dirs = [os.path.join(root_dir, 'mask') for root_dir in root_dirs]
        self.image_filenames = self.collect_image_filenames()
        self.transforms = transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.Resize([resize], interpolation=InterpolationMode.BILINEAR),
            transforms.ToTensor()
        ])

    def collect_image_filenames(self):
        image_filenames = []
        for image_dir in self.image_dirs:
            image_filenames += os.listdir(image_dir)
        return image_filenames

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        image_path = self.find_image_path(image_name)
        mask_path = self.find_mask_path(image_name)

        image = self.load_image(image_path)
        mask = self.load_mask(mask_path)

        return image, mask

    def find_image_path(self, image_name):
        for image_dir in self.image_dirs:
            image_path = os.path.join(image_dir, image_name)
            if os.path.exists(image_path):
                return image_path
        raise FileNotFoundError(f"Image file not found: {image_name}")

    def find_mask_path(self, image_name):
        for mask_dir in self.mask_dirs:
            mask_path = os.path.join(mask_dir, image_name)
            if os.path.exists(mask_path):
                return mask_path
        raise FileNotFoundError(f"Mask file not found: {image_name}")

    def load_image(self, path):
        image = Image.open(path).convert('RGB')
        image = self.transforms(image)
        return image

    def load_mask(self, path):
        mask = Image.open(path).convert('L')
        mask = self.transforms(mask)
        mask = torch.squeeze(mask, dim=0)  # Squeeze the mask
        return mask


In [2]:
use_cuda = True

if use_cuda and not torch.cuda.is_available():
    print("Error: cuda requested but not available, will use cpu instead!")
    device = torch.device('cpu')
elif not use_cuda:
    print("Info: will use cpu!")
    device = torch.device('cpu')
else:
    print(f"Info: Devices: {torch.cuda.device_count()} {torch.cuda.get_device_name(0)} GPU available, will use gpu!")
    device = torch.device('cuda')
    
print(f"Number of CPU cores: {os.cpu_count()}")

Info: Devices: 4 A100-SXM4-40GB GPU available, will use gpu!
Number of CPU cores: 128


In [3]:
## fixed parameters
batch_size = 8
learning_rate = 0.00001
epochs = 100

dataset_path = '/share/data1/pv_segmentation/'

In [4]:
def plot_images_masks_predictions(images, masks, predicted_masks):
    batch_size = images.shape[0]
    fig, axs = plt.subplots(batch_size, 3, figsize=(15, 5*batch_size))

    for i in range(batch_size):
        # Plot the image
        axs[i, 0].imshow(images[i].transpose(1, 2, 0))
        axs[i, 0].set_title('Image {}'.format(i+1))
        axs[i, 0].axis('off')

        # Plot the ground truth mask
        axs[i, 1].imshow(masks[i], cmap='gray')
        axs[i, 1].set_title('Ground Truth Mask {}'.format(i+1))
        axs[i, 1].axis('off')
   
        predicted_mask = predicted_masks[i]
        if predicted_mask.ndim == 2:
            axs[i, 2].imshow(predicted_mask, cmap='gray')
        elif predicted_mask.ndim == 3:
            axs[i, 2].imshow(predicted_mask.transpose(1, 2, 0))
        else:
            raise ValueError('Invalid shape of predicted mask')
        axs[i, 2].set_title('Predicted Mask {}'.format(i+1))
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()
    
def test_segmentation_network(model, dataloader):
    # Get a batch of images and masks
    images, masks = next(iter(dataloader))

    # Move images and masks to the device
    device = next(model.parameters()).device
    images = images.to(device)
    masks = masks.to(device)

    # Run the model to get the predictions
    model.eval()
    with torch.no_grad():
        outputs = model(images)['out']
        predictions = torch.sigmoid(outputs).cpu().numpy()

    # Convert tensors to numpy arrays
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()

    plot_images_masks_predictions(images, masks, predicted_masks)

    
def plot_batch(dataloader,batch_size):
    # Get a batch of images and masks
    images, masks = next(iter(dataloader))

    # Convert tensors to numpy arrays
    images = images.numpy()
    masks = masks.numpy()

    # Plot the images and masks
    fig, axs = plt.subplots(batch_size, 2, figsize=(10, 5*batch_size))

    for i in range(batch_size):
        # Plot the image
        axs[i, 0].imshow(images[i].transpose(1, 2, 0))
        axs[i, 0].set_title('Image {}'.format(i+1))
        axs[i, 0].axis('off')

        # Plot the mask
        axs[i, 1].imshow(masks[i], cmap='gray')
        axs[i, 1].set_title('Mask {}'.format(i+1))
        axs[i, 1].axis('off')

    plt.tight_layout()
    plt.show()   
    
# import metrics
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryPrecision, BinaryF1Score, BinaryJaccardIndex
calculate_accuracy = BinaryAccuracy()
calculate_precision = BinaryPrecision()
calculate_recall = BinaryRecall()
calculate_f1_score = BinaryF1Score()
calculate__iou = BinaryJaccardIndex()

In [5]:
def eval(model,eval_dataloader, dataset_path_eval):
    model.eval()
    with torch.no_grad():
        images, masks = next(iter(eval_dataloader))

        # Move images and masks to the device
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = model(images)['out']
        predicted_masks = torch.sigmoid(outputs) > 0.5

        # Convert tensors to numpy arrays
        images = images.cpu()
        masks = masks.cpu()
        predicted_masks = predicted_masks.cpu()

        threshold = 0.5
        mask = (masks >= threshold).int()
        predicted_mask = (predicted_masks >= threshold).int()

        # Calculate evaluation metrics
        accuracy = calculate_accuracy(predicted_mask.squeeze(), mask)
        precision = calculate_precision(predicted_mask.squeeze(), mask)
        recall = calculate_recall(predicted_mask.squeeze(), mask)
        f1_score = calculate_f1_score(predicted_mask.squeeze(), mask)
        iou = calculate__iou(predicted_mask.squeeze(), mask)

        # Print evaluation metrics
#        print('Accuracy: {:.4f}'.format(accuracy))
#        print('Precision: {:.4f}'.format(precision))
#        print('Recall: {:.4f}'.format(recall))
        print('{:.4f}'.format(f1_score))
#        print('{:.4f}'.format(iou))
#        print('used dataset s for validation', dataset_path_eval)
        
        # Write validation metrics to TensorBoard
      #  writer.add_scalar('Accuracy', accuracy, epoch)
       # writer.add_scalar('Precision', precision, epoch)
      #  writer.add_scalar('Recall', recall, epoch)
      #  writer.add_scalar('F1 Score', f1_score, epoch)
      #  writer.add_scalar('IoU', iou, epoch)

        # Plot the images, masks, and predicted masks
       # plot_images_masks_predictions(images.numpy(), masks.numpy(), predicted_masks.numpy())

        # Save the trained model
     #   torch.save(model, save_path + "/model_" + str(epoch) + ".pt")
     #   print('Model saved to:', save_path)

In [6]:
## diff vals
val_path1 = '/share/data1/pv_segmentation/PV01_val'
val_path2 = '/share/data1/pv_segmentation/PV02_val'
val_path3 = '/share/data1/pv_segmentation/PV03_val'
val_path4 = '/share/data1/pv_segmentation/PV08_val'
val_path5 = '/share/data1/pv_segmentation/PV16_val'
val_path6 = '/share/data1/pv_segmentation/PV32_val'

test_path1 = '/share/data1/pv_segmentation/PV01_test'
test_path2 = '/share/data1/pv_segmentation/PV02_test'
test_path3 = '/share/data1/pv_segmentation/PV03_test'
test_path4 = '/share/data1/pv_segmentation/PV08_test'
test_path5 = '/share/data1/pv_segmentation/PV16_test'
test_path6 = '/share/data1/pv_segmentation/PV32_test'

def run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6):
    dataset_path_eval = [test_path1]
    dataset_eval = SegmentationDataset(dataset_path_eval)
    dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=100, shuffle=False)
    eval(model,dataloader_eval,dataset_path_eval)

    dataset_path_eval = [test_path2]
    dataset_eval = SegmentationDataset(dataset_path_eval)
    dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=100, shuffle=False)
    eval(model,dataloader_eval,dataset_path_eval)

    dataset_path_eval = [test_path3]
    dataset_eval = SegmentationDataset(dataset_path_eval)
    dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=100, shuffle=False)
    eval(model,dataloader_eval,dataset_path_eval)

    dataset_path_eval = [test_path4]
    dataset_eval = SegmentationDataset(dataset_path_eval)
    dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=100, shuffle=False)
    eval(model,dataloader_eval,dataset_path_eval)

    dataset_path_eval = [test_path5]
    dataset_eval = SegmentationDataset(dataset_path_eval)
    dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=100, shuffle=False)
    eval(model,dataloader_eval,dataset_path_eval)
    
    dataset_path_eval = [test_path6]
    dataset_eval = SegmentationDataset(dataset_path_eval)
    dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=100, shuffle=False)
    eval(model,dataloader_eval,dataset_path_eval)

In [7]:
# 1
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV01.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.9599
0.7612
0.6892
0.4679
0.1681
0.4006


In [8]:
# 2
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV02.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.8595
0.9540
0.3014
0.2341
0.0887
0.1327


In [9]:
# 3
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV03.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.5703
0.0645
0.9753
0.7970
0.6864
0.4721


In [10]:
# 4
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV08.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.6596
0.7818
0.7746
0.9553
0.9118
0.6945


In [11]:
# 4
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV16.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.3154
0.1412
0.5626
0.8547
0.9500
0.8571


In [12]:
# 4
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.1734
0.0276
0.4032
0.4882
0.7760
0.8659


In [13]:
# 5
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV01_PV02.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.9574
0.9550
0.6233
0.5298
0.2594
0.4041


In [14]:
# 6
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV01_PV02_PV03.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.9585
0.9482
0.9760
0.8584
0.5933
0.6050


In [15]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV01_PV02_PV03_PV08.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.9590
0.9557
0.9717
0.9649
0.8735
0.6749


In [16]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV01_PV02_PV03_PV08_PV16.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.9587
0.9587
0.9697
0.9644
0.9517
0.7555


In [17]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV01_PV02_PV03_PV08_PV16_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.9600
0.9565
0.9746
0.9634
0.9544
0.9071


In [18]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV02_PV03_PV08_PV16_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.8704
0.9611
0.9736
0.9660
0.9550
0.9023


In [19]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV03_PV08_PV16_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.6687
0.8007
0.9735
0.9660
0.9539
0.9122


In [20]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV08_PV16_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.7355
0.7808
0.7321
0.9590
0.9558
0.8987


In [21]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV16_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.3918
0.2355
0.2425
0.7383
0.9526
0.9114


In [22]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.1734
0.0276
0.4032
0.4882
0.7760
0.8659


In [23]:
# 7
# Load the model
model = torch.load('/share/data1/mkleebauer/pv_segmentation/model_PV32.pt')

# Set the model to the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

#Run
run_validation(model, test_path1, test_path2, test_path3, test_path4, test_path5, test_path6)

0.1734
0.0276
0.4032
0.4882
0.7760
0.8659


In [24]:
#test_segmentation_network(model, torch.utils.data.DataLoader(dataset_eval, batch_size=4, shuffle=True))


In [25]:
#eval(model,dataloader_eval)

In [26]:
#eval(model,dataloader_eval)

In [27]:
#from ptflops import get_model_complexity_info
# Get a sample input from the dataloader
#sample_input, _ = next(iter(dataloader))
#input_size = tuple(sample_input.shape[1:])  # Get the input size

# Compute the FLOPs of your model
#flops, params = get_model_complexity_info(model, input_size)
#gflops = flops / 1e9  # Convert FLOPs to GFLOPs

# Log the GFLOPs
#print(f"GFLOPs: {gflops} billion")

In [28]:
# Example usage

# Load the model
#model = torch.load(save_path)

# Set the model to the device (CPU or GPU)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = model.to(device)

#test_segmentation_network(model, dataloader)