### Imports

In [12]:
import sys
sys.path.insert(0, '../')
from preprocessing.preprocess_data import get_preprocessed_data, get_google_maps_data, get_massa_data, create_preprocessed_dataset, get_deepglobe_data
import numpy as np
import torch
import torch.optim as optim
from segmentation_models_pytorch.utils.metrics import IoU, Precision, Recall, Fscore
from segmentation_models_pytorch.utils.losses import DiceLoss
from torch.utils.data import DataLoader, TensorDataset
import segmentation_models_pytorch as smp
import os
import matplotlib.pyplot as plt

### Data loading

In [5]:
N_FILES = 500

deepglobe_path = '/workspace/Road-Segmentation-Comp/data/deepglobe/train/'

print('-' * 100)
print('Getting DeepGlobe data')
deep_train_x, deep_train_y, deep_val_x, deep_val_y = get_deepglobe_data(deepglobe_path, N_FILES)
print(deep_train_x.shape, deep_train_y.shape, deep_val_x.shape, deep_val_y.shape)


train_path = "/workspace/Road-Segmentation-Comp/data/ethz-cil-road-segmentation-2024/training/"
print('-' * 100)
print('Getting Project data')
(train_x, train_y), (val_x, val_y), (ori_val_x, ori_val_y) = get_preprocessed_data(path=train_path)
print(train_x.shape, train_y.shape, val_x.shape, val_y.shape)

eliot_path = "/workspace/Road-Segmentation-Comp/data/eliot_dataset/"

print('-' * 100)
print('Getting Eliot data')
eliot_train_x, eliot_train_y, eliot_val_x, eliot_val_y = get_google_maps_data(eliot_path, N_FILES)
print(eliot_train_x.shape, eliot_train_y.shape, eliot_val_x.shape, eliot_val_y.shape)


massachusetts_path = "/workspace/Road-Segmentation-Comp/data/massa/tiff/"

print('-' * 100)
print('Getting Massachussets data')
massa_train_x, massa_train_y, massa_val_x, massa_val_y = get_massa_data(massachusetts_path, N_FILES)
print(massa_train_x.shape, massa_train_y.shape, massa_val_x.shape, massa_val_y.shape)

# Concatenate datasets

print('-' * 30)
print('Concatenating datasets')
X_train = np.concatenate([train_x, deep_train_x, eliot_train_x, massa_train_x], axis=0)
Y_train = np.round(np.concatenate([train_y, deep_train_y, eliot_train_y, massa_train_y], axis=0))
X_val = np.concatenate([val_x, deep_val_x, eliot_val_x, massa_val_x], axis=0)
Y_val = np.round(np.concatenate([val_y, deep_val_y, eliot_val_y, massa_val_y], axis=0))

print(X_train.shape, Y_train.shape, X_val.shape, Y_val.shape)

----------------------------------------------------------------------------------------------------
Getting DeepGlobe data
Listing files
Filtering files
Choosing N_FILES
Getting masks
Preprocessing files
Opening cropped images...
Opening resized images...
Preprocessing images...
Splitting data into training and validation sets...
(400, 224, 224, 3) (400, 224, 224, 1) (100, 224, 224, 3) (100, 224, 224, 1)
----------------------------------------------------------------------------------------------------
Getting Project data
CHECK
(575, 224, 224, 3) (575, 224, 224, 1) (145, 224, 224, 3) (145, 224, 224, 1)
----------------------------------------------------------------------------------------------------
Getting Eliot data
Listing files
Choosing N_FILES
Getting masks
Preprocessing files
Opening cropped images...
Opening resized images...
Preprocessing images...
Splitting data into training and validation sets...
(400, 224, 224, 3) (400, 224, 224, 1) (100, 224, 224, 3) (100, 224, 224, 1

In [7]:
# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)
Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).permute(0, 3, 1, 2)
Y_val_tensor = torch.tensor(Y_val, dtype=torch.float32).permute(0, 3, 1, 2)

# Create TensorDataset and DataLoader
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, Y_val_tensor)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

### Model

#### Setup

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model
pspnet = smp.PSPNet(
    encoder_name="resnet152", 
    encoder_weights="imagenet", 
    classes=1, 
    activation='sigmoid'
).to(device)

In [9]:
loss = DiceLoss()

metrics = [
    IoU(threshold=0.5),
    Precision(threshold=0.5),
    Recall(threshold=0.5),
    Fscore(threshold=0.5),
]

optimizer = optim.Adam(pspnet.parameters(), lr=0.0008, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=5, T_mult=1, eta_min=1e-6,
)

train_epoch = smp.utils.train.TrainEpoch(
    pspnet, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True
)

valid_epoch = smp.utils.train.ValidEpoch(
    pspnet, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [10]:
def save_validation_prediction_plot(model, val_loader, device, epoch, output_dir='./plots'):
    """
    Save validation prediction plot.
    
    Args:
        model (torch.nn.Module): The trained model.
        val_loader (torch.utils.data.DataLoader): The validation data loader.
        device (torch.device): The device to run the model on.
        epoch (int): The current epoch number.
        output_dir (str, optional): The directory to save the plot. Defaults to './plots'.
    """
    
    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    model.eval()
    with torch.no_grad():
        # Plot the images
        fig, ax = plt.subplots(5, 3, figsize=(15, 25))

        i = 0
        for val_data in val_loader:
            val_images, val_masks = val_data
            val_images = val_images.to(device)
            val_masks = val_masks.to(device)

            outputs = model(val_images)
            outputs = outputs.cpu().numpy()
            val_images = val_images.cpu().numpy()
            val_masks = val_masks.cpu().numpy()

            # Select the first image in the batch for visualization
            img = val_images[0].transpose(1, 2, 0)
            true_mask = val_masks[0].transpose(1, 2, 0)
            pred_mask = outputs[0].transpose(1, 2, 0)

            ax[i, 0].imshow(img)
            ax[i, 0].set_title("Input Image")
            ax[i, 0].axis('off')
            ax[i, 1].imshow(true_mask.squeeze(), cmap='gray')
            ax[i, 1].set_title("True Mask")
            ax[i, 1].axis('off')
            ax[i, 2].imshow(pred_mask.squeeze(), cmap='gray')
            ax[i, 2].set_title("Predicted Mask")
            ax[i, 2].axis('off')
            # Only show/save the first batch
            i += 1
            if i == 5:
                break
        # Save the plot
        plot_path = os.path.join(output_dir, f'epoch_{epoch + 1}.png')
        plt.savefig(plot_path)
        plt.close()

#### Training

In [13]:
EPOCHS = 50

best_iou_score = 0.0
train_logs_list, valid_logs_list = [], []
for i in range(0, EPOCHS):
    # Perform training & validation
    print('\nEpoch: {}'.format(i+1))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(val_loader)
    train_logs_list.append(train_logs)
    valid_logs_list.append(valid_logs)
    # Save model if a better val F1 score is obtained
    if best_iou_score < valid_logs['fscore']:
        best_iou_score = valid_logs['fscore']
        torch.save(pspnet, './model_params/pspnet_curr.pth')
        print('Model saved!')
    
    # Save a validation prediction to evaluate performance at current epoch
    save_validation_prediction_plot(pspnet, val_loader, device, i, output_dir='./predictions/pspnet')


Epoch: 1
train: 100%|██████████| 111/111 [00:14<00:00,  7.92it/s, dice_loss - 0.4753, iou_score - 0.3605, precision - 0.4708, recall - 0.6092, fscore - 0.5276]
valid: 100%|██████████| 28/28 [00:02<00:00,  9.62it/s, dice_loss - 0.4752, iou_score - 0.3646, precision - 0.5335, recall - 0.5504, fscore - 0.5277]
Model saved!

Epoch: 2
train: 100%|██████████| 111/111 [00:13<00:00,  8.43it/s, dice_loss - 0.4596, iou_score - 0.375, precision - 0.4865, recall - 0.6247, fscore - 0.5426] 
valid: 100%|██████████| 28/28 [00:02<00:00, 10.36it/s, dice_loss - 0.5427, iou_score - 0.3136, precision - 0.5513, recall - 0.4684, fscore - 0.4584]

Epoch: 3
train: 100%|██████████| 111/111 [00:12<00:00,  8.74it/s, dice_loss - 0.4433, iou_score - 0.3901, precision - 0.5108, recall - 0.6249, fscore - 0.5589]
valid: 100%|██████████| 28/28 [00:02<00:00, 10.53it/s, dice_loss - 0.4896, iou_score - 0.3506, precision - 0.5335, recall - 0.522, fscore - 0.5121] 

Epoch: 4
train: 100%|██████████| 111/111 [00:13<00:00,  

: 