Skip to content

Commit

Permalink
Merge pull request #9926 from rouault/fix_GDALNoDataMaskBand_IRasterI…
Browse files Browse the repository at this point in the history
…O_failure

GDALNoDataMaskBand::IRasterIO(): fix crash on memory allocation failure
  • Loading branch information
rouault authored May 28, 2024
2 parents 06e1043 + 5b28d32 commit 6bc311f
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 39 deletions.
55 changes: 42 additions & 13 deletions autotest/gcore/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,10 @@ def test_mask_27():


@pytest.mark.parametrize("dt", [gdal.GDT_Byte, gdal.GDT_Int64, gdal.GDT_UInt64])
def test_mask_setting_nodata(dt):
@pytest.mark.parametrize(
"GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND", [None, "YES", "ALWAYS"]
)
def test_mask_setting_nodata(dt, GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND):
def set_nodata_value(ds, val):
if dt == gdal.GDT_Byte:
ds.GetRasterBand(1).SetNoDataValue(val)
Expand All @@ -1005,15 +1008,41 @@ def set_nodata_value(ds, val):
else:
ds.GetRasterBand(1).SetNoDataValueAsUInt64(val)

ds = gdal.GetDriverByName("MEM").Create("", 1, 1, 1, dt)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
set_nodata_value(ds, 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
set_nodata_value(ds, 1)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
set_nodata_value(ds, 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
ds.GetRasterBand(1).DeleteNoDataValue()
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
def test():
ds = gdal.GetDriverByName("MEM").Create("__debug__", 1, 1, 1, dt)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
set_nodata_value(ds, 0)
got = ds.GetRasterBand(1).GetMaskBand().ReadRaster()
if (
GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND == "ALWAYS"
and dt != gdal.GDT_Byte
):
assert got is None
assert gdal.GetLastErrorType() == gdal.CE_Failure
else:
if (
GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND == "YES"
and dt != gdal.GDT_Byte
):
assert gdal.GetLastErrorType() == gdal.CE_Warning
assert got == struct.pack("B", 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
set_nodata_value(ds, 1)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack(
"B", 255
)
set_nodata_value(ds, 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)

ds.GetRasterBand(1).DeleteNoDataValue()
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)

if GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND:
with gdal.quiet_errors(), gdal.config_option(
"GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND",
GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND,
):
test()
else:
test()
94 changes: 68 additions & 26 deletions gcore/gdalnodatamaskband.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include <algorithm>
#include <cstring>
#include <utility>

#include "cpl_conv.h"
#include "cpl_error.h"
Expand Down Expand Up @@ -166,7 +167,6 @@ bool GDALNoDataMaskBand::IsNoDataInRange(double dfNoDataValue,
{
return GDALIsValueInRange<GUInt32>(dfNoDataValue);
}

case GDT_Int32:
{
return GDALIsValueInRange<GInt32>(dfNoDataValue);
Expand Down Expand Up @@ -279,19 +279,65 @@ CPLErr GDALNoDataMaskBand::IRasterIO(GDALRWFlag eRWFlag, int nXOff, int nYOff,
return CE_None;
}

if (eBufType == GDT_Byte)
const auto AllocTempBufferOrFallback =
[this, eRWFlag, nXOff, nYOff, nXSize, nYSize, pData, nBufXSize,
nBufYSize, eBufType, nPixelSpace, nLineSpace,
psExtraArg](int nWrkDTSize) -> std::pair<CPLErr, void *>
{
const int nWrkDTSize = GDALGetDataTypeSizeBytes(eWrkDT);
void *pTemp = VSI_MALLOC3_VERBOSE(nWrkDTSize, nBufXSize, nBufYSize);
if (pTemp == nullptr)
auto poParentDS = m_poParent->GetDataset();
// Check if we must simulate a memory allocation failure
// Before checking the env variable, which is slightly expensive,
// check first for a special dataset name, which is a cheap test.
const char *pszOptVal =
poParentDS && strcmp(poParentDS->GetDescription(), "__debug__") == 0
? CPLGetConfigOption(
"GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND", "NO")
: "NO";
const bool bSimulMemAllocFailure =
EQUAL(pszOptVal, "ALWAYS") ||
(CPLTestBool(pszOptVal) &&
GDALMajorObject::GetMetadataItem(__func__, "__INTERNAL__") ==
nullptr);
void *pTemp = nullptr;
if (!bSimulMemAllocFailure)
{
CPLErrorStateBackuper oErrorStateBackuper(CPLQuietErrorHandler);
pTemp = VSI_MALLOC3_VERBOSE(nWrkDTSize, nBufXSize, nBufYSize);
}
if (!pTemp)
{
return GDALRasterBand::IRasterIO(
eRWFlag, nXOff, nYOff, nXSize, nYSize, pTemp, nBufXSize,
nBufYSize, eWrkDT, nWrkDTSize,
static_cast<GSpacing>(nBufXSize) * nWrkDTSize, psExtraArg);
const bool bAllocHasAlreadyFailed =
GDALMajorObject::GetMetadataItem(__func__, "__INTERNAL__") !=
nullptr;
CPLError(bAllocHasAlreadyFailed ? CE_Failure : CE_Warning,
CPLE_OutOfMemory,
"GDALNoDataMaskBand::IRasterIO(): cannot allocate %d x %d "
"x %d bytes%s",
nBufXSize, nBufYSize, nWrkDTSize,
bAllocHasAlreadyFailed
? ""
: ". Falling back to block-based approach");
if (bAllocHasAlreadyFailed)
return std::pair(CE_Failure, nullptr);
// Sets a metadata item to prevent potential infinite recursion
GDALMajorObject::SetMetadataItem(__func__, "IN", "__INTERNAL__");
const CPLErr eErr = GDALRasterBand::IRasterIO(
eRWFlag, nXOff, nYOff, nXSize, nYSize, pData, nBufXSize,
nBufYSize, eBufType, nPixelSpace, nLineSpace, psExtraArg);
GDALMajorObject::SetMetadataItem(__func__, nullptr, "__INTERNAL__");
return std::pair(eErr, nullptr);
}
return std::pair(CE_None, pTemp);
};

const CPLErr eErr = m_poParent->RasterIO(
if (eBufType == GDT_Byte)
{
const int nWrkDTSize = GDALGetDataTypeSizeBytes(eWrkDT);
auto [eErr, pTemp] = AllocTempBufferOrFallback(nWrkDTSize);
if (!pTemp)
return eErr;

eErr = m_poParent->RasterIO(
GF_Read, nXOff, nYOff, nXSize, nYSize, pTemp, nBufXSize, nBufYSize,
eWrkDT, nWrkDTSize, static_cast<GSpacing>(nBufXSize) * nWrkDTSize,
psExtraArg);
Expand Down Expand Up @@ -453,30 +499,26 @@ CPLErr GDALNoDataMaskBand::IRasterIO(GDALRWFlag eRWFlag, int nXOff, int nYOff,

// Output buffer is non-Byte. Ask for Byte and expand to user requested
// type
GByte *pabyBuf =
static_cast<GByte *>(VSI_MALLOC2_VERBOSE(nBufXSize, nBufYSize));
if (pabyBuf == nullptr)
{
return GDALRasterBand::IRasterIO(eRWFlag, nXOff, nYOff, nXSize, nYSize,
pData, nBufXSize, nBufYSize, eBufType,
nPixelSpace, nLineSpace, psExtraArg);
}
const CPLErr eErr =
IRasterIO(eRWFlag, nXOff, nYOff, nXSize, nYSize, pabyBuf, nBufXSize,
nBufYSize, GDT_Byte, 1, nBufXSize, psExtraArg);
auto [eErr, pTemp] = AllocTempBufferOrFallback(sizeof(GByte));
if (!pTemp)
return eErr;

eErr = IRasterIO(eRWFlag, nXOff, nYOff, nXSize, nYSize, pTemp, nBufXSize,
nBufYSize, GDT_Byte, 1, nBufXSize, psExtraArg);
if (eErr != CE_None)
{
VSIFree(pabyBuf);
VSIFree(pTemp);
return eErr;
}

for (int iY = 0; iY < nBufYSize; iY++)
{
GDALCopyWords(pabyBuf + static_cast<size_t>(iY) * nBufXSize, GDT_Byte,
1, static_cast<GByte *>(pData) + iY * nLineSpace,
eBufType, static_cast<int>(nPixelSpace), nBufXSize);
GDALCopyWords(
static_cast<GByte *>(pTemp) + static_cast<size_t>(iY) * nBufXSize,
GDT_Byte, 1, static_cast<GByte *>(pData) + iY * nLineSpace,
eBufType, static_cast<int>(nPixelSpace), nBufXSize);
}
VSIFree(pabyBuf);
VSIFree(pTemp);
return CE_None;
}

Expand Down

0 comments on commit 6bc311f

Please sign in to comment.