In [None]:
# the script can be used to train the networks within the publication 

In [None]:
import os
from PIL import Image
import numpy as np

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

from torchvision.utils import make_grid
from torchvision import transforms
from torchvision.transforms import functional as F, ToTensor, InterpolationMode
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights 

import torch.nn as nn
import torch.optim as optim

class SegmentationDataset(Dataset):
    def __init__(self, root_dirs, resize=256, crop_size=256):
        self.root_dirs = root_dirs
        self.image_dirs = [os.path.join(root_dir, 'image') for root_dir in root_dirs]
        self.mask_dirs = [os.path.join(root_dir, 'mask') for root_dir in root_dirs]
        self.image_filenames = self.collect_image_filenames()
        self.transforms = transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.Resize([resize], interpolation=InterpolationMode.BILINEAR),
            transforms.ToTensor()
        ])

    def collect_image_filenames(self):
        image_filenames = []
        for image_dir in self.image_dirs:
            image_filenames += os.listdir(image_dir)
        return image_filenames

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_name = self.image_filenames[idx]
        image_path = self.find_image_path(image_name)
        mask_path = self.find_mask_path(image_name)

        image = self.load_image(image_path)
        mask = self.load_mask(mask_path)

        return image, mask

    def find_image_path(self, image_name):
        for image_dir in self.image_dirs:
            image_path = os.path.join(image_dir, image_name)
            if os.path.exists(image_path):
                return image_path
        raise FileNotFoundError(f"Image file not found: {image_name}")

    def find_mask_path(self, image_name):
        for mask_dir in self.mask_dirs:
            mask_path = os.path.join(mask_dir, image_name)
            if os.path.exists(mask_path):
                return mask_path
        raise FileNotFoundError(f"Mask file not found: {image_name}")

    def load_image(self, path):
        image = Image.open(path).convert('RGB')
        image = self.transforms(image)
        return image

    def load_mask(self, path):
        mask = Image.open(path).convert('L')
        mask = self.transforms(mask)
        mask = torch.squeeze(mask, dim=0)  # Squeeze the mask
        return mask


In [None]:
use_cuda = True

if use_cuda and not torch.cuda.is_available():
    print("Error: cuda requested but not available, will use cpu instead!")
    device = torch.device('cpu')
elif not use_cuda:
    print("Info: will use cpu!")
    device = torch.device('cpu')
else:
    print(f"Info: Devices: {torch.cuda.device_count()} {torch.cuda.get_device_name(0)} GPU available, will use gpu!")
    device = torch.device('cuda')
    
print(f"Number of CPU cores: {os.cpu_count()}")

In [None]:
## fixed parameters
batch_size = 16
learning_rate = 0.00001
epochs = 100

dataset_path = '/share/data1/pv_segmentation/'

In [None]:
def plot_images_masks_predictions(images, masks, predicted_masks):
    batch_size = images.shape[0]
    fig, axs = plt.subplots(batch_size, 3, figsize=(15, 5*batch_size))

    for i in range(batch_size):
        # Plot the image
        axs[i, 0].imshow(images[i].transpose(1, 2, 0))
        axs[i, 0].set_title('Image {}'.format(i+1))
        axs[i, 0].axis('off')

        # Plot the ground truth mask
        axs[i, 1].imshow(masks[i], cmap='gray')
        axs[i, 1].set_title('Ground Truth Mask {}'.format(i+1))
        axs[i, 1].axis('off')
   
        predicted_mask = predicted_masks[i]
        if predicted_mask.ndim == 2:
            axs[i, 2].imshow(predicted_mask, cmap='gray')
        elif predicted_mask.ndim == 3:
            axs[i, 2].imshow(predicted_mask.transpose(1, 2, 0))
        else:
            raise ValueError('Invalid shape of predicted mask')
        axs[i, 2].set_title('Predicted Mask {}'.format(i+1))
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()
    
def test_segmentation_network(model, dataloader):
    # Get a batch of images and masks
    images, masks = next(iter(dataloader))

    # Move images and masks to the device
    device = next(model.parameters()).device
    images = images.to(device)
    masks = masks.to(device)

    # Run the model to get the predictions
    model.eval()
    with torch.no_grad():
        outputs = model(images)['out']
        predictions = torch.sigmoid(outputs).cpu().numpy()

    # Convert tensors to numpy arrays
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()

    plot_images_masks_predictions(images, masks, predicted_masks)

    
def plot_batch(dataloader,batch_size):
    # Get a batch of images and masks
    images, masks = next(iter(dataloader))

    # Convert tensors to numpy arrays
    images = images.numpy()
    masks = masks.numpy()

    # Plot the images and masks
    fig, axs = plt.subplots(batch_size, 2, figsize=(10, 5*batch_size))

    for i in range(batch_size):
        # Plot the image
        axs[i, 0].imshow(images[i].transpose(1, 2, 0))
        axs[i, 0].set_title('Image {}'.format(i+1))
        axs[i, 0].axis('off')

        # Plot the mask
        axs[i, 1].imshow(masks[i], cmap='gray')
        axs[i, 1].set_title('Mask {}'.format(i+1))
        axs[i, 1].axis('off')

    plt.tight_layout()
    plt.show()   
    
# import metrics
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryPrecision, BinaryF1Score, BinaryJaccardIndex
calculate_accuracy = BinaryAccuracy()
calculate_precision = BinaryPrecision()
calculate_recall = BinaryRecall()
calculate_f1_score = BinaryF1Score()
calculate_iou = BinaryJaccardIndex()

In [None]:
# load and prepare model
def load_model():
    model = models.segmentation.deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1)
    model.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
    num_classes = 1  # Assuming binary segmentation (1 class)
    model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, num_classes)

        # Set up the device (CPU or GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return(model)

In [None]:
from torch.utils.tensorboard import SummaryWriter

def train_segmentation_network(model, dataset, dataset_eval, num_epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths):
    """Trains a segmentation network using the provided dataset and saves the trained model.

    Args:
        model (torch.model): The model for training
        dataset (torch.utils.data.Dataset): The dataset for training.
        num_epochs (int): The number of training epochs.
        batch_size (int): The batch size for training.
        learning_rate (float): The learning rate for optimization.
        save_path (str): The file path to save the trained model.
        log_dir (str): The directory path to save the TensorBoard logs.

    """
    # Instantiate the DeepLabV3_ResNet101 model with pretrained weights
   
    # Define the loss function and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Create the data loader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Create the data loader for evaluation
    eval_dataloader = DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

    # Set up TensorBoard writer
    writer = SummaryWriter(log_dir=log_dir)

    # Training loop
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        running_loss = 0.0
        model.train()

        for images, masks in dataloader:
            # Move images and masks to the device
            images = images.to(device)
            masks = masks.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)['out']
            outputs = torch.squeeze(outputs)  # Remove the extra dimensions

            loss = criterion(outputs, masks)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(dataset)
        print('Loss: {:.4f}'.format(epoch_loss))

        best_iou = 0.0  # Variable to track the highest IoU
        
        # Evaluate the model and plot the images, masks, and predicted masks every 10th epoch
        if (epoch + 1) % 10 == 0:
            model.eval()
            with torch.no_grad():
                accuracy_list = []
                precision_list = []
                recall_list = []
                f1_score_list = []
                iou_list = []

                for images, masks in eval_dataloader:
                    images = images.to(device)
                    masks = masks.to(device)

                    outputs = model(images)['out']
                    predicted_masks = torch.sigmoid(outputs) > 0.5

                    images = images.cpu()
                    masks = masks.cpu()
                    predicted_masks = predicted_masks.cpu()

                    threshold = 0.5
                    mask = (masks >= threshold).int()
                    predicted_mask = (predicted_masks >= threshold).int()

                    accuracy = calculate_accuracy(predicted_mask.squeeze(), mask)
                    precision = calculate_precision(predicted_mask.squeeze(), mask)
                    recall = calculate_recall(predicted_mask.squeeze(), mask)
                    f1_score = calculate_f1_score(predicted_mask.squeeze(), mask)
                    iou = calculate_iou(predicted_mask.squeeze(), mask)

                    accuracy_list.append(accuracy)
                    precision_list.append(precision)
                    recall_list.append(recall)
                    f1_score_list.append(f1_score)
                    iou_list.append(iou)

                # Calculate average metrics
                avg_accuracy = sum(accuracy_list) / len(accuracy_list)
                avg_precision = sum(precision_list) / len(precision_list)
                avg_recall = sum(recall_list) / len(recall_list)
                avg_f1_score = sum(f1_score_list) / len(f1_score_list)
                avg_iou = sum(iou_list) / len(iou_list)

                # Print average evaluation metrics
                print('Accuracy: {:.4f}'.format(avg_accuracy))
                print('Precision: {:.4f}'.format(avg_precision))
                print('Recall: {:.4f}'.format(avg_recall))
                print('F1 Score: {:.4f}'.format(avg_f1_score))
                print('IoU: {:.4f}'.format(avg_iou))

                # Write average validation metrics to TensorBoard
                writer.add_scalar('Accuracy', avg_accuracy, epoch)
                writer.add_scalar('Precision', avg_precision, epoch)
                writer.add_scalar('Recall', avg_recall, epoch)
                writer.add_scalar('F1 Score', avg_f1_score, epoch)
                writer.add_scalar('IoU', avg_iou, epoch)

                # Plot the images, masks, and predicted masks
                # plot_images_masks_predictions(images.numpy(), masks.numpy(), predicted_masks.numpy())

                # Save the trained model

                if avg_iou > best_iou:
                    best_iou = avg_iou
                    file_list = [path.split('/')[-1] for path in dataset_paths]
                    marker = '_'.join(file_list)
                    torch.save(model, save_path + "/model_" + str(marker) + ".pt")
                    print('Model saved to:', save_path)

    # Save the TensorBoard writer
    writer.flush()
    writer.close()
    print('TensorBoard writer saved to:', log_dir)


In [None]:
dataset_path1 = '/share/data1/pv_segmentation/PV01_train'
dataset_path2 = '/share/data1/pv_segmentation/PV02_train'
dataset_path3 = '/share/data1/pv_segmentation/PV03_train'
dataset_path4 = '/share/data1/pv_segmentation/PV08_train'
dataset_path5 = '/share/data1/pv_segmentation/PV16_train'
dataset_path6 = '/share/data1/pv_segmentation/PV32_train'

val_path1 = '/share/data1/pv_segmentation/PV01_val'
val_path2 = '/share/data1/pv_segmentation/PV02_val'
val_path3 = '/share/data1/pv_segmentation/PV03_val'
val_path4 = '/share/data1/pv_segmentation/PV08_val'
val_path5 = '/share/data1/pv_segmentation/PV16_val'
val_path6 = '/share/data1/pv_segmentation/PV32_val'

save_path = dataset_path 
log_dir = dataset_path

In [None]:
# 1. run
dataset_paths = [dataset_path1]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path1]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
# 2. run
dataset_paths = [dataset_path2]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path2]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
# 3. run
dataset_paths = [dataset_path3]

dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path3]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
# 4. run
dataset_paths = [dataset_path4]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path4]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
dataset_paths = [dataset_path5]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path5]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
dataset_paths = [dataset_path6]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path6]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
dataset_paths = [dataset_path1,dataset_path2]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path1]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
# 6. run
dataset_paths = [dataset_path1,dataset_path2,dataset_path3]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path1]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
# 7. run
dataset_paths = [dataset_path1,dataset_path2,dataset_path3,dataset_path4]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path1]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
dataset_paths = [dataset_path1,dataset_path2,dataset_path3,dataset_path4,dataset_path5]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path1]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
dataset_paths = [dataset_path1,dataset_path2,dataset_path3,dataset_path4,dataset_path5,dataset_path6]
#dataset_paths = [dataset_path1, dataset_path2, dataset_path3]
dataset = SegmentationDataset(dataset_paths)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

dataset_path_eval = [val_path1]
dataset_eval = SegmentationDataset(dataset_path_eval)
dataloader_eval = torch.utils.data.DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

model = load_model()
train_segmentation_network(model, dataset, dataset_eval, epochs, batch_size, learning_rate, save_path, log_dir, dataset_paths)

In [None]:
#eval(model,dataloader_eval)

In [None]:
#from ptflops import get_model_complexity_info
# Get a sample input from the dataloader
#sample_input, _ = next(iter(dataloader))
#input_size = tuple(sample_input.shape[1:])  # Get the input size

# Compute the FLOPs of your model
#flops, params = get_model_complexity_info(model, input_size)
#gflops = flops / 1e9  # Convert FLOPs to GFLOPs

# Log the GFLOPs
#print(f"GFLOPs: {gflops} billion")

In [None]:
# Example usage

# Load the model
#model = torch.load(save_path)

# Set the model to the device (CPU or GPU)
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = model.to(device)

#test_segmentation_network(model, dataloader)