**Background**

This notebook presents a comprehensive framework for semantic segmentation tasks, with a focus on fine-tuning models on the ISPRS dataset. The core structure is organized into *functional and class-based modules* for data management and model training, followed by the main training script in the final section. You may *skip the first two sections and go directly to the main training script to understand the overall training process*. If you need further details, feel free to refer back to the corresponding functions or classes as needed.

**Overview of Code Structure**
1. Data Management: Reads image files and their corresponding labels using `read_isprs_images`. Converts RGB labels to class IDs and performs resizing/normalization using the `ISPRSDataset` class.

2. Model Training Workflow:
Supports both DeepLabv3-ResNet and FCN-ResNet architectures for fine-tuning. All pretrained layers are frozen, while the classifier and auxiliary classifier layers are unfrozen (`initialize_fcn_resnet101_model` and `initialize_deeplabv3_resnet101_model`). Training and validation per epoch are handled by `train_one_epoch` and `test_model`, respectively. The `train_val` function manages hyperparameter variations and logs results using TensorBoard. This project uses intersection over union (IoU) rather than accuracy rate to handle class imbalance in the ISPRS dataset.

3. Main Training:
Integrates the data management and training modules, executing a full training pipeline with hyperparameter tuning. After training, prediction results on the test set are visualized using the best-performing model configuration.

**Result**



In [None]:
!pip install rasterio

In [None]:
# data prepared
import warnings
import rasterio
from rasterio.errors import NotGeoreferencedWarning
import kagglehub
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

warnings.filterwarnings("ignore", category=NotGeoreferencedWarning)

# data augmentation
import cv2
from torchvision import transforms
import torch
from torch.utils.data import Dataset, DataLoader

# model
from torchvision.models.segmentation import fcn_resnet101, FCN_ResNet101_Weights
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import itertools
import random
import torch.optim.lr_scheduler as lr_scheduler


# 1. Data Management

## 1.1 Loading Images

In [None]:
def read_isprs_images(image_path, label_path):
  img_file_list = os.listdir(image_path)
  lab_file_list = os.listdir(label_path)

  img_file_list = [image_path+'/'+i for i in img_file_list]
  img_file_list.sort()
  lab_file_list = [label_path+'/'+i for i in lab_file_list]
  lab_file_list.sort()

  img, lab = [], []
  for img_file, lab_file in zip(img_file_list, lab_file_list):
    with rasterio.open(img_file) as src:
      img.append(torch.tensor(src.read(), dtype=torch.float32))

    with rasterio.open(lab_file) as src:
      lab.append(torch.tensor(src.read(), dtype=torch.long))

  return img, lab

In [None]:
class ISPRSDataset(torch.utils.data.Dataset):
  """For loading the ISPRS dataset"""
  def __init__(self, images, labels):
    self.img_transform = transforms.Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    self.mask_transform = transforms.Compose([
          transforms.Resize((300, 300), interpolation=Image.NEAREST)])

    self.images = images
    self.labels = labels

  def __getitem__(self, idx):
    image = self.images[idx]
    label = self.labels[idx]

    if self.img_transform:
      image = self.img_transform(image)
      if image.ndim == 2:
          image = image.unsqueeze(0).repeat(3, 1, 1)

    if self.mask_transform:
      label_pil_for_resize = Image.fromarray(label.cpu().numpy().astype(np.uint8), mode='L')
      label_pil_resized = self.mask_transform(label_pil_for_resize)
      label = torch.from_numpy(np.array(label_pil_resized)).long()
    return image, label.long()

  def __len__(self):
    return len(self.images)

In [None]:
def voc_colormap2label():
  """Build a mapping from RGB to VOC category index (labels)"""
  colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
  for i, colormap in enumerate(ISPRS_COLORMAP):
      colormap2label[
          (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
  return colormap2label

# Colormap is the RGB value in the image, which is converted into the corresponding label value
def voc_label_indices(colormap, colormap2label):
  """Map RGB values in VOC labels to their category indices"""
  colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
  idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
          + colormap[:, :, 2])
  return colormap2label[idx]

In [None]:
def display_image_and_label(images_list, labels_list, index_to_display = None, is_class_id_label=False):
  """
  Displays an image and its corresponding label mask from lists of PyTorch tensors.
  """

  if not (0 <= index_to_display < len(images_list)):
    print(f"Error: Index {index_to_display} is out of bounds for images_list (size {len(images_list)}).")
    return
  if not (0 <= index_to_display < len(labels_list)):
    print(f"Error: Index {index_to_display} is out of bounds for labels_list (size {len(labels_list)}).")
    return

  img_tensor = images_list[index_to_display]
  label_tensor = labels_list[index_to_display]

  img_np = img_tensor.numpy()
  img_np = np.transpose(img_np, (1, 2, 0))

  if not is_class_id_label:
    label_indices_np = voc_label_indices(label_tensor.cpu(), voc_colormap2label()).numpy()
    label_np_processed = label_indices_np
  else:
    label_np_processed = label_tensor.numpy()

  def normalize_to_0_1(img_np):
      min_val = np.min(img_np)
      max_val = np.max(img_np)
      if max_val > min_val:
          img_np_scaled = (img_np - min_val) / (max_val - min_val)
      else:
          img_np_scaled = np.zeros_like(img_np)

      return img_np_scaled

  img_np_normalized = normalize_to_0_1(img_np)
  def mask_to_rgb(mask, colormap):
      """Converts a class index mask (H, W) to an RGB image (H, W, 3) using a colormap."""
      if mask.ndim != 2:
          raise ValueError("Input mask to mask_to_rgb must be 2D (H, W)")

      h, w = mask.shape
      output_rgb = np.zeros((h, w, 3), dtype=np.uint8)
      if 'ISPRS_COLORMAP' in globals():
            colormap_to_use = ISPRS_COLORMAP
      else:
            print("Warning: ISPRS_COLORMAP not found. Using a default grayscale colormap.")
            colormap_to_use = [[i, i, i] for i in range(len(np.unique(mask)))]

      for class_id, color in enumerate(colormap_to_use):
          output_rgb[mask == class_id] = color
      return output_rgb

  label_display_rgb = mask_to_rgb(label_np_processed, ISPRS_COLORMAP)

  plt.figure(figsize=(12, 6))
  plt.subplot(1, 2, 1)
  plt.imshow(img_np_normalized)
  plt.title(f"Image (Index: {index_to_display})")
  plt.axis('off')

  plt.subplot(1, 2, 2)
  plt.imshow(label_display_rgb)
  plt.title(f"Label (Index: {index_to_display})")
  plt.axis('off')

  plt.tight_layout()
  plt.show()

# 2. Model Training Workflow

## 2.1 Model Architecture: FCN_ResNet101

In [None]:
def initialize_fcn_resnet101_model(num_output_classes, device):
  weights = FCN_ResNet101_Weights.DEFAULT
  model_fcn = fcn_resnet101(weights=weights, progress=False)
  model_fcn = model_fcn.to(device)

  in_features = model_fcn.classifier[4].in_channels

  model_fcn.classifier[4] = nn.Conv2d(in_features, num_output_classes, kernel_size=(1, 1), stride=(1, 1))

  if model_fcn.aux_classifier is not None:
    aux_in_features = model_fcn.aux_classifier[4].in_channels
    model_fcn.aux_classifier[4] = nn.Conv2d(aux_in_features, num_output_classes, kernel_size=(1, 1), stride=(1, 1))

  # freeze all parameters
  for param in model_fcn.parameters():
    param.requires_grad = False

  # Unfreeze the parameters of the classifier head
  for param in model_fcn.classifier.parameters():
    param.requires_grad = True

  # Unfreeze the parameters of the auxiliary classifier
  if model_fcn.aux_classifier is not None:
    for param in model_fcn.aux_classifier.parameters():
        param.requires_grad = True

  return model_fcn

## 2.2 Model Architecture: deeplabv3_resnet101_model

In [None]:

def initialize_deeplabv3_resnet101_model(num_output_classes, device):
    """
    Initializes and configures the DeepLabV3_ResNet101 model for semantic segmentation.
    """
    weights = DeepLabV3_ResNet101_Weights.DEFAULT
    model = deeplabv3_resnet101(weights=weights, progress=False)
    model = model.to(device)

    in_features_main = model.classifier[4].in_channels
    model.classifier[4] = nn.Conv2d(in_features_main, num_output_classes, kernel_size=(1, 1), stride=(1, 1))

    if model.aux_classifier is not None:
        aux_in_features = model.aux_classifier[4].in_channels
        model.aux_classifier[4] = nn.Conv2d(aux_in_features, num_output_classes, kernel_size=(1, 1), stride=(1, 1))
    else:
        pass

    for param in model.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True

    if model.aux_classifier is not None:
        for param in model.aux_classifier.parameters():
            param.requires_grad = True
    return model

## 2.3 Training and Validation Logic

In [None]:
def train_one_epoch(model, train_loader, optimizer, loss_fn, device):
  """
  Run a full training cycle (epoch)
  """
  model.train()
  running_train_loss = 0.0

  for i, (features, labels) in enumerate(train_loader):
    features, labels = features.to(device), labels.to(device)
    optimizer.zero_grad()
    classes_preds = model(features)['out']
    training_loss = loss_fn(classes_preds, labels)
    training_loss.backward()
    optimizer.step()
    running_train_loss += training_loss.item()

  avg_train_loss = running_train_loss / len(train_loader)
  return model, avg_train_loss

In [None]:
def test_model(model, val_loader, loss_fn, device, num_classes, ignore_index):
  """
  Perform a full validation cycle (epoch)
  """
  model.eval()
  running_val_loss = 0.0
  running_val_iou = 0.0
  num_batches_with_valid_iou = 0

  with torch.no_grad():
    for j, (features, labels) in enumerate(val_loader):
      features, labels = features.to(device), labels.to(device)

      classes_preds = model(features)['out']
      val_loss = loss_fn(classes_preds, labels)
      predicted_labels = torch.argmax(classes_preds, dim=1)
      batch_iou, _ = calculate_iou_mask(
          predicted_labels.cpu().numpy(),
          labels.cpu().numpy(),
          num_classes=num_classes,
          ignore_index=ignore_index
      )

      running_val_loss += val_loss.item()
      if not np.isnan(batch_iou):
        running_val_iou += batch_iou
        num_batches_with_valid_iou += 1

  avg_val_loss = running_val_loss / len(val_loader)
  avg_val_iou = running_val_iou / num_batches_with_valid_iou if num_batches_with_valid_iou > 0 else np.nan

  return avg_val_loss, avg_val_iou

In [None]:
def train_val(model, hparams, writer):
  # Device configuration
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  print(f"Using device: {device}")

  # hypar
  lr = hparams['learning_rate']
  bs = hparams['batch_size']
  opt_name = hparams['optimizer']
  gamma = hparams['gamma']

  model = model.to(device)
  loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

  if opt_name == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=lr)
  elif opt_name == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=lr)
  else:
    raise ValueError(f"Unknown optimizer: {opt_name}")
      
  scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
  
  # dataset
  train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, worker_init_fn=worker_init_fn)
  val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, worker_init_fn=worker_init_fn)

  # Early stopping
  best_val_loss = float('inf')
  patience_counter = 0
  patience = 2

  n_epoch = 20
  print(f"\nStarting training ...")
  for epoch in range(n_epoch):
    # --- Training Phase ---
    model, avg_train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    writer.add_scalar('Loss/Train', avg_train_loss, epoch)

    # --- Validation Phase ---
    avg_val_loss,avg_val_iou = test_model(model, val_loader, loss_fn,
                                                device, num_classes=NUM_CLASSES, ignore_index=IGNORE_INDEX)

    writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
    writer.add_scalar('Metrics/Validation_IoU', avg_val_iou, epoch)

    if epoch % 1 == 0:
      print(f"Epoch {epoch+1}/{n_epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val IoU: {avg_val_iou:.4f}")
      
    # --- early stop ---
    if avg_val_loss < best_val_loss:
      best_val_loss = avg_val_loss
      patience_counter = 0
    else:
      patience_counter += 1
      if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break
  
  # 
    scheduler.step()

  final_loss = avg_val_loss
  final_iou = avg_val_iou
  print("\nTraining complete.")
  return model, final_loss, avg_val_iou

In [None]:
def calculate_iou_mask(pred_mask, true_mask, num_classes, ignore_index):
  """
  Calculates Intersection over Union (IoU) for segmentation masks.
  Can handle binary or multi-class masks.
  """
  if pred_mask.shape != true_mask.shape:
    raise ValueError("Predicted mask and true mask must have the same shape.")

  # Ensure masks are integer types
  pred_mask = pred_mask.astype(np.int64)
  true_mask = true_mask.astype(np.int64)

  # Flatten masks for easier comparison
  pred_flat = pred_mask.flatten()
  true_flat = true_mask.flatten()

  if num_classes is None:
    all_classes = np.unique(np.concatenate((pred_flat, true_flat)))
    if ignore_index is not None:
        all_classes = all_classes[all_classes != ignore_index]
    num_classes = int(np.max(all_classes)) + 1 if len(all_classes) > 0 else 0
    if num_classes == 0:
        return 0.0 if ignore_index is None else np.nan

  iou_per_class = {}
  total_iou = 0.0
  valid_classes_count = 0

  for class_id in range(num_classes):
    if class_id == ignore_index:
      continue

    # Create binary masks for the current class
    pred_binary = (pred_flat == class_id)
    true_binary = (true_flat == class_id)

    intersection = np.sum(pred_binary & true_binary)
    union = np.sum(pred_binary | true_binary)

    if union == 0:
      iou_score = np.nan
    else:
      iou_score = intersection / union

    iou_per_class[class_id] = iou_score

    if not np.isnan(iou_score):
      total_iou += iou_score
      valid_classes_count += 1

  mean_iou = total_iou / valid_classes_count if valid_classes_count > 0 else np.nan

  return mean_iou, iou_per_class

# 3. Main Training Script

## 3.1 Data Management

In [None]:
ISPRS_COLORMAP = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
                [255, 255, 0], [255, 0, 0]]

ISPRS_CLASSES = ['Impervious Surfaces', 'Building', 'Low Vegetation', 'Tree',
               'Car','Clutter/Background']

In [None]:
# 0.Device configuration
device = torch.device('cuda')
if torch.cuda.is_available():
    device = torch.device('cuda')

print(f'Use: {device}')

# 1. loading data
path = kagglehub.dataset_download("jahidhasan66/isprs-potsdam")
path = path+'/patches'
print("Path to dataset files:", path)
image_path = path+'/Images'
label_path = path+'/Labels'
images, labels = read_isprs_images(image_path, label_path)

colormap2label = voc_colormap2label()
labels_class_ids = [voc_label_indices(label, colormap2label) for label in labels]
print(f"Converted {len(labels)} RGB labels to class ID labels (H, W).")

# 2. split data
train_ratio = 0.70
val_ratio = 0.15
test_ratio = 0.15

total_samples = min(len(images), len(labels_class_ids))
train_count = int(total_samples * train_ratio)
val_count = int(total_samples * val_ratio)
test_count = total_samples - train_count - val_count

train_images = images[:train_count]
train_labels = labels_class_ids[:train_count]

val_images = images[train_count : train_count + val_count]
val_labels = labels_class_ids[train_count : train_count + val_count]

test_images = images[train_count + val_count : total_samples]
test_labels = labels_class_ids[train_count + val_count : total_samples]

print(f"Final train_images count: {len(train_images)}, train_labels count: {len(train_labels)}")
print(f"Final val_images count: {len(val_images)}, val_labels count: {len(val_labels)}")
print(f"Final test_images count: {len(test_images)}, test_labels count: {len(test_labels)}")

# v

In [None]:
# cus dataset
train_dataset = ISPRSDataset(train_images, train_labels)
val_dataset = ISPRSDataset(val_images, val_labels)
test_dataset = ISPRSDataset(test_images, test_labels)

batch_size = 12
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# # of data, # of batches, check input and output shape
print(f"Train images length: {len(train_images)}, Train labels length: {len(train_labels)}")
print(f"Validation images length: {len(val_images)}, Validation labels length: {len(val_labels)}")
print(f"Test images length: {len(test_images)}, Test labels length: {len(test_labels)}")

print(f"Number of batches in train_loader: {len(train_loader)}")

for X, Y in train_loader:
  print(X.shape)
  print(Y.shape)
  break

# Visualizate image and mask
display_example_index = 337
print('visualizate a example')
display_image_and_label(images, labels, display_example_index, is_class_id_label=False)

## 3.2 Model Training

In [None]:
# reproducible
SEED = 77
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(SEED)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
  seed = SEED + worker_id
  torch.manual_seed(seed)
  np.random.seed(seed)
  random.seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# file
base_log_dir = "runs/hparam_search_results"
os.makedirs(base_log_dir, exist_ok=True)

# no. of classes, ignore_index
NUM_CLASSES = len(ISPRS_CLASSES)
IGNORE_INDEX = 255

# hyperparameters
learning_rates = [0.005,0.001]
batch_sizes = [32, 64]
optimizers = ['Adam']
gammas = [0.90]

MODELS_TO_TEST = [
    ("FCN_ResNet101", initialize_fcn_resnet101_model(6,device)),
    ("DeepLabV3_ResNet101", initialize_deeplabv3_resnet101_model(6,device))
]

HPARAM_COMBINATIONS = list(itertools.product(learning_rates, batch_sizes, optimizers,gammas))

# training
print("\nStarting hyperparameter search across multiple models...")

for model_name, model_init_func in MODELS_TO_TEST:
    print(f"\n--- Starting Hyperparameter Search for Model: {model_name} ---")
    for i, (lr, bs, opt_name, gamma) in enumerate(HPARAM_COMBINATIONS):
      hparams = {
          'learning_rate': lr,
          'batch_size': bs,
          'optimizer': opt_name,
          'model_name': model_name,
          'gamma':gammas
      }
    
      run_name = f"{model_name}_run_{i}_lr{lr}_bs{bs}_opt{opt_name}"
      log_path = os.path.join(base_log_dir, run_name)
      writer = SummaryWriter(log_dir=log_path)
    
      print(f"\n--- Run {i+1}/{len(HPARAM_COMBINATIONS)}: {hparams} ---")
      current_model = model_init_func
      _, final_loss, final_iou = train_val(current_model, hparams, writer)
    
      writer.add_hparams(
          hparam_dict=hparams,
          metric_dict={
              'hparam/final_loss': final_loss,
              'hparam/final_iou': final_iou
          }
      )
    
      writer.close()
      print(f"Experiment {i+1} for {model_name} completed. Final Validation Loss: {final_loss:.4f}, Final Validation IoU: {final_iou:.4f}")

print(f"\n All data have been recorded: {base_log_dir}")


%load_ext tensorboard
%tensorboard --logdir runs/hparam_search_results


In [None]:
# best model
best_hparams = {'learning_rate': 0.001, 'batch_size': 32, 'optimizer': 'Adam', 'model_name': 'DeepLabV3_ResNet101', 'gamma': 0.9}

model_fcn = initialize_deeplabv3_resnet101_model(6,device)
model, best_final_loss, beast_final_iou = train_val(model_fcn,best_hparams, writer)

## 3.3 Evaluate Model Performance

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

loss_fn = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
avg_test_loss, avg_test_iou = test_model(model, test_loader, loss_fn,
                                         device, num_classes=NUM_CLASSES, ignore_index=IGNORE_INDEX)

print(f"Average Test Loss: {avg_test_loss:.4f}")
print(f"Mean Test IoU: {avg_test_iou:.4f}")

## 3.4 Visualize Prediction Results

In [None]:

num_visualizations_per_class = 3
test_indices = list(range(len(test_dataset)))
random_test_indices = random.sample(test_indices, num_visualizations_per_class)

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def mask_to_rgb(mask, colormap):
  """Converts a class index mask (H, W) to an RGB image (H, W, 3) using a colormap."""
  h, w = mask.shape
  output_rgb = np.zeros((h, w, 3), dtype=np.uint8)
  for class_id, color in enumerate(colormap):
      output_rgb[mask == class_id] = color
  return output_rgb

plt.figure(figsize=(15, num_visualizations_per_class * 5))

with torch.no_grad():
  for i, idx in enumerate(random_test_indices):
    image, label = test_dataset[idx]
    image = image.unsqueeze(0).to(device)
    prediction = model(image)['out']
    predicted_mask = torch.argmax(prediction, dim=1).squeeze(0)
    image = image.squeeze(0).cpu()
    label = label.cpu()
    predicted_mask = predicted_mask.cpu()

    ground_truth_rgb = mask_to_rgb(label.numpy(), ISPRS_COLORMAP)
    predicted_rgb = mask_to_rgb(predicted_mask.numpy(), ISPRS_COLORMAP)

    image_display = image.permute(1, 2, 0).numpy()
    if image_display.max() > 255 or image_display.min() < 0:
          image_display = (image_display - image_display.min()) / (image_display.max() - image_display.min()) * 255
    image_display = image_display.astype(np.uint8)

    plt.subplot(num_visualizations_per_class, 3, i * 3 + 1)
    plt.imshow(image_display)
    plt.title(f"Original Image {idx}")
    plt.axis('off')

    plt.subplot(num_visualizations_per_class, 3, i * 3 + 2)
    plt.imshow(ground_truth_rgb)
    plt.title(f"Ground Truth {idx}")
    plt.axis('off')

    plt.subplot(num_visualizations_per_class, 3, i * 3 + 3)
    plt.imshow(predicted_rgb)
    plt.title(f"Predicted {idx}")
    plt.axis('off')

plt.tight_layout()
plt.show()