# Milestone 3: Finetune the model with sidewalks dataset

First, please manually copy the extracted sidewalks data to the current folder.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import os, tarfile, sys
from pathlib import Path
from transformers import SamProcessor, SamModel, SamConfig

from tqdm import tqdm
from statistics import mean

import torch
from torch.optim import Adam
from torch.utils.data import IterableDataset, DataLoader

import monai
import cv2

dataset_path = Path().absolute()
test_label_path = dataset_path.joinpath("Label", "Test")
test_images_path = dataset_path.joinpath("Test")
train_label_path = dataset_path.joinpath("Label", "Train")
train_images_path = dataset_path.joinpath("Train")

2024-04-12 06:02:58.903692: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Wrap the dataset into a class as described by the example

In [2]:
def read_dir(path):
    for file_name in os.listdir(str(path)):
        if file_name.endswith(".tfw") or file_name == "yolact":
            continue
        image = tifffile.imread(str(path.joinpath(file_name)))
        yield image


def read_tar(tar):
    for file in tar:
        c = tar.extractfile(file).read()
        if sys.getsizeof(c) > 266:
            na = np.frombuffer(c, dtype=np.uint8)
            im = cv2.imdecode(na, cv2.IMREAD_COLOR)
            yield im


#Get bounding boxes from mask.
def get_bounding_box(ground_truth_map):
    # get bounding box from mask
    y_indices, x_indices = np.where(ground_truth_map > 0)
    H, W = ground_truth_map.shape

    if y_indices.size == 0 or x_indices.size == 0:  # randomly generate bbox
        x = np.random.randint(0, W - 20)
        y = np.random.randint(0, H - 20)
        w = np.random.randint(20, W - x)
        h = np.random.randint(20, H - y)
        return [x, y, x + w, y + h]

    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))
    bbox = [x_min, y_min, x_max, y_max]

    return bbox


class SAMDataset(IterableDataset):
    """
    We need to use generators instead of naively storing all data in memory due to the sheer size of the training data.
    """

    def __init__(self, path_x, path_y, processor, is_tar=False):
        self.processor = processor
        if is_tar:
            self.generator_x = read_tar(tarfile.open(str(path_x), "r:*"))
            self.generator_y = read_tar(tarfile.open(str(path_y), "r:*"))
        else:
            self.generator_x = read_dir(path_x)
            self.generator_y = read_dir(path_y)

    def __iter__(self):
        for image, ground_truth_mask in zip(self.generator_x, self.generator_y):
            # get bounding box prompt
            prompt = get_bounding_box(ground_truth_mask)
            # prepare image and prompt for the model
            inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
            # remove batch dimension which the processor adds by default
            inputs = {k: v.squeeze(0) for k, v in inputs.items()}
            # add ground truth segmentation
            inputs["ground_truth_mask"] = ground_truth_mask
            yield inputs

In [3]:
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("facebook/sam-vit-huge", )
trainset = SAMDataset(train_images_path, train_label_path, processor)
testset = SAMDataset(test_images_path, test_label_path, processor)
trainloader = DataLoader(trainset, batch_size=1)
testloader = DataLoader(testset, batch_size=1)

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
    if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
        param.requires_grad_(False)

In [4]:
# Initialize the optimizer and the loss function
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
#Try DiceFocalLoss, FocalLoss, DiceCELoss
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

#Training loop
num_epochs = 1

device = "cuda"  #if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    batch_num = 0
    for batch in tqdm(trainloader):
        # forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=True)

        # compute loss
        predicted_masks = torch.sigmoid(outputs.pred_masks.squeeze(1))  # need sigmoid to present it as probability map
        ground_truth_masks = batch["ground_truth_mask"].repeat(3, 1, 1).float().to(device)  # to rgb
        loss = seg_loss(predicted_masks,
                        torch.reshape(ground_truth_masks, (1, 3, 256, 256)))  # add dimension to gt masks

        # backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

        # save latest model for backup, just in case. it should back up every 10 minutes on my local machine
        if batch_num % 2000 == 0:
            torch.save(model.state_dict(), 'sidewalks-sam-backup.pt')
        batch_num += 1

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

torch.save(model.state_dict(), 'sidewalks-sam.pt')

174it [10:26,  3.60s/it]


KeyboardInterrupt: 

Test inference

In [2]:
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

sidewalk_model = SamModel(config=model_config)
#Update the model by loading the weights from saved file.
sidewalk_model.load_state_dict(torch.load("sidewalks-sam.pt"))

sidewalk_model.to("cuda")


NameError: name 'SamConfig' is not defined

Plot a sample test image

In [None]:
idx = 2

# load image
for index, batch in enumerate(testloader):
    if index < idx:
        continue
    selected_batch = batch
    break

sidewalk_model.eval()

# forward pass
with torch.no_grad():
    outputs = sidewalk_model(pixel_values=selected_batch["pixel_values"].to(device),
                             input_boxes=selected_batch["input_boxes"].to(device),
                             multimask_output=True)

# apply sigmoid
seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
seg_prob = seg_prob.cpu().numpy().squeeze()
seg = (seg_prob > 0.5).astype(np.uint8)

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot the first image on the left
axes[0].imshow(np.array(selected_batch["pixel_values"]).transpose(1, 2, 0))
axes[0].set_title("Image")

# Plot the second image on the right
axes[1].imshow(seg, cmap='gray')  # Assuming the second image is grayscale
axes[1].set_title("Mask")

# Plot the second image on the right
axes[2].imshow(seg_prob)  # Assuming the second image is grayscale
axes[2].set_title("Probability Map")

# Hide axis ticks and labels
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

# Display the images side by side
plt.show()

Compute overall IOU score for the test images

In [3]:
testloader = DataLoader(testset, batch_size=1)  # reset the loader
ious=[]
for test_batch in testloader:
    outputs = sidewalk_model(pixel_values=test_batch["pixel_values"].to(device),
                    input_boxes=test_batch["input_boxes"].to(device),
                    multimask_output=True)

    # compute loss
    predicted_masks = torch.sigmoid(outputs.pred_masks.squeeze(1))  # need sigmoid to present it as probability map
    ground_truth_masks = batch["ground_truth_mask"].repeat(3, 1, 1).float().to(device)  # to rgb
    ious.append(np.array(monai.metrics.compute_iou(predicted_masks,ground_truth_masks)))
    


SyntaxError: incomplete input (3257055356.py, line 2)