In [None]:
import sys
sys.path.append("../src")
from dataset import DeepGlobeDataset
from typing import TypedDict
import segmentation_models_pytorch as smp
from torchvision import transforms as T
import torch
import numpy as np
from PIL import Image
import gc

In [7]:
class StandartEvalConfig(TypedDict):
    data_dir: str = "../dataset/train"
    val_size: float = 0.2
    test_size: float = 0.1
    transform = None
    target_transform = None
    batch_size: int = 1
    learning_rate: float = 2e-4
    epochs: int = 5
    encoder_name: str = "resnet18"
    encoder_weights: str = "imagenet"
    activation: str = "logsoftmax"
    in_channels: int = 3
    classes: int = 7
    device: str = "mps"
    checkpoints_dir: str = "../src/checkpoints"
    freeze_encoder_layers: int = -2  # freeze encoder excluding 2 last layers

In [16]:
images = []
masks = []
images.append(Image.open("../dataset/train/6399_sat.jpg"))
masks.append(Image.open("../dataset/train/6399_mask.png"))
images.append(Image.open("../dataset/train/10901_sat.jpg"))
masks.append(Image.open("../dataset/train/10901_mask.png"))
images.append(Image.open("../dataset/train/855_sat.jpg"))
masks.append(Image.open("../dataset/train/855_mask.png"))

In [17]:
def mask2tensor(mask, resize=False):
    mask_array = np.array(mask)
    class_mask = np.zeros(mask_array.shape[:2], dtype=np.uint8)

    for rgb, label in DeepGlobeDataset.class_mapping.items():
        class_mask[(mask_array == rgb).all(axis=-1)] = label

    class_mask = torch.from_numpy(class_mask).long()
    if resize:
        class_mask = T.Resize((2464, 2464))(class_mask.unsqueeze(0)).squeeze(0)
    return class_mask

def img2tensor(image, resize=False):
    image_tensor = T.ToTensor()(image)
    if resize:
        image_tensor = T.Resize((2464, 2464))(image_tensor)
    return image_tensor

dataset = []
for img, mask in zip(images, masks):
    img_tensor = img2tensor(img)
    mask_tensor = mask2tensor(mask)
    dataset.append((img_tensor, mask_tensor))

In [20]:
cfg = StandartEvalConfig
model_path = f"{cfg.checkpoints_dir}/pspnet_Epochs:{cfg.epochs}_lf:DiceLoss_lr:{cfg.learning_rate}_best.pth"
model = smp.PSPNet(
    encoder_name=cfg.encoder_name,
    encoder_weights=cfg.encoder_weights,
    in_channels=cfg.in_channels,
    classes=cfg.classes,
    activation=cfg.activation,
)

model.load_state_dict(torch.load(model_path, map_location=torch.device(cfg.device)))

pred_pspnet = []
for img, mask in dataset:
    model.to(cfg.device)
    model.eval()
    with torch.no_grad():
        x = img.unsqueeze(0).to(cfg.device)
        x = model(x)
        pred_pspnet.append(torch.argmax(x, dim=1).squeeze(0).cpu().numpy())
del(model)
gc.collect()
torch.mps.empty_cache()

In [23]:
cfg = StandartEvalConfig
model_path = f"{cfg.checkpoints_dir}/unet_Epochs:{4}_lf:DiceLoss_lr:{cfg.learning_rate}_best.pth"
model = smp.Unet(
    encoder_name=cfg.encoder_name,
    encoder_weights=cfg.encoder_weights,
    in_channels=cfg.in_channels,
    classes=cfg.classes,
    activation=cfg.activation,
)

model.load_state_dict(torch.load(model_path, map_location=torch.device(cfg.device)))

pred_unet = []
for img, mask in dataset:
    model.to(cfg.device)
    model.eval()
    with torch.no_grad():
        x = img.unsqueeze(0).to(cfg.device)
        x = model(x)
        pred_unet.append(torch.argmax(x, dim=1).squeeze(0).cpu().numpy())
del(model)
gc.collect()
torch.mps.empty_cache()

In [None]:
dataset_fpn = []
for img, mask in zip(images, masks):
    img_tensor = img2tensor(img, resize=True)
    mask_tensor = mask2tensor(mask, resize=True)
    dataset_fpn.append((img_tensor, mask_tensor))
cfg = StandartEvalConfig
model_path = f"{cfg.checkpoints_dir}/fpn_Epochs:{cfg.epochs}_lf:DiceLoss_lr:{cfg.learning_rate}_best.pth"
model = smp.FPN(
    encoder_name=cfg.encoder_name,
    encoder_weights=cfg.encoder_weights,
    in_channels=cfg.in_channels,
    classes=cfg.classes,
    activation=cfg.activation,
)

model.load_state_dict(torch.load(model_path, map_location=torch.device(cfg.device)))

pred_fpn = []
for img, mask in dataset_fpn:
    model.to(cfg.device)
    model.eval()
    with torch.no_grad():
        x = img.unsqueeze(0).to(cfg.device)
        x = model(x)
        pred_fpn.append(torch.argmax(x, dim=1).squeeze(0).cpu().numpy())
gc.collect()
torch.mps.empty_cache()

In [29]:
def plot_combined(idx, save_path=None):
    h = max(pred_unet[idx].shape[0], pred_fpn[idx].shape[0])
    w = max(pred_unet[idx].shape[1], pred_fpn[idx].shape[1])
    combined = Image.new("RGB", (w * 3 + 10, h * 3 + 15))
    preds = [pred_pspnet[idx], pred_unet[idx], pred_fpn[idx]]
    for i, (pred) in enumerate(preds):
        combined.paste(masks[idx].convert("RGB"), (0, i * (h + 5)))
        combined.paste(images[idx].convert("RGB"), (w + 5, i * (h + 5)))
        combined.paste(DeepGlobeDataset.label_to_rgb_mask(pred), (w * 2 + 10, i * (h + 5)))
    if save_path:
        combined.save(save_path)
    return combined

In [None]:
plot_combined(0, save_path="../results/collected/comparison_0.png")

In [None]:
plot_combined(1, save_path="../results/collected/comparison_1.png")

In [None]:
plot_combined(2, save_path="../results/collected/comparison_2.png")