In [None]:
! pip install segmentation_models_pytorch -q
! git clone -b inference https://github.com/DavidFM43/landcover-segmentation.git
%cd landcover-segmentation

In [None]:
config = {
    "downsize_res": 512,
    "batch_size": 6,
    "epochs": 30,
    "lr": 3e-4,
    "model_architecture": "Unet",
    "model_config": {
        "encoder_name": "resnet34",
        "encoder_weights": "imagenet",
        "in_channels": 3,
        "classes": 7,

    },
}

In [None]:
import torch 
from torchvision import transforms
import segmentation_models_pytorch as smp

device = "cuda" if torch.cuda.is_available() else "cpu"
# instantiate model and load weights
cp_path = "checkpoints/CP_epoch20.pth"
model_architecture = getattr(smp, config["model_architecture"])
model = model_architecture(**config["model_config"])
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
model.to(device)
model.eval();

downsize_res = config["downsize_res"]
# mean = [0.4085, 0.3798, 0.2822]
# std = [0.1410, 0.1051, 0.0927]
# transforms
downsize_t = transforms.Resize(downsize_res, antialias=True)
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # transforms.Normalize(mean, std),
    ]
)

In [None]:
import os 
from PIL import Image

images_dir = "data/sample_sat_images"
image_ids = os.listdir(images_dir)
sample_id = image_ids[0]
image_path = f"{images_dir}/{sample_id}"
sat_img = Image.open(image_path)
sat_img.resize((512, 512))


In [None]:
X = transform(sat_img).unsqueeze(0)
X = X.to(device)
X_down = downsize_t(X)

In [None]:
%%time
# forward pass
logits = model(X_down)
preds = torch.argmax(logits, 1).detach()
# resize to evaluate with the original image
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)        

In [None]:
from torchvision.utils import draw_segmentation_masks
from torchvision.io import read_image
import torch.nn.functional as F

def label_to_onehot(mask, num_classes):
    dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
    return torch.permute(
        F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
        dims_p,
    )

sat_img = read_image(image_path)
masks = preds.squeeze()
masks = label_to_onehot(masks, 7)


In [None]:
class_rgb_colors = [(0, 255, 255),
 (255, 255, 0),
 (255, 0, 255),
 (0, 255, 0),
 (0, 0, 255),
 (255, 255, 255),
 (0, 0, 0)]

mask_over_image = draw_segmentation_masks(
    sat_img, masks=masks, alpha=0.2, colors=class_rgb_colors
)


In [None]:
mask_over_image.shape

In [None]:
transforms.functional.to_pil_image(mask_over_image)


In [None]:
def get_overlay(sat_img, preds, alpha):
    class_rgb_colors = [(0, 255, 255), (255, 255, 0), (255, 0, 255), (0, 255, 0), (0, 0, 255), (255, 255, 255), (0, 0, 0)]
    masks = preds.squeeze()
    masks = label_to_onehot(masks, 7)
    overlay = draw_segmentation_masks(
        sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
    )
    return overlay

def segment(sat_image):
    sat_img = Image.open(image_path)
    sat_img2 = read_image(image_path)
    # preprocess image
    X = transform(sat_img).unsqueeze(0)
    X = X.to(device)
    X_down = downsize_t(X)
    # forward pass
    logits = model(X_down)
    preds = torch.argmax(logits, 1).detach()
    # resize to evaluate with the original image
    preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)        
    overlay = get_overlay(sat_img2, preds, 0.2)
    raw_masks = get_overlay(torch.zeros_like(sat_img2), preds, 1)

    raw_masks = torch.permute(raw_masks, (1, 2, 0))
    overlay = torch.permute(overlay, (1, 2, 0))

    return raw_masks.numpy(), overlay.numpy()

In [None]:
import gradio as gr

i = gr.inputs.Image()
o = [gr.Image(), gr.Image()]

examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
title = "Satellite Images Landcover Segmentation"
description = "Upload an image or select from examples to segment"

gr.Interface(segment, i, o, examples=examples, title=title, description=description).launch(share=True, debug=True)