In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import albumentations as A
from tqdm.auto import tqdm
import glob
import os
import pandas as pd
from sklearn.model_selection import train_test_split

from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
import cv2

import segmentation_models_pytorch as smp
from torch.profiler import profile, record_function, ProfilerActivity

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
# a list to collect paths of images
images_path = []
labels_path = []
masks_path = []

# Get the paths of the images and sort them
images_path = sorted(glob.glob('data/original_images/*.jpg'))
labels_path = sorted(glob.glob('data/label_images_semantic/*.png'))
rgb_masks_path = sorted(glob.glob('data/RGB_color_image_masks/*.png'))

paths = np.column_stack((images_path, labels_path))
print(paths.shape)
print(paths[0])

# Apply 80-10-10 split
train_split, valtest_split = train_test_split(paths, test_size=0.2, random_state=69420)
val_split, test_split = train_test_split(valtest_split, test_size=0.5, random_state=69420)

(400, 2)
['data/original_images/000.jpg' 'data/label_images_semantic/000.png']


In [4]:
# for img in tqdm(paths[:, 0]):
#     shape = cv2.imread(img).shape
#     print(f'shape for img {img} is {shape}')
#     if shape != (4000, 6000, 3):
#         print(f'wrong shape for img {img}')

In [5]:
# Read number of classes
classes_df = pd.read_csv('data/class_dict_seg.csv')
classes_df

Unnamed: 0,name,r,g,b
0,unlabeled,0,0,0
1,paved-area,128,64,128
2,dirt,130,76,0
3,grass,0,102,0
4,gravel,112,103,87
5,water,28,42,168
6,rocks,48,41,30
7,pool,0,50,89
8,vegetation,107,142,35
9,roof,70,70,70


In [6]:
# # Check if there are any conflicting labels in the masks
# for img in tqdm(paths[:, 1]):
#     mask = cv2.imread(img, cv2.IMREAD_GRAYSCALE)
#     unlabelled = np.sum(mask == 23)
#     if unlabelled > 0:
#         print(f'Unlabelled pixels in img {img}: {unlabelled} ({unlabelled / mask.size * 100:.2f}%)')

In [7]:
# No conflicting labels found
# Therefore there are 23 classes in the dataset
nr_classes = len(classes_df) - 1
print(f'Number of classes: {nr_classes}')

Number of classes: 23


In [8]:
class TilesDataset(Dataset):
    def __init__(self, image_paths, transform=None, tiles=True, tiles_dim=512):
        self.image_paths = image_paths
        self.transform = transform
        self.tiles = tiles
        self.tiles_dim = tiles_dim

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path, mask_path = self.image_paths[idx]
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)
            img, mask = augmented['image'], augmented['mask']

        if self.tiles:
            img, mask = self.create_tiles(img, mask, self.tiles_dim)

        return img, mask
    
    def create_tiles(self, img, mask, tiles_dim):
        # # Round up to the nearest multiple of tiles_dim
        # new_height = (img.shape[1] + tiles_dim - 1) // tiles_dim * tiles_dim
        # new_width = (img.shape[2] + tiles_dim - 1) // tiles_dim * tiles_dim
        # Round down to the nearest multiple of tiles_dim
        new_height = img.shape[1] // tiles_dim * tiles_dim
        new_width = img.shape[2] // tiles_dim * tiles_dim
        new_shp = (new_height, new_width)

        # Resize the image and mask
        img = F.interpolate(img.unsqueeze(0), size=new_shp, mode='bilinear', align_corners=False).squeeze(0)
        mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=new_shp, mode='nearest').squeeze(0).squeeze(0)
        
        # Create img tiles and mask tiles
        img_tiles = img.unfold(1, tiles_dim, tiles_dim).unfold(2, tiles_dim, tiles_dim)
        img_tiles = img_tiles.contiguous().view(3, -1, tiles_dim, tiles_dim)
        img_tiles = img_tiles.permute(1, 0, 2, 3)
        
        mask_tiles = mask.unfold(0, tiles_dim, tiles_dim).unfold(1, tiles_dim, tiles_dim)
        mask_tiles = mask_tiles.contiguous().view(-1, tiles_dim, tiles_dim)

        # Check if tiles_dim is greater than 512, resize each tile to 512x512
        if tiles_dim > 256:
            resize_dim = 256
            img_tiles = F.interpolate(img_tiles, size=(resize_dim, resize_dim), mode='bilinear', align_corners=False)
            mask_tiles = F.interpolate(mask_tiles.unsqueeze(1), size=(resize_dim, resize_dim), mode='nearest').squeeze(1)

        return img_tiles, mask_tiles
    
tiles_dim = 512
new_height = (4000 + tiles_dim - 1) // tiles_dim * tiles_dim
new_width = (6000 + tiles_dim - 1) // tiles_dim * tiles_dim

# Define Albumentations transformations
train_transform = A.Compose([
    # A.Resize(new_height, new_width, p=1.0),  # Resize the image to the desired shape
    A.HorizontalFlip(p=0.5),  # Apply horizontal flip with 50% probability
    A.VerticalFlip(p=0.5),  # Apply vertical flip with 50% probability
    A.RandomBrightnessContrast(p=0.2),  # Randomly change brightness and contrast
    A.OneOf([
        A.GaussianBlur(p=1.0),  # Apply Gaussian blur
        A.MotionBlur(p=1.0),  # Apply motion blur
    ], p=0.2),  # Apply one of the blur operations with 20% probability
    A.HueSaturationValue(p=0.2),  # Randomly change hue, saturation, and value
    A.RandomGamma(p=0.2),  # Randomly change gamma
    A.CLAHE(p=0.2),  # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize the image
    ToTensorV2(),  # Convert image and mask to PyTorch tensors
])

valtest_transform = A.Compose([
    A.HorizontalFlip(p=0.5),  # Apply horizontal flip with 50% probability
    A.VerticalFlip(p=0.5),  # Apply vertical flip with 50% probability
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  # Normalize the image
    ToTensorV2(),  # Convert image and mask to PyTorch tensors
])

# Initialize your custom dataset
train_ds = TilesDataset(train_split, transform=train_transform, tiles_dim=tiles_dim)
val_ds = TilesDataset(val_split, transform=valtest_transform, tiles_dim=tiles_dim)
test_ds = TilesDataset(test_split, transform=valtest_transform, tiles_dim=tiles_dim)

In [9]:
print(f'Train dataset length: {len(train_ds)}, Val dataset length: {len(val_ds)}, Test dataset length: {len(test_ds)}')

Train dataset length: 320, Val dataset length: 40, Test dataset length: 40


In [10]:
# Create a DataLoader
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, pin_memory=False, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, pin_memory=False, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, pin_memory=False, num_workers=4)

### Metrics

In [11]:
def compute_metrics_torch(y_true, y_pred, num_classes):
    # Flatten the arrays for metric computation
    y_true_flat = y_true.view(-1)
    y_pred_flat = y_pred.view(-1)

    # Compute overall accuracy
    acc = (y_true_flat == y_pred_flat).float().mean().item()

    # Helper function to compute IoU for a single class
    def compute_iou(cls):
        intersection = ((y_true_flat == cls) & (y_pred_flat == cls)).float().sum().item()
        union = ((y_true_flat == cls) | (y_pred_flat == cls)).float().sum().item()
        return intersection / union if union != 0 else 0

    # Helper function to compute Dice score for a single class
    def compute_dice(cls):
        intersection = 2 * ((y_true_flat == cls) & (y_pred_flat == cls)).float().sum().item()
        total = (y_true_flat == cls).float().sum().item() + (y_pred_flat == cls).float().sum().item()
        return intersection / total if total != 0 else 0

    # Compute IoU
    iou_list = [compute_iou(cls) for cls in range(num_classes)]
    mean_iou = np.mean(iou_list)

    # Compute Dice
    dice_list = [compute_dice(cls) for cls in range(num_classes)]
    mean_dice = np.mean(dice_list)

    # Return the metrics
    return {
        'mean_iou': mean_iou,
        'per_class_iou': iou_list,
        'accuracy': acc,
        'mean_dice': mean_dice,
        'per_class_dice': dice_list
    }

### Create network

In [12]:
config = {
    'arch': 'unet',
    'encoder_name': 'resnet34',
    'encoder_weights': 'imagenet',
    'in_channels': 3,
    'classes': nr_classes
}

model = smp.create_model(**config)
model

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

### Tensorboard

In [13]:
from datetime import datetime
from torch.utils.tensorboard.writer import SummaryWriter
import os
from tensorboard import program

# Create a TensorBoard callback
logs_dir = f'logs/{config["arch"]}/{config["encoder_name"]}/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
os.makedirs(logs_dir, exist_ok=True)  # Ensure the logs directory exists
# Get full path to the logs directory
logs_dir = os.path.abspath(logs_dir)
print(f"TensorBoard logs directory: {logs_dir}")

writer = SummaryWriter(log_dir=logs_dir)

# Launch TensorBoard
tb = program.TensorBoard()
tb.configure(argv=[None, '--logdir', logs_dir])
url = tb.launch()

print(f"TensorBoard is running at {url}")

2024-07-26 22:29:16.718462: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-26 22:29:16.733147: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-26 22:29:16.737627: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-26 22:29:16.749170: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:root:Server binary (from Python package v0.7.2):

TensorBoard logs directory: /home/andrea/unimib/BigImaging/Exam/logs/unet/resnet34/2024-07-26_22-29-18


INFO:pytorch_profiler:Monitor runs begin


TensorBoard is running at http://localhost:6008/


### Train

In [14]:
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast

def reshape_imgs_masks(imgs, masks):
    imgs, masks = imgs.to(device), masks.to(device)

    # Reshape images: [batch_size, num_tiles, channels, height, width] -> [batch_size * num_tiles, channels, height, width]
    imgs = imgs.view(-1, imgs.shape[2], imgs.shape[3], imgs.shape[4])
    # Reshape masks: [batch_size, num_tiles, height, width] -> [batch_size * num_tiles, height, width]
    masks = masks.view(-1, masks.shape[2], masks.shape[3])

    # Convert masks to Long() type
    masks = masks.to(torch.long)

    return imgs,masks

def train(train_loss, imgs, masks, accumulation_steps=1):
    imgs, masks = reshape_imgs_masks(imgs, masks)

    optimizer.zero_grad()
    with autocast(device_type='cuda'):
        outputs = model(imgs)
        loss = criterion(outputs, masks) / accumulation_steps
    scaler.scale(loss).backward()
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    
    train_loss += loss.item() * accumulation_steps  # Adjust for scaled loss

    return train_loss

# Set up mixed precision training
scaler = GradScaler()

# Accumulation steps (adjust based on GPU memory)
accumulation_steps = 1

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 100
num_classes = nr_classes

with torch.profiler.profile(
            activities=[ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=50, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(logs_dir),
            record_shapes=True, profile_memory=True, with_stack=True
        ) as prof:
            train_loss = 0.0
            for i, (imgs, masks) in enumerate(tqdm(train_loader, desc=f'Profiling Training')):
                train_loss = train(train_loss, imgs, masks, accumulation_steps)
                prof.step()

break

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    if epoch == 0:
        with torch.profiler.profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=1, warmup=1, active=50, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(logs_dir),
            record_shapes=True, profile_memory=True, with_stack=True
        ) as prof:
            for i, (imgs, masks) in enumerate(tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')):
                train_loss = train(train_loss, imgs, masks, accumulation_steps)
                prof.step()
    else:
        for i, (imgs, masks) in enumerate(tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')):
            train_loss = train(train_loss, imgs, masks, accumulation_steps)

    writer.close()

    model.eval()
    val_loss = 0.0
    all_y_true = []
    all_y_pred = []

    with torch.no_grad():
        for imgs, masks in tqdm(val_loader, desc=f'Validation Epoch {epoch+1}/{num_epochs}'):

            imgs, masks = reshape_imgs_masks(imgs, masks)

            with autocast():
                outputs = model(imgs)
                loss = criterion(outputs, masks)
            val_loss += loss.item()

            preds = torch.argmax(outputs, dim=1)

        all_y_true.append(masks)
        all_y_pred.append(preds)

    all_y_true = torch.cat(all_y_true, dim=0)
    all_y_pred = torch.cat(all_y_pred, dim=0)

    metrics = compute_metrics_torch(all_y_true, all_y_pred, num_classes)

    print(f'Validation Loss: {val_loss/len(val_loader):.3f}, Mean IoU: {metrics["mean_iou"]:.3f}, '
      f'Accuracy: {metrics["accuracy"]:.3f}, Dice Score: {metrics["mean_dice"]:.3f}, '
      f'per-class IoU: {[f"Class {i}: {iou:.3f}" for i, iou in enumerate(metrics["per_class_iou"])]}')

I0000 00:00:1722025760.775644  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
I0000 00:00:1722025760.829260  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
I0000 00:00:1722025760.882680  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
I0000 00:00:1722025760.926715  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork


Profiling Training:   0%|          | 0/320 [00:00<?, ?it/s]

I0000 00:00:1722025761.034623  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
I0000 00:00:1722025761.124129  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
I0000 00:00:1722025761.189050  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
I0000 00:00:1722025761.354685  231941 work_stealing_thread_pool.cc:320] WorkStealingThreadPoolImpl::PrepareFork
[2024-07-26T20:29:23Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:29:23Z INFO  rustboard_core::cli] Finished load cycle (386.167µs)
[2024-07-26T20:29:28Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:29:28Z INFO  rustboard_core::cli] Finished load cycle (691.585µs)
[2024-07-26T20:29:33Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:29:33Z INFO  rustboard_core::cli] Finished load cycle (476.926µs)
[2024-07-26T20:29:38Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:29:38Z INFO  rustb

INFO:werkzeug:127.0.0.1 - - [26/Jul/2024 22:29:22] "GET / HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [26/Jul/2024 22:29:22] "GET /font-roboto/oMMgfZMQthOryQo9n22dcuvvDin1pK8aKteLpeZ5c0A.woff2 HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [26/Jul/2024 22:29:22] "GET /icon_bundle.svg HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [26/Jul/2024 22:29:22] "GET /font-roboto/RxZJdnzeo3R5zSexge8UUZBw1xU1rKptJj_0jans920.woff2 HTTP/1.1" 200 -
INFO:tensorboard:Plugin listing: is_active() for timeseries took 0.000 seconds
INFO:tensorboard:Thread-11 (process_request_thread)[7fcb94a80640] ENTER GrpcDataProvider.list_tensors
INFO:tensorboard:Plugin listing: is_active() for scalars took 0.000 seconds
INFO:tensorboard:Thread-11 (process_request_thread)[7fcb94a80640]   ENTER build request
INFO:tensorboard:Plugin listing: is_active() for custom_scalars took 0.000 seconds
INFO:tensorboard:Thread-11 (process_request_thread)[7fcb94a80640]   LEAVE build request - 0.002032s elapsed
INFO:tensorboard:Thread-11 (pro

KeyboardInterrupt: 

[2024-07-26T20:30:23Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:23Z INFO  rustboard_core::cli] Finished load cycle (175.279µs)
[2024-07-26T20:30:28Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:28Z INFO  rustboard_core::cli] Finished load cycle (193.618µs)
[2024-07-26T20:30:33Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:33Z INFO  rustboard_core::cli] Finished load cycle (330.877µs)
[2024-07-26T20:30:38Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:38Z INFO  rustboard_core::cli] Finished load cycle (133.078µs)
[2024-07-26T20:30:43Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:43Z INFO  rustboard_core::cli] Finished load cycle (124.909µs)
[2024-07-26T20:30:48Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:48Z INFO  rustboard_core::cli] Finished load cycle (152.259µs)
[2024-07-26T20:30:53Z INFO  rustboard_core::cli] Starting load cycle
[2024-07-26T20:30:53Z INFO  rus