/
slope_classification.py
240 lines (209 loc) · 7.79 KB
/
slope_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python
# coding: utf8
#
# Copyright (c) 2022 Centre National d'Etudes Spatiales (CNES).
#
# This file is part of demcompare
# (see https://github.com/CNES/demcompare).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Mainly contains the SlopeClassification class.
"""
import collections
import logging
from typing import Dict, List
import numpy as np
import xarray as xr
# DEMcompare imports
from demcompare.dem_tools import create_dem
from ..internal_typing import ConfigType
from .classification_layer import ClassificationLayer
from .classification_layer_template import ClassificationLayerTemplate
# Third party imports
@ClassificationLayer.register("slope")
class SlopeClassificationLayer(ClassificationLayerTemplate):
"""
SlopeClassificationLayer
"""
_RANGES = [0, 5, 10, 25, 45]
def __init__(
self,
name: str,
classification_layer_kind: str,
cfg: Dict,
dem: xr.Dataset = None,
):
"""
Init function
:param name: classification layer name
:type name: str
:param classification_layer_kind: classification layer kind
:type classification_layer_kind: str
:param cfg: layer's configuration
:type cfg: ConfigType
:param dem: dem
:type dem: xr.DataSet containing :
- image : 2D (row, col) xr.DataArray float32
- georef_transform: 1D (trans_len) xr.DataArray
- classification_layer_masks : 3D (row, col, indicator)
xr.DataArray
:return: None
"""
# Call generic init before supercharging
super().__init__(
name,
classification_layer_kind,
cfg,
dem,
)
# Ranges
self.ranges: List = self.cfg["ranges"]
# Checking configuration during initialisation step
# doesn't require classification layers
if dem is not None:
# Create labelled map to classification_layer from
self._create_labelled_map()
# Create class masks
self._create_class_masks()
logging.debug("ClassificationLayer created as: %s", self)
def fill_conf_and_schema(self, cfg: ConfigType = None) -> ConfigType:
"""
Add default values to the dictionary if there are missing
elements and define the configuration schema
:param cfg: coregistration configuration
:type cfg: ConfigType
:return cfg: coregistration configuration updated
:rtype: ConfigType
"""
# Call generic fill_conf_and_schema
cfg = super().fill_conf_and_schema(cfg)
# Give the default value if the required element
# is not in the configuration
if "ranges" not in cfg:
cfg["ranges"] = self._RANGES
# Add subclass parameter to the default schema
self.schema["ranges"] = list
self.check_ranges(cfg)
return cfg
@staticmethod
def check_ranges(cfg: dict) -> None:
"""
Verify users configuration for ranges in slope classification
:param cfg: slope configuration
:type cfg: dict
:return: None
"""
ranges_dict = cfg["ranges"]
if not all(
isinstance(values, int) or (ranges_dict is list)
for values in ranges_dict
):
raise TypeError("Ranges must be a list of int")
def _create_labelled_map(self):
"""
Create the labelled map and save it if necessary
:return: None
"""
# transform 'ranges' to 'classes'
self.classes: collections.OrderedDict = self._generate_classes(
self.ranges
)
# create slope maps of ref and sec
self._create_slope_map_datasets(self.dem)
def _create_slope_map_datasets(self, dem: xr.Dataset):
"""
Create slope map datasets
:param dem: input dem
:type dem: xr.DataSet containing :
- image : 2D (row, col) xr.DataArray float32
- georef_transform: 1D (trans_len) xr.DataArray
- classification_layer_masks : 3D (row, col, indicator)
xr.DataArray
:return: None
"""
# Classify slope
dict_slope = {"ref_slope": "ref", "sec_slope": "sec"}
for slope_name, support in dict_slope.items():
if slope_name in dem:
slope_img = self.dem[slope_name].data[:, :]
slope_dataset = create_dem(
slope_img,
transform=dem.georef_transform.data,
img_crs=dem.crs,
nodata=self.nodata,
)
# Create the layer map for each slope
self._classify_slope_by_ranges(slope_dataset, support)
@staticmethod
def _generate_classes(ranges) -> collections.OrderedDict:
"""
Create classes from ranges
:param ranges: ranges
:type ranges: List
:return: classes
:rtype: collections.OrderedDict
"""
# Change the intervals into a list to make 'classes' generic
classes = collections.OrderedDict()
for idx, range_item in enumerate(ranges):
if idx == len(ranges) - 1:
key = f"[{range_item}%;inf["
else:
key = f"[{range_item}%;{ranges[idx + 1]}%["
classes[key] = ranges[idx]
return classes
def _classify_slope_by_ranges(
self, slope_dataset: xr.Dataset, support: str = "ref"
):
"""
Create the map for each slope using the input ranges
(value interval is transformed into 1 value (interval minimum value))
:param slope_dataset: slope dataset
:type slope_dataset: xr.DataSet containing :
- image : 2D (row, col) xr.DataArray float32
- georef_transform: 1D (trans_len) xr.DataArray
- classification_layer_masks : 3D (row, col, indicator)
xr.DataArray
:param support: support dem, ref or sec
:type support: str
:return: None
"""
# Use radiometric ranges to classify the slope dataset
# Initialize map
map_img = np.ones(slope_dataset["image"].data.shape) * self.nodata
# For each radiometric range, add the slope values that are within
# the interval to the map_img
for idx, _ in enumerate(self.ranges):
# If it is the last range, do not check if smaller than next range
if idx == len(self.ranges) - 1:
map_img[
np.where(
(~np.isnan(slope_dataset["image"].data))
* (slope_dataset["image"].data >= self.ranges[idx])
)
] = self.ranges[idx]
else:
map_img[
np.where(
(~np.isnan(slope_dataset["image"].data))
* (slope_dataset["image"].data >= self.ranges[idx])
& (slope_dataset["image"].data < self.ranges[idx + 1])
)
] = self.ranges[idx]
# Store map_image
self.map_image[support] = map_img
# If output_dir is set, create map_dataset and save
if self.output_dir:
self.save_map_img(map_img, support)