In [None]:
from datasets import Dataset, DatasetDict, Image
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset
import torch
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
from PIL import Image
from torch.utils.data import DataLoader
from PIL import Image
import torch
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
import evaluate


# Constants 
ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

# Parameters
dataset_name = "ACPDS" # ACPDS | PKLOT
train_ratio = 0.6
test_ratio = 0.3


In [None]:
root = os.getcwd()
dataset_root = f"{root}/{dataset_name}/{dataset_name}"
image_dir = os.path.join(dataset_root, "images")
mask_dir = os.path.join(dataset_root, "int_masks")

image_paths = sorted([
    os.path.join(image_dir, f)
    for f in os.listdir(image_dir)
    if f.endswith(".jpg")
])

label_paths = sorted([
    os.path.join(mask_dir, f.replace(".jpg", ".png"))
    for f in os.listdir(image_dir)
    if f.endswith(".jpg")
])

combined = list(zip(image_paths, label_paths))
random.seed(42)  
random.shuffle(combined)

train_size = int(len(combined)*train_ratio)
test_size = int(len(combined)*test_ratio)
val_size = len(combined) - train_size - test_size

train_split = combined[:train_size]
test_split = combined[train_size:train_size + test_size]
val_split = combined[train_size + test_size:]

# Unzip back
train_imgs, train_masks = zip(*train_split)
test_imgs, test_masks = zip(*test_split)
val_imgs, val_masks = zip(*val_split)

def create_dataset(image_paths, label_paths):
    dataset = Dataset.from_dict({
        "image": image_paths,
        "label": label_paths
    })
    dataset = dataset.cast_column("image", Image())
    dataset = dataset.cast_column("label", Image())
    return dataset

dataset = DatasetDict({
    "train": create_dataset(train_imgs, train_masks),
    "test": create_dataset(test_imgs, test_masks),
    "validation": create_dataset(val_imgs, val_masks)
})

dataset

In [None]:
example = dataset["train"][0]
image = example["image"]
image

In [None]:
segmentation_map = example["label"]
segmentation_map

In [None]:
segmentation_map = np.array(segmentation_map)
segmentation_map

In [None]:
id2label = {
    0: "background",
    1: "free",
    2: "occupied"
}

id2color = {
    0: (255,255,255),
    1: (0,255,0),
    2: (255,0,0)
}

def visualize_map(image, segmentation_map):
    color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) 
    for label, color in id2color.items():
        color_seg[segmentation_map == label, :] = color

    # Show image + mask
    img = np.array(image) * 0.5 + color_seg * 0.5
    img = img.astype(np.uint8)

    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.show()

visualize_map(image, segmentation_map)

In [None]:
class SegmentationDataset(Dataset):
  
  def __init__(self, dataset, resize_size=(448, 448)):
    self.dataset = dataset
    self.resize_size = resize_size
    self.normalize = transforms.Normalize(mean=ADE_MEAN, std=ADE_STD)

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    og_image = np.array(item["image"])
    og_mask = np.array(item["label"])

    image = F.to_tensor(og_image)            # Converts to [C, H, W] and scales [0, 255] -> [0, 1]
    mask = torch.from_numpy(og_mask).long()  # [H, W] with class ids

    # Resize both image and mask
    image = F.resize(image, self.resize_size, interpolation=InterpolationMode.BILINEAR)
    mask = F.resize(mask.unsqueeze(0), self.resize_size, interpolation=InterpolationMode.NEAREST).squeeze(0)

    # Normalize image
    image = self.normalize(image)

    return image, mask, og_image, og_mask

In [None]:
train_dataset = SegmentationDataset(dataset["train"])
val_dataset = SegmentationDataset(dataset["validation"])

In [None]:
pixel_values, target, original_image, original_segmentation_map = train_dataset[3]
print(pixel_values.shape)
print(target.shape)

In [None]:
pixel_values_val, target_val, original_image_val, original_segmentation_map_val = val_dataset[3]
print(pixel_values_val.shape)
print(target_val.shape)

In [None]:
Image.fromarray(original_image)

In [None]:
[id2label[id] for id in np.unique(original_segmentation_map).tolist()]

In [None]:
def collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]

    return batch

train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=3, shuffle=False, collate_fn=collate_fn)

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

In [None]:
batch["pixel_values"].dtype

In [None]:
batch["labels"].dtype

In [None]:
unnormalized_image = (batch["pixel_values"][0].numpy() * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
unnormalized_image = Image.fromarray(unnormalized_image)
unnormalized_image

In [None]:
[id2label[id] for id in torch.unique(batch["labels"][0]).tolist()]

In [None]:
visualize_map(unnormalized_image, batch["labels"][0].numpy())

In [None]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1))


    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.classifier(embeddings)


class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)

    self.dinov2 = Dinov2Model(config)
    self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels)

  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
    # use frozen features
    outputs = self.dinov2(pixel_values,
                            output_hidden_states=output_hidden_states,
                            output_attentions=output_attentions)
    # get the patch embeddings - so we exclude the CLS token
    patch_embeddings = outputs.last_hidden_state[:,1:,:]

    # convert to logits and upsample to the size of the pixel values
    logits = self.classifier(patch_embeddings)
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

    loss = None
    if labels is not None:
      # important: we're going to use 0 here as ignore index instead of the default -100
      # as we don't want the model to learn to predict background
      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)
      # print("Logits shape:", logits.shape)
      # print("Labels shape:", labels.shape)
      loss = loss_fct(logits, labels)

    return SemanticSegmenterOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

In [None]:
model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))

In [None]:
# Freeze dinov2
for name, param in model.named_parameters():
  if name.startswith("dinov2"):
    param.requires_grad = False

In [None]:
outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
print(outputs.logits.shape)
print(outputs.loss)

In [None]:
iou_metric = evaluate.load("mean_iou")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm

# training hyperparameters
# NOTE: I've just put some random ones here, not optimized at all
# feel free to experiment, see also DINOv2 paper
learning_rate = 1e-2
epochs = 2

optimizer = AdamW(model.parameters(), lr=learning_rate)

# put model on GPU (set runtime to GPU in Google Colab)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# put model in training mode
model.train()

for epoch in range(epochs):
  print("Epoch:", epoch)
  for idx, batch in enumerate(tqdm(train_dataloader)):
      pixel_values = batch["pixel_values"].to(device)
      labels = batch["labels"].to(device)

      # forward pass
      outputs = model(pixel_values, labels=labels)
      loss = outputs.loss

      loss.backward()
      optimizer.step()

      # zero the parameter gradients
      optimizer.zero_grad()

      # evaluate
      with torch.no_grad():
        predicted = outputs.logits.argmax(dim=1)

        # note that the metric expects predictions + labels as numpy arrays
        iou_metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

      # let's print loss and metrics every 100 batches
      if idx % 100 == 0:
        metrics = iou_metric.compute(num_labels=len(id2label),
                                ignore_index=0,
                                reduce_labels=False,
        )

        print("Loss:", loss.item())
        print("Mean_iou:", metrics["mean_iou"])
        print("Mean accuracy:", metrics["mean_accuracy"])

In [None]:
from PIL import Image
og_test_image = dataset["test"][16]["image"]
og_test_image

In [None]:
normalize = transforms.Normalize(mean=ADE_MEAN, std=ADE_STD)

test_image = F.to_tensor(og_test_image)
test_image = F.resize(test_image, (448, 448), interpolation=InterpolationMode.BILINEAR)
pixel_values = normalize(test_image).unsqueeze(0)
print(pixel_values.shape)

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(pixel_values.to(device))

In [None]:
upsampled_logits = torch.nn.functional.interpolate(outputs.logits,
                                                   size=og_test_image.size[::-1],
                                                   mode="bilinear", align_corners=False)
predicted_map = upsampled_logits.argmax(dim=1)

In [None]:
visualize_map(og_test_image, predicted_map.squeeze().cpu())

In [None]:
torch.save(model, 'model.pt')