In [1]:
import os
import torch
import torchvision
import numpy as np
from PIL import Image
from UnimatchV2_LULC.model.semseg.dpt import DPT
from utils.utils import unpatchify, decode_segmap, LABELS, COLORS

xFormers not available
xFormers not available


In [4]:
def make_patches(image: np.ndarray, patch_size: int) -> tuple[np.ndarray, tuple[int, int]]:
    """
    Splits a PIL image into patches of size (patch_size, patch_size),
    zero-padding the remaining parts if needed.

    Args:
        image (PIL.Image.Image): Input image.
        patch_size (int): Size of each patch (square).

    Returns:
        numpy.ndarray: Array of patches with shape 
        (num_patches_vertical, num_patches_horizontal, 1, patch_size, patch_size, 3)
    """
    # image = image.convert('RGB')
    image_size = image.shape[1], image.shape[0]
    # image_np = np.array(image)
    image_np = image
    h, w, c = image_np.shape

    num_patches_vertical = (h + patch_size - 1) // patch_size
    num_patches_horizontal = (w + patch_size - 1) // patch_size

    patches = torch.zeros(
        (num_patches_vertical, num_patches_horizontal, 1, patch_size, patch_size, c)
    )

    for i in range(num_patches_vertical):
        for j in range(num_patches_horizontal):
            y_start = i * patch_size
            x_start = j * patch_size
            patch = image_np[y_start:y_start +
                             patch_size, x_start:x_start + patch_size]

            # Handle padding if needed
            padded_patch = torch.zeros(
                (patch_size, patch_size, c))
            padded_patch[:patch.shape[0], :patch.shape[1], :] = patch

            patches[i, j, 0] = padded_patch

    return patches, image_size

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DPT(
    **{'encoder_size': 'base', 'features': 128, 'out_channels': [96, 192, 384, 768],
       'nclass': 6})
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device)

unimatch_path = '/opt/models/exp/unimatchv2_0.pth'
checkpoint = torch.load(
    unimatch_path, map_location='cpu', weights_only=False)
new_state_dict = {}
for k, v in checkpoint['model'].items():
    new_key = k.replace('module.', '')
    new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)


def predict(image: np.ndarray, patch_size: int = 518) -> tuple[Image.Image, Image.Image, list]:
    image = image[:, :, :3]
    original_image = image.copy()
    
    image = torchvision.transforms.functional.to_tensor(
    image)
    image = torchvision.transforms.functional.normalize(
    image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).permute((1, 2, 0))

    patch_images, image_size = make_patches(image, patch_size)

    size_y, size_x, _, p_s_1, p_s_2, channels = patch_images.shape
    patch_images = patch_images.reshape(
        size_x * size_y, p_s_1, p_s_2, channels).permute((0, 3, 1, 2)).cuda()

    output_images = []
    count_array = np.zeros(6, dtype=np.int_)

    for image in patch_images:
        model.eval()

        # image = torchvision.transforms.functional.to_tensor(
        #     image).to(device)
        # image = torchvision.transforms.functional.normalize(
        #     image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).reshape(1, channels, patch_size, patch_size)
        image = image.reshape(1, channels, patch_size, patch_size)

        image = image.to(dtype=torch.float32)

        output = model(image)
        output = output.detach().max(dim=1)[1].cpu().numpy().squeeze(axis=0)

        unique, counter = np.unique(output, return_counts=True)
        count_temp = np.zeros(6, dtype=np.int_)
        count_temp[unique] = counter
        count_array += count_temp

        output = decode_segmap(output)
        output = Image.fromarray(output)
        output_images.append(output)

    output_images = np.stack(output_images, axis=0).reshape(
        size_y, size_x, 1, p_s_1, p_s_2, channels)

    output_image = unpatchify(output_images, image_size)
    output_image = Image.fromarray(output_image)

    labels = LABELS.get('lulc')
    colors = [str(color) for color in COLORS.get('lulc')[1:]]
    area = [f'{val * 4.92e-6:,.2f}' for val in count_array[1:]]
    max_pixel = np.sum(count_array[1:]) or 1
    count_array = [f'{val / max_pixel * 100:.2f}%' for val in count_array[1:]]
    table = list(zip(labels, list(count_array), area, colors))

    torch.cuda.empty_cache()

    return Image.fromarray(original_image), output_image, table


In [16]:
img = Image.open('/home/skeptic/web-lulc/trash/full_image.png')
img = torchvision.transforms.functional.to_tensor(
    img)
img = torchvision.transforms.functional.normalize(
    img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).permute((1, 2, 0))

# patched_images = make_patches(img, 512)[0]
pred = predict(np.asarray(img))[1]
# # pred
# patched_images
pred

TypeError: Cannot handle this data type: (1, 1, 3), <f4