<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
import cv2
import torch
import matplotlib.pyplot as plt

import multiprocessing as mp

import geopandas as gpd
import rasterio
from shapely import geometry
import numpy as np

from tqdm import tqdm
from skimage.morphology import h_minima, watershed, label

from dh_segment_torch.inference import InferenceModel

In [None]:
DATA = '/dhlabdata4/benali/'
MODELS = DATA + 'models/1848/'
IMAGES = DATA + 'cadaster_1848_test/splits/'

IMG_NAME = 'cannaregio_07-11_crop05'
MODEL_EDGES_NAME = 'model_edges2_n02'
MODEL_CLASSES_NAME = 'model_classes_inv_n01'

IMG_FORMAT = '.tif'
#TYPE = 'full' # indicates training with classes and edges
TYPE = 'sep' # indicates separate training for classes and edges


IMG = IMAGES + IMG_NAME + IMG_FORMAT
EDGES = MODELS + MODEL_EDGES_NAME + '.pth'
#EDGES = DATA + 'models/1808/' + MODEL_EDGES_NAME + '.pth'
CLASSES = MODELS + MODEL_CLASSES_NAME + '.pth'

#DEVICE = 'cpu'
DEVICE = 'cuda:4'
SAVE_GEOJSON = False

In [None]:
model_edges = InferenceModel.from_params({
    "model": {
        "encoder": "resnet50",
        "decoder": {
            "decoder_channels": [512, 256, 128, 64, 32],
            "max_channels": 512
        }
    }, # Copier/coller du fichier config
    "num_classes": 2, # A inferer depuis le fichier
    "model_state_dict": EDGES,
    "device": DEVICE, # utiliser cuda:0 (ou /1/2)
    "patch_size": (500,500), # a adapter en fonction
    "patches_batch_size": 8,
    "patches_overlap": 0.2, # entre 0 et 1
    "multilabel": False, #
})

model_classes = InferenceModel.from_params({
    "model": {
        "encoder": "resnet50",
        "decoder": {
            "decoder_channels": [512, 256, 128, 64, 32],
            "max_channels": 512
        }
    },
    "num_classes": 6,
    "model_state_dict": CLASSES,
    "device": DEVICE,
    "patch_size": (500,500),
    "patches_batch_size": 8,
    "patches_overlap": 0,
    "multilabel": False
})

In [None]:
img = cv2.imread(IMG)[:,:,::-1].copy() # lit l'image en RGB

In [None]:
img_torch = torch.from_numpy(img.transpose(2,0,1)/255).float().unsqueeze(0) # transforme en pytorch

In [None]:
probs_edges = model_edges.predict_patches(img_torch)[0] 

probs_classes = model_classes.predict_patches(img_torch)[0]

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(img)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(probs_edges.cpu().numpy()[0], cmap='gray')

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(probs_classes.cpu().numpy()[4])

In [None]:
countours = probs_edges.cpu().numpy()[1]

minimas = label(h_minima(countours, 0.1))

watershed_parcels = watershed((255 * countours).astype('int'), minimas)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(16, 8), sharex=True, sharey=True)
ax = axes.ravel()

ax[0].imshow(img, cmap=plt.cm.gray)
ax[0].set_title('Cadaster')

ax[1].imshow(watershed_parcels, cmap=plt.cm.nipy_spectral)
ax[1].set_title('Watershed parcels')

In [None]:
idx2class = dict(enumerate(['background', 'street', 'water', 'church', 'courtyard', 'building']))
num_parcels= np.unique(watershed_parcels)

In [None]:
from affine import Affine

if (IMG_FORMAT == '.tif'):
    transform = rasterio.open(IMG).transform
else:
    geotransform = (0, 1, 0.0, 0, 0, -1)
    transform = Affine.from_gdal(*geotransform)

In [None]:
def num_parcel2res(parcel_idx, probs_classes=probs_classes, watershed_parcels=watershed_parcels):
    mask_parcel = watershed_parcels == parcel_idx
    class_idx = np.bincount(probs_classes.cpu()[0:, mask_parcel].argmax(axis=0)).argmax()
    contours, hierarchy = cv2.findContours(mask_parcel.astype('uint8').copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)

    poly1 = (cv2.approxPolyDP(contours[0], 1, closed=True)[:,0,:]).tolist()
    poly1.append(poly1[0])
    poly1 = [transform*x for x in poly1]
    holes = []
    for h in contours[1:]:
        poly2 = (cv2.approxPolyDP(h, 1, closed=True)[:,0,:]).tolist()
        poly2.append(poly2[0])
        poly2 = [transform*x for x in poly2]
        holes.append(poly2)
    poly = geometry.Polygon(poly1, holes=holes)
    return poly, idx2class[class_idx]

In [None]:
results = []

for idx in tqdm(num_parcels):
    results.append(num_parcel2res(idx))
    
all_polys = [x[0] for x in results]
all_classes = [x[1] for x in results]

all_polys = gpd.GeoSeries(all_polys)
all_polys.crs = 'EPSG:3004'

In [None]:
geodata = all_polys.to_frame('geometry')
geodata['class'] = all_classes

In [None]:
palette = ['#000000', '#FFFF00','#00FFFF','#FF0000','#00FF00','#FF00FF']
fig, axes = plt.subplots(ncols=2,figsize = (16,8))
ax = axes.ravel()
for idx, label in idx2class.items():
    geodata.loc[geodata['class'] == label]['geometry'].plot(color=palette[idx], label=label, ax=ax[1])
ax[0].imshow(watershed_parcels, cmap=plt.cm.nipy_spectral)

In [None]:
if SAVE_GEOJSON:
    if TYPE == 'sep':
        geodata.to_file(MODELS + 'geojsons/' + IMG_NAME + MODEL_EDGES_NAME + MODEL_CLASSES_NAME + '.geojson',
                        driver='GeoJSON')
    elif TYPE == 'full':
        geodata.to_file(MODELS + 'geojsons/' + IMG_NAME + MODEL_FULL_NAME + '.geojson', driver='GeoJSON')