Skip to content

Commit

Permalink
adressed #845
Browse files Browse the repository at this point in the history
  • Loading branch information
janzandr committed Mar 18, 2024
1 parent 5f4dbb9 commit 4d7ae55
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def shortDescription(self) -> str:
'Stehman (2014): https://doi.org/10.1080/01431161.2014.930207. ' \
'Note that (simple) random sampling is a special case of stratified random sampling, ' \
'with exactly one stratum. \n' \
'Observed and predicted categories are matched by name.'
'Observed and predicted categories are matched by name, if possible. ' \
'Otherwise, categories are matched by order (in this case, a warning message is logged).'

def helpParameters(self) -> List[Tuple[str, str]]:
return [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

import numpy as np

from enmapbox.typeguard import typechecked
from enmapboxprocessing.algorithm.rasterizecategorizedvectoralgorithm import RasterizeCategorizedVectorAlgorithm
from enmapboxprocessing.algorithm.translatecategorizedrasteralgorithm import TranslateCategorizedRasterAlgorithm
from enmapboxprocessing.enmapalgorithm import EnMAPProcessingAlgorithm, Group
from enmapboxprocessing.rasterreader import RasterReader
from enmapboxprocessing.reportwriter import HtmlReportWriter, CsvReportWriter, MultiReportWriter
from enmapboxprocessing.utils import Utils
from qgis.core import (QgsProcessingContext, QgsProcessingFeedback, QgsVectorLayer, QgsRasterLayer, QgsUnitTypes)
from enmapbox.typeguard import typechecked


@typechecked
Expand All @@ -34,7 +34,8 @@ def displayName(cls) -> str:
def shortDescription(self) -> str:
return 'Estimates map accuracy and area proportions for stratified random sampling as described in ' \
'Stehman (2014): https://doi.org/10.1080/01431161.2014.930207. \n' \
'Observed and predicted categories are matched by name.'
'Observed and predicted categories are matched by name, if possible. ' \
'Otherwise, categories are matched by order (in this case, a warning message is logged).'

def helpParameters(self) -> List[Tuple[str, str]]:
return [
Expand Down Expand Up @@ -168,17 +169,22 @@ def processAlgorithm(
yMap = arrayPrediction[valid].astype(np.float32)
# - remap class ids by name
yMapRemapped = yMap.copy() # this initial state is correct for matching by order (see #845)
classNamesMatching = list()
for i, cP in enumerate(categoriesPrediction):
found = False
for cR in categoriesReference:
if cR.name == cP.name:
yMapRemapped[yMap == cP.value] = cR.value
found = True
classNamesMatching.append([cP.name, cR.name])
if not found:
feedback.pushWarning(
f'predicted class "{categoriesPrediction[i].name}" not found in reference classes, '
f'and will be matched by order to class "".'
f'predicted class "{categoriesPrediction[i].name}" not found in reference classes. '
f'class will be matched by order: '
f'"{cP.name}" -> "{categoriesReference[i].name}".'
)
classNamesMatching.append([cP.name, categoriesReference[i].name])

yMap = yMapRemapped
# - prepare strata
stratum = arrayStratification[valid]
Expand All @@ -197,7 +203,10 @@ def processAlgorithm(
stats = stratifiedAccuracyAssessment(stratum, yReference, yMap, h, N_h, classValues, classNames)
pixelUnits = QgsUnitTypes.toString(classification.crs().mapUnits())
pixelArea = classification.rasterUnitsPerPixelX() * classification.rasterUnitsPerPixelY()
self.writeReport(filename, stats, pixelUnits=pixelUnits, pixelArea=pixelArea)

self.writeReport(
filename, stats, pixelUnits=pixelUnits, pixelArea=pixelArea, classNamesMatching=classNamesMatching
)
# dump json
with open(filename + '.json', 'w') as file:
file.write(json.dumps(stats.__dict__, indent=4))
Expand All @@ -211,7 +220,10 @@ def processAlgorithm(
return result

@classmethod
def writeReport(cls, filename: str, stats: 'StratifiedAccuracyAssessmentResult', pixelUnits='pixel', pixelArea=1.):
def writeReport(
cls, filename: str, stats: 'StratifiedAccuracyAssessmentResult', pixelUnits='pixel', pixelArea=1.,
classNamesMatching: list = None
):

def smartRound(obj, ndigits):
if isinstance(obj, list):
Expand Down Expand Up @@ -242,6 +254,9 @@ def confidenceIntervall(mean, se):
report.writeParagraph(f'Sample size: {stats.n} px')
report.writeParagraph(f'Area size: {smartRound(stats.N, 2)} {pixelUnits}')

if classNamesMatching is not None:
report.writeTable(classNamesMatching, 'Class matching', ['predicted', 'observed'])

values = smartRound(stats.confusion_matrix_counts, 2)
report.writeTable(
values, 'Adjusted confusion matrix counts: predicted (rows) vs. observed (columns)',
Expand Down

0 comments on commit 4d7ae55

Please sign in to comment.