In [1]:
import torch
import numpy as np
import faiss
import glob
import rasterio
from rasterio.windows import Window
from tqdm import tqdm
from transformers import AutoImageProcessor, Dinov2WithRegistersModel

In [2]:
file_list = glob.glob('./examples/**.tif')

centroids_array_path = './centroids_6k.npy'
centroids_arr = np.load(centroids_array_path).astype(np.float32)
cetroid_index = faiss.IndexFlatIP(centroids_arr.shape[1])
cetroid_index.add(centroids_arr)

colormap_array_path = './color_map_rgb_6k.npy'
colormap_arr = np.load(colormap_array_path).astype(np.uint8)

In [3]:

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
model = Dinov2WithRegistersModel.from_pretrained("facebook/dinov2-with-registers-base")

def process_image(image):
    features_shape = (16, 16)

    h, w, c = image.shape
    inputs = image_processor(image, return_tensors="pt")
    with torch.no_grad():
        output = model(**inputs)
        last_hidden_state = output.last_hidden_state
        features_flat = last_hidden_state[:, 1 + model.config.num_register_tokens:, :]
        features = features_flat.unflatten(1, features_shape)
        features_np = features.detach().cpu().numpy().astype(np.float16)

    eh, ew, ec = features_np.shape[1], features_np.shape[2], features_np.shape[3]
    kmeans_2c = np.zeros([eh, ew, 3], dtype=np.uint8)
    for ih in range(eh):
        for iw in range(ew):
            pxe = features_np[0, ih, iw, :].reshape(1, ec)
            _, kmeans_index = cetroid_index.search(pxe, 1) 
            color_rgb = colormap_arr[kmeans_index[0][0]]
            kmeans_2c[ih, iw] = color_rgb
    return kmeans_2c

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
sample_size = 224
features_size = 16
singlecolor_count_threshold = 768
one_feature_in_pixels = sample_size / features_size
border_in_features = 3
one_border_in_pixels = (border_in_features * one_feature_in_pixels)
borders_in_pixels = one_border_in_pixels * 2
sample_stride_in_pixels = sample_size - borders_in_pixels
sample_stride_in_features = features_size - (2 * border_in_features)


def calulate_singlecolor(sample_image, threshold=1):
    t = torch.from_numpy(np.mean(sample_image, axis=2))
    th, tw = t.shape
    ts = torch.nn.functional.pixel_unshuffle(t.reshape(1, 1, th, tw), 7)
    ts_min = torch.min(ts, dim=1).values
    ts_max = torch.max(ts, dim=1).values
    ts_diff = ts_max - ts_min
    ts_diff.shape
    singlecolor_count = int(torch.sum(ts_diff[0,:,:] < threshold))
    return singlecolor_count


def process_tif_to_tif(src_fp, dst_fp):
    with rasterio.open(src_fp) as src:
        profile = src.profile
        src_crs = src.crs
        height_pixels, width_pixels = src.height, src.width
    
        h1_list_in_pixels = []
        h1 = 0
        while True:
            if h1 + sample_size > height_pixels:
                break
            h1_list_in_pixels.append(h1)
            h1 += sample_stride_in_pixels
        b = ((h1 + sample_size) - height_pixels) // 2
        h1_list_in_pixels = [i + b for i in h1_list_in_pixels]
        
        w1_list_in_pixels = []
        w1 = 0
        while True:
            if w1 + sample_size > width_pixels:
                break
            w1_list_in_pixels.append(w1)
            w1 += sample_stride_in_pixels
        b = ((w1 + sample_size) - width_pixels) // 2
        w1_list_in_pixels = [i + b for i in w1_list_in_pixels]
        
        
        h1w1_list_in_pixels = []
        for h1 in h1_list_in_pixels:
            for w1 in w1_list_in_pixels:
                h1w1_list_in_pixels.append((h1, w1))
        
        top_pixel = h1_list_in_pixels[0] + one_border_in_pixels
        bottom_pixel = (h1_list_in_pixels[-1] + sample_size) - one_border_in_pixels
        
        left_pixel = w1_list_in_pixels[0] + one_border_in_pixels
        right_pixel = (w1_list_in_pixels[-1] + sample_size) - one_border_in_pixels
        
        dst_west, dst_north = src.xy(top_pixel, left_pixel)
        dst_east, dst_south = src.xy(bottom_pixel, right_pixel)
        
        dst_height, dst_width = (bottom_pixel - top_pixel) / one_feature_in_pixels, (right_pixel - left_pixel) / one_feature_in_pixels
        dst_transform = rasterio.transform.from_bounds(dst_west, dst_south, dst_east, dst_north, dst_width, dst_height)

        profile = {
            'driver': 'GTiff',
            'dtype': 'uint8',
            'count': 3,
            'height': dst_height,
            'width': dst_width,
            'compress': 'LZW',
            'crs':src_crs,
            'transform':dst_transform
        }
    
        with rasterio.open(dst_fp, 'w', **profile) as dst:
            for h1, w1 in tqdm(h1w1_list_in_pixels):
                sample = src.read(window=((h1, h1 + sample_size), (w1, w1 + sample_size)), boundless=True, fill_value=0)
                sample = np.transpose(sample[:3, :], (1, 2, 0))
                singlecolor_count = calulate_singlecolor(sample, threshold=2)
                if singlecolor_count > singlecolor_count_threshold:
                    continue
    
                fetures_image = process_image(sample)

                x, y = src.xy(h1 + one_border_in_pixels, w1 + one_border_in_pixels)
                dst_h1, dst_w1 = dst.index(x, y)
                window = Window(col_off=dst_w1, row_off=dst_h1, width=sample_stride_in_features, height=sample_stride_in_features)
                for b in range(3):
                    dst.write(fetures_image[border_in_features:-border_in_features,border_in_features:-border_in_features,b], b+1, window=window)

In [5]:
for fp in file_list:
    print(fp)
    process_tif_to_tif(fp, './i_' + fp.split('/')[-1])
print('Done!')

./examples/65dd2b5c7175970001f718ba.tif


100%|███████████████████████████████████████| 1025/1025 [03:23<00:00,  5.04it/s]


./examples/64a71afc64adbc00012e09b6.tif


100%|█████████████████████████████████████████| 728/728 [04:07<00:00,  2.94it/s]


./examples/62e90217140bec00067db70d.tif


100%|███████████████████████████████████████| 2601/2601 [07:56<00:00,  5.46it/s]

Done!



