In [None]:
## UTIL

import numpy as np
import matplotlib.pyplot as plt

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

def show_boxes_on_image(raw_image, boxes):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_on_image(raw_image, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    plt.axis('on')
    plt.show()

def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
    plt.figure(figsize=(10,10))
    plt.imshow(raw_image)
    input_points = np.array(input_points)
    if input_labels is None:
      labels = np.ones_like(input_points[:, 0])
    else:
      labels = np.array(input_labels)
    show_points(input_points, labels, plt.gca())
    for box in boxes:
      show_box(box, plt.gca())
    plt.axis('on')
    plt.show()


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_masks_on_image(raw_image, masks, scores):
    if len(masks.shape) == 4:
      masks = masks.squeeze()
    if scores.shape[0] == 1:
      scores = scores.squeeze()

    nb_predictions = scores.shape[-1]
    fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))

    for i, (mask, score) in enumerate(zip(masks, scores)):
      mask = mask.cpu().detach()
      axes[i].imshow(np.array(raw_image))
      show_mask(mask, axes[i])
      axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
      axes[i].axis("off")
    plt.show()

In [10]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from PIL import Image
from transformers import SamModel, SamProcessor
from segment_anything import SamPredictor, sam_model_registry
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import SamModel, SamProcessor
from segment_anything import SamPredictor, sam_model_registry
from peft import LoraConfig, get_peft_model
import torch.nn as nn

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

sam = sam_model_registry["vit_h"](checkpoint="./model/sam_vit_h_4b8939.pth")
sam.to(device=device)
predictor = SamPredictor(sam)

In [8]:
# Add images and masks to array

import os

panos_dir = "./data/training/0000/directions"
masks_dir = "./data/training/0000/direction_masks"

image_paths = []
mask_paths = []

for filename in os.listdir(panos_dir):
	if filename.endswith((".jpg", ".png")):
		image_path = os.path.join(panos_dir, filename)
		mask_name = f"mask_{os.path.splitext(filename)[0]}.png" 
		mask_path = os.path.join(masks_dir, mask_name)

		if os.path.exists(mask_path):
				image_paths.append(image_path)
				mask_paths.append(mask_path)
		else:
				print(f"No mask for {image_path} found")

print(f"Amount of images found: {len(image_paths)}")
print(f"Amount of masks found: {len(mask_paths)}")


Amount of images found: 128
Amount of masks found: 128


In [9]:
# Dataset definition

class WindowSegmentationDataset(Dataset):
  def __init__(self, image_paths, mask_paths, transform=None):
    self.image_paths = image_paths
    self.mask_paths = mask_paths
    self.transform = transform
  
  def __len__(self):
    return len(self.image_paths)
  
  def __getitem__(self, idx):
    image = Image.open(self.image_paths[idx]).convert("RGB")
    mask = Image.open(self.mask_paths[idx]).convert("L")
    
    if (self.transform):
      image = self.transform(image)
      mask = self.transform(mask)
      
    return image, mask

transform = transforms.Compose([
	transforms.Resize((409, 204)),
 	transforms.ToTensor(),
])

dataset = WindowSegmentationDataset(image_paths, mask_paths, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [11]:
class DiceLoss(nn.Module):
	def __init__(self, smooth=1.0):
		super(DiceLoss, self).__init__()
		self.smooth = smooth

	def forward(self, pred, target):
		pred = torch.sigmoid(pred)  # Falls die Logits nicht schon gesigmoidet sind
		intersection = (pred * target).sum()
		return 1 - (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)



class BCEDiceLoss(nn.Module):
	def __init__(self, bce_weight=0.5):
		super(BCEDiceLoss, self).__init__()
		self.bce = nn.BCEWithLogitsLoss()
		self.dice = DiceLoss()

	def forward(self, pred, target):
		return self.bce(pred, target) * 0.5 + self.dice(pred, target) * 0.5

loss_fn = BCEDiceLoss()


In [None]:
lora_config = LoraConfig(
	r=16,
	lora_alpha=32,
	lora_dropout=0.1,
 	target_modules="mask_decoder.layers.0.conv2d"
)

model = get_peft_model(model, lora_config)
print(model)
optimizer = Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
  for batch in data_loader:
    images, masks = batch
    outputs = model(images)
    loss = BCEDiceLoss()
    loss.backward()
    optimizer.step()

In [None]:
## rename

import os
import re

folder = "./data/training/0000/direction_masks"

for filename in os.listdir(folder):
  match = re.match(r"(mask_(left|right|front|back))_mask_(.*\.png)", filename)
  
  if match:
    new_name = f"{match.group(1)}_{match.group(3)}"
    old_path = os.path.join(folder, filename)
    new_path = os.path.join(folder, new_name)
    
    os.rename(old_path, new_path)
    print(f"Renamed: {filename} -> {new_name}")
    
print("done")