In [1]:
import cv2
import glob
import os
import matplotlib.pyplot as plt
import numpy as np
import rasterio

In [38]:
mask_path = '/scratch2/ziyliu/pro_data/sate_dataset_V4/test_test/'
img_path = mask_path
predict_path = '/scratch2/ziyliu/LAMA/lama/sate_dataset_V4_output/trained_test_inet256_sate_test-pretrained-inet256_model210000/inpainted/'

In [39]:
# find all img, mask, output pairs and visualize them
mask_names = sorted(list(glob.glob(os.path.join(mask_path, '**', '*mask*.tif'), recursive=True)))

# mask_names = ['/scratch2/ziyliu/pro_data/sate_dataset_V4/test_test/ZA_A001_20APR09081310-P1BS-014905565010_01_P003_test_mask.tif']
img_names = [fname.rsplit('_mask', 1)[0]+'.tif' for fname in mask_names]
#replace test_img with test_img_output
output_names = [os.path.join(predict_path, os.path.splitext(fname[len(mask_path):])[0])+'.tif'  for fname in img_names]
# output_names = [os.path.splitext(fname[len(img_path)+1:])[0]+'.tif'  for fname in img_names]

model_name = 'inet210000'

In [2]:
mask_names = ['/scratch2/ziyliu/LAMA/lama/BRZ_test/2_7723282967316973_014886554010_01_P002_mask.tif']
img_names = ['/scratch2/ziyliu/LAMA/lama/BRZ_test/2_7723282967316973_014886554010_01_P002.tif']
output_names = ['/scratch2/ziyliu/LAMA/lama/test_output_baseline/2_7723282967316973_014886554010_01_P002_mask.tif']
# model_name = 'original'
# model_name = 'fine-tuned (lr:1e-5)'
model_name = 'opencv'

In [None]:
# find 1 areas in mask, use 3 pixels' dilation and draw the boundary on original image and output image
# image and predicted image are panchromatic grayscale images, nodata value is 2**16-1
result_num = 0
for mask_name, img_name, output_name in zip(mask_names, img_names, output_names):
    # read image while ignoring nodata value
    image = cv2.imread(img_name, cv2.IMREAD_UNCHANGED)
    nodata_mask = image == 2**16-1
    image = image.astype(np.float32)

    image[nodata_mask] = np.nan
    mask_ = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
    mask = (nodata_mask | mask_)*1
    output = cv2.imread(output_name, cv2.IMREAD_UNCHANGED)

    # use 3 pixels' dilation to find the boundary of 1 areas
    kernel = np.ones((6,6), np.uint16)
    mask_dilation = cv2.dilate(mask, kernel, iterations=1)

    # find the findContours of mask_dilation
    contours, _ = cv2.findContours(mask_dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    for contour in contours:
        cv2.drawContours(image, contour, -1, (0, 0, 255), 1)
        cv2.drawContours(output, contour, -1, (0, 0, 255), 1)

    vmin = np.min(output)
    vmax = np.max(output)

    for i, contour in enumerate(contours):
        area = cv2.contourArea(contour)
        if area < 2000: # 2600, 4000
            continue

        x, y, w, h = cv2.boundingRect(contour)
        
        # a = np.maximum(0, y-20)
        # b = np.minimum(y+h+20, image.shape[0])
        # c = np.maximum(0, x-20)
        # d = np.minimum(x+w+20, image.shape[1])
        a = np.maximum(0, y-20)
        b = np.minimum(y+h+20, image.shape[0])
        c = np.maximum(0, x-60)
        d = np.minimum(x+w+60, image.shape[1])
        roi_image = image[a:b, c:d]
        roi_output = output[a:b, c:d]
        roi_mask = mask[a:b, c:d]

        # 如果image中的nan值过多，不显示
        if np.sum(np.isnan(roi_image)) > 20: #20
            continue
        # 如果mask中的1值过少，不显示
        if np.sum(roi_mask) < 1000:
            continue

        # plot the image, mask, output in the first row, and the roi_image, roi_mask, roi_output in the second row
        fig, ax = plt.subplots(2, 3, figsize=(15, 10))
        # draw the bounding box of the roi on the image
        image_show= cv2.rectangle(image.copy(), (c, a), (d, b), (0, 255, 0), 10)

        fig.suptitle(f'Inputs and outputs for {model_name} model') #  trained from scratch, original 
        ax[0, 0].imshow(image_show, cmap='gray', vmin=vmin, vmax=vmax)
        ax[0, 0].set_title('Input: ground truth image')
        ax[0, 0].axis('off')
        ax[0, 1].imshow(mask, cmap='gray')
        ax[0, 1].set_title('Input: occlusion mask')
        ax[0, 1].axis('off')
        ax[0, 2].imshow(output, cmap='gray', vmin=vmin, vmax=vmax)
        ax[0, 2].set_title(f'Output: inpainted image: {model_name} model')
        ax[0, 2].axis('off')
        ax[1, 0].imshow(roi_image, cmap='gray', vmin=vmin, vmax=vmax)
        ax[1, 0].set_title('ROI crop of gt image')
        ax[1, 0].axis('off')
        ax[1, 1].imshow(roi_mask, cmap='gray')
        ax[1, 1].set_title('ROI crop of occlusion mask')
        ax[1, 1].axis('off')
        ax[1, 2].imshow(roi_output, cmap='gray', vmin=vmin, vmax=vmax)
        ax[1, 2].set_title(f'ROI crop of inpainted image: {model_name} model')
        ax[1, 2].axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
        result_num += 1
        if result_num > 20:
            break
    if result_num > 20:
        break

gif version

In [None]:
## plot the gif
from matplotlib.animation import FuncAnimation

result_num = 0
for mask_name, img_name, output_name in zip(mask_names, img_names, output_names):
    # read image while ignoring nodata value
    image = cv2.imread(img_name, cv2.IMREAD_UNCHANGED)
    nodata_mask = image == 2**16-1
    image = image.astype(np.float32)

    image[nodata_mask] = np.nan
    mask_ = cv2.imread(mask_name, cv2.IMREAD_UNCHANGED)
    mask = (nodata_mask | mask_)*1
    output = cv2.imread(output_name, cv2.IMREAD_UNCHANGED)

    # use 3 pixels' dilation to find the boundary of 1 areas
    kernel = np.ones((6,6), np.uint16)
    mask_dilation = cv2.dilate(mask, kernel, iterations=1)

    # find the findContours of mask_dilation
    contours, _ = cv2.findContours(mask_dilation, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    for contour in contours:
        cv2.drawContours(image, contour, -1, (0, 0, 255), 1)
        cv2.drawContours(output, contour, -1, (0, 0, 255), 1)

    vmin = np.min(output)
    vmax = np.max(output)

    for i, contour in enumerate(contours):
        area = cv2.contourArea(contour)
        if area < 2000: # 2600, 4000
            continue

        x, y, w, h = cv2.boundingRect(contour)
        
        a = np.maximum(0, y-20)
        b = np.minimum(y+h+20, image.shape[0])
        c = np.maximum(0, x-20)
        d = np.minimum(x+w+20, image.shape[1])

        roi_image = image[a:b, c:d]
        roi_output = output[a:b, c:d]
        roi_mask = mask[a:b, c:d]

        # 如果image中的nan值过多，不显示
        if np.sum(np.isnan(roi_image)) > 20: #20
            continue
        # 如果mask中的1值过少，不显示
        if np.sum(roi_mask) < 1000:
            continue

        images = [roi_image,roi_image, roi_output, roi_output]
        titles = ['ROI crop of gt image', 'ROI crop of gt image', 'ROI crop of inpainted image', 'ROI crop of inpainted image']

        def update(i):
            im.set_array(images[i])
            title_text.set_text(titles[i])
            return im, title_text

        # plot the image, mask, output in the first row, and the roi_image, roi_mask, roi_output in the second row
        fig, ax = plt.subplots(2, 3, figsize=(15, 10))
        # draw the bounding box of the roi on the image
        image_show= cv2.rectangle(image.copy(), (c, a), (d, b), (0, 255, 0), 10)

        im = ax[1, 0].imshow(images[0], cmap='gray', vmin=vmin, vmax=vmax)
        title_text = ax[1, 0].set_title(titles[0])

        fig.suptitle(f'Inputs and outputs for {model_name} model') #  trained from scratch, original 
        ax[0, 0].imshow(image_show, cmap='gray', vmin=vmin, vmax=vmax)
        ax[0, 0].set_title('Input: ground truth image')
        ax[0, 0].axis('off')
        ax[0, 1].imshow(mask, cmap='gray')
        ax[0, 1].set_title('Input: occlusion mask')
        ax[0, 1].axis('off')
        ax[0, 2].imshow(output, cmap='gray', vmin=vmin, vmax=vmax)
        ax[0, 2].set_title(f'Output: inpainted image: {model_name} model')
        ax[0, 2].axis('off')
        # ax[1, 0].imshow(roi_image, cmap='gray', vmin=vmin, vmax=vmax)
        # ax[1, 0].set_title('ROI crop of gt image')
        ax[1, 0].axis('off')
        ax[1, 1].imshow(roi_mask, cmap='gray')
        ax[1, 1].set_title('ROI crop of occlusion mask')
        ax[1, 1].axis('off')
        ax[1, 2].imshow(roi_output, cmap='gray', vmin=vmin, vmax=vmax)
        ax[1, 2].set_title(f'ROI crop of inpainted image: {model_name} model')
        ax[1, 2].axis('off')
        plt.tight_layout()

        ani = FuncAnimation(fig, update, frames=range(4), repeat=True, interval=4000, blit=True)
        ani.save(f'gif_result_BRZ/loop_animation_{result_num}.gif', writer='imagemagick', fps=1, savefig_kwargs={'facecolor':'white'})

        plt.show()
        plt.close()
        result_num += 1
        if result_num > 20:
            break
    if result_num > 20:
        break

In [13]:
dir = '/scratch2/ziyliu/pro_data/sate_dataset_V2/train/NZ_A001/image/21FEB11013741-P1BS-014422115010_01_P002.tif'
with rasterio.open(dir) as src:
    img_array = src.read(1).astype('float32')
    no_data = [img_array == 2**16-1]
    # save no_data mask
    no_data = np.array(no_data).astype('uint16')
    profile = src.profile
    with rasterio.open('no_data_mask.tif', 'w', **profile) as dst:
        dst.write(no_data)

mask_dir = '/scratch2/ziyliu/pro_data/sate_dataset_V2/train/NZ_A001/mask/13MAR24221955-P1BS-014418373010_01_P005_mask.tif'
# reproject mask to the same crs as image

from rasterio.warp import reproject, Resampling

with rasterio.open(dir) as src_img:
    img_array = src_img.read(1).astype('float32')
    img_crs = src_img.crs
    img_transform = src_img.transform
    img_bounds = src_img.bounds
    img_width = src_img.width
    img_height = src_img.height

# reprojection mask to img
with rasterio.open(mask_dir) as src_mask:
    kwargs = src_mask.meta.copy()
    kwargs.update({
        'crs': img_crs,
        'transform': img_transform,
        'width': img_width,
        'height': img_height,
        'nodata': src_mask.nodata
    })
    projected_mask = np.empty((img_height, img_width), dtype='float32')
    reproject(
        source=rasterio.band(src_mask, 1),
        destination=projected_mask,
        src_transform=src_mask.transform,
        src_crs=src_mask.crs,
        dst_transform=img_transform,
        dst_crs=img_crs,
        dst_bounds=img_bounds,
        resampling=Resampling.nearest)
    # 大于0的变成1，其余为0
    mask_array = (projected_mask > 0).astype('uint16')
    mask_array = no_data | mask_array
    with rasterio.open('reprojected_mask.tif', 'w', **kwargs) as dst:
        dst.write(mask_array[0,:,:], 1)
