In [32]:
import numpy as np
import os
import matplotlib.pyplot as plt
import random

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
#from transformers import DeepLabV3Processor, DeepLabV3ForSemanticSegmentation

from PIL import Image
import evaluate
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from tqdm.notebook import tqdm
import wandb
import segmentation_models_pytorch as smp


In [33]:
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, transform=None, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            image_processor (SegFormerImageProcessor): image processor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.train = train

        sub_path = "training" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, sub_path, "images")
        self.ann_dir = os.path.join(self.root_dir, sub_path, "groundtruth")
        self.img_dir = os.path.normpath(self.img_dir)
        self.ann_dir = os.path.normpath(self.ann_dir)
        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
          if os.path.basename(root) != "images":
            files = [os.path.basename(root) + "/" + f for f in files]
          image_file_names.extend(files)
        self.images = sorted(image_file_names)
        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
          if os.path.basename(root) != "groundtruth":
            files = [os.path.basename(root) + "/" + f for f in files]
          annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.img_dir, self.images[idx])).convert("RGB")
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))
        # Convert to tensor
        image = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
        
        # Convert segmentation map to tensor, and ensure it's an integer type
        segmentation_map = torch.tensor(np.array(segmentation_map), dtype=torch.float32)

        # Apply transformation (if any)
        if self.transform:
            image, segmentation_map = self.transform(image, segmentation_map)

        return image, segmentation_map

In [34]:
# Set training parameters
batch_size = 4
epochs = 10

# Set seed for reproducibility
seed = 21
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_num_threads(1)  # Use a single thread for PyTorch
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["PYTHONHASHSEED"] = str(seed)  # Prevent hash-based randomness in Python

# Load dataset
root_dir = 'C:/Users/Qrnqult/Documents/GitHub/ML_AGA/data'


In [35]:
train_dataset = SemanticSegmentationDataset(root_dir=root_dir)
valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, train=False)

print(train_dataset[0])
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

(tensor([[[ 82.,  80.,  80.,  ...,  79.,  78.,  78.],
         [ 81.,  80.,  79.,  ...,  80.,  80.,  80.],
         [ 81.,  81.,  80.,  ...,  80.,  80.,  80.],
         ...,
         [  8.,   9.,  11.,  ...,  94.,  94.,  92.],
         [  8.,   9.,  10.,  ..., 141., 143., 138.],
         [  7.,   9.,  11.,  ..., 161., 158., 145.]],

        [[ 76.,  75.,  74.,  ...,  78.,  78.,  79.],
         [ 75.,  74.,  73.,  ...,  80.,  80.,  81.],
         [ 75.,  75.,  74.,  ...,  80.,  80.,  81.],
         ...,
         [  8.,   9.,  11.,  ...,  92.,  92.,  91.],
         [  8.,   9.,  11.,  ..., 140., 141., 136.],
         [  8.,  10.,  12.,  ..., 159., 156., 143.]],

        [[ 68.,  66.,  65.,  ...,  74.,  73.,  74.],
         [ 67.,  65.,  64.,  ...,  75.,  75.,  76.],
         [ 67.,  67.,  66.,  ...,  75.,  75.,  76.],
         ...,
         [  5.,   7.,   9.,  ...,  82.,  83.,  81.],
         [  4.,   5.,   7.,  ..., 125., 128., 124.],
         [  3.,   6.,   7.,  ..., 144., 142., 130.]]

In [None]:
# Create dataloaders
train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size, 
                              shuffle=True, 
                              worker_init_fn=lambda worker_id: np.random.seed(seed + worker_id),
                              generator=torch.Generator().manual_seed(seed))
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size)

# Define model
model = smp.DeepLabV3Plus(
    encoder_name="resnet34", 
    encoder_weights="imagenet", 
    classes=2,  # Number of classes (including background)
    activation="softmax"  # or 'sigmoid' for binary segmentation
)

# Load evaluation metric
metric = evaluate.load("mean_iou")

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)

# Move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Initialize W&B run
wandb.login()
job_type = "train_model"
config = {
    "optimizer": "adam",
    "batch_size": batch_size,
    "epochs": epochs,
}
run = wandb.init(project="ml_p2", job_type=job_type, config=config)

DeepLabV3Plus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

In [None]:
best_f1 = 0.0

# Begin training
print("Starting training")
model.train()
for epoch in range(epochs):  # loop over the dataset multiple times
    print("Epoch:", epoch+1)
    model.train()
    for idx, batch in enumerate((train_dataloader)):
        # Get the inputs
        image, mask = batch
        image = image.to(device)
        mask = mask.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(image)
        one_part = outputs[:, 1, :, :]
        print(one_part.shape)
        print(mask.shape)
        loss = nn.CrossEntropyLoss()(one_part, mask)  # Use cross-entropy loss for segmentation

        loss.backward()
        optimizer.step()

        # Evaluate
        with torch.no_grad():
            predicted = outputs.argmax(dim=1)
            # Update metric
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=mask.detach().cpu().numpy())

        # Print loss and metrics every 100 batches
        if idx % 100 == 0:
            metrics = metric._compute(
              predictions=predicted.cpu(),
              references=mask.cpu(),
              num_labels=2,
              ignore_index=None,
              reduce_labels=False,
            )
            print(f"Loss: {loss.item()}")
            print("Mean IoU:", metrics["mean_iou"])
            print("Pixel-wise accuracy:", metrics["overall_accuracy"])
            wandb.log({"Epoch": epoch, "Train accuracy": metrics["overall_accuracy"], "Train loss": loss.item(), "Mean IoU": metrics["mean_iou"]})

    # Validation
    all_preds = []
    all_labels = []

    model.eval()  # Set model to evaluation mode
    validation_loss = 0.0
    with torch.no_grad():
        for val_batch in tqdm(valid_dataloader, desc="Validation"):
            # Get the inputs
            image, mask = val_batch
            print("Shape", image.shape)
            image = image.to(device)
            mask = mask.to(device)

            # Forward pass
            outputs = model(image)
            val_loss = nn.CrossEntropyLoss()(outputs, mask)
            validation_loss += val_loss.item()

            # Evaluate validation metrics
            predicted = outputs.argmax(dim=1)
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=mask.detach().cpu().numpy())

            # Collect predictions and labels for F1 score
            all_preds.append(predicted.cpu().numpy().flatten())
            all_labels.append(mask.cpu().numpy().flatten())

    # Compute overall metrics for validation
    val_metrics = metric._compute(
      predictions=predicted.cpu(),
      references=mask.cpu(),
      num_labels=2,
      ignore_index=None,
      reduce_labels=False,
    )
    avg_val_loss = validation_loss / len(valid_dataloader)

    # Compute pixel-wise scores for the entire validation set
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    pixelwise_f1_val = f1_score(all_labels, all_preds, average="binary")
    pixelwise_recall_val = recall_score(all_labels, all_preds)
    pixelwise_precision_val = precision_score(all_labels, all_preds)
    pixelwise_accuracy_val = accuracy_score(all_labels, all_preds)

    print(f"Validation Loss: {avg_val_loss:.4f} | Mean IOU: {val_metrics['mean_iou']:.4f} | Mean Accuracy: {val_metrics['mean_accuracy']:.4f}")
    print(f"Validation F1 Score: {pixelwise_f1_val:.4f}")
    print(f"Validation Recall: {pixelwise_recall_val:.4f}")
    print(f"Validation Precision: {pixelwise_precision_val:.4f}")
    print(f"Validation Accuracy: {pixelwise_accuracy_val:.4f}")

    wandb.log({"Epoch": epoch+1, "Val accuracy": pixelwise_accuracy_val, "Val loss": avg_val_loss, "Val F1": pixelwise_f1_val, "Val recall": pixelwise_recall_val, "Val precision": pixelwise_precision_val, "Mean accuracy": val_metrics['mean_accuracy'], "Mean IoU": val_metrics['mean_iou']})

    # if pixelwise_f1_val > best_f1:
    #     best_f1 = pixelwise_f1_val
    #     print("Saving best model")
    #     torch.save(model.state_dict(), f"./models/finetuned_deeplabv3plus_{epoch+1}.pth")  # Save model
    #     torch.save(optimizer.state_dict(), f"./models/finetuned_deeplabv3plus_optimizer_{epoch+1}.pth")

Starting training
Epoch: 1
torch.Size([4, 400, 400])
torch.Size([4, 400, 400])




Loss: 139539.578125
Mean IoU: 0.15224774471133762
Pixel-wise accuracy: 0.34371168620159537
torch.Size([4, 400, 400])
torch.Size([4, 400, 400])


KeyboardInterrupt: 