Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GDALNoDataMaskBand::IRasterIO(): fix crash on memory allocation failure #9926

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions autotest/gcore/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,10 @@ def test_mask_27():


@pytest.mark.parametrize("dt", [gdal.GDT_Byte, gdal.GDT_Int64, gdal.GDT_UInt64])
def test_mask_setting_nodata(dt):
@pytest.mark.parametrize(
"GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND", [None, "YES", "ALWAYS"]
)
def test_mask_setting_nodata(dt, GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND):
def set_nodata_value(ds, val):
if dt == gdal.GDT_Byte:
ds.GetRasterBand(1).SetNoDataValue(val)
Expand All @@ -1005,15 +1008,41 @@ def set_nodata_value(ds, val):
else:
ds.GetRasterBand(1).SetNoDataValueAsUInt64(val)

ds = gdal.GetDriverByName("MEM").Create("", 1, 1, 1, dt)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
set_nodata_value(ds, 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
set_nodata_value(ds, 1)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
set_nodata_value(ds, 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
ds.GetRasterBand(1).DeleteNoDataValue()
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
def test():
ds = gdal.GetDriverByName("MEM").Create("__debug__", 1, 1, 1, dt)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)
set_nodata_value(ds, 0)
got = ds.GetRasterBand(1).GetMaskBand().ReadRaster()
if (
GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND == "ALWAYS"
and dt != gdal.GDT_Byte
):
assert got is None
assert gdal.GetLastErrorType() == gdal.CE_Failure
else:
if (
GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND == "YES"
and dt != gdal.GDT_Byte
):
assert gdal.GetLastErrorType() == gdal.CE_Warning
assert got == struct.pack("B", 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)
set_nodata_value(ds, 1)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack(
"B", 255
)
set_nodata_value(ds, 0)
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 0)

ds.GetRasterBand(1).DeleteNoDataValue()
assert ds.GetRasterBand(1).GetMaskBand().ReadRaster() == struct.pack("B", 255)

if GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND:
with gdal.quiet_errors(), gdal.config_option(
"GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND",
GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND,
):
test()
else:
test()
94 changes: 68 additions & 26 deletions gcore/gdalnodatamaskband.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include <algorithm>
#include <cstring>
#include <utility>

#include "cpl_conv.h"
#include "cpl_error.h"
Expand Down Expand Up @@ -166,7 +167,6 @@ bool GDALNoDataMaskBand::IsNoDataInRange(double dfNoDataValue,
{
return GDALIsValueInRange<GUInt32>(dfNoDataValue);
}

case GDT_Int32:
{
return GDALIsValueInRange<GInt32>(dfNoDataValue);
Expand Down Expand Up @@ -279,19 +279,65 @@ CPLErr GDALNoDataMaskBand::IRasterIO(GDALRWFlag eRWFlag, int nXOff, int nYOff,
return CE_None;
}

if (eBufType == GDT_Byte)
const auto AllocTempBufferOrFallback =
[this, eRWFlag, nXOff, nYOff, nXSize, nYSize, pData, nBufXSize,
nBufYSize, eBufType, nPixelSpace, nLineSpace,
psExtraArg](int nWrkDTSize) -> std::pair<CPLErr, void *>
{
const int nWrkDTSize = GDALGetDataTypeSizeBytes(eWrkDT);
void *pTemp = VSI_MALLOC3_VERBOSE(nWrkDTSize, nBufXSize, nBufYSize);
if (pTemp == nullptr)
auto poParentDS = m_poParent->GetDataset();
// Check if we must simulate a memory allocation failure
// Before checking the env variable, which is slightly expensive,
// check first for a special dataset name, which is a cheap test.
const char *pszOptVal =
poParentDS && strcmp(poParentDS->GetDescription(), "__debug__") == 0
? CPLGetConfigOption(
"GDAL_SIMUL_MEM_ALLOC_FAILURE_NODATA_MASK_BAND", "NO")
: "NO";
const bool bSimulMemAllocFailure =
EQUAL(pszOptVal, "ALWAYS") ||
(CPLTestBool(pszOptVal) &&
GDALMajorObject::GetMetadataItem(__func__, "__INTERNAL__") ==
nullptr);
void *pTemp = nullptr;
if (!bSimulMemAllocFailure)
{
CPLErrorStateBackuper oErrorStateBackuper(CPLQuietErrorHandler);
pTemp = VSI_MALLOC3_VERBOSE(nWrkDTSize, nBufXSize, nBufYSize);
}
if (!pTemp)
{
return GDALRasterBand::IRasterIO(
eRWFlag, nXOff, nYOff, nXSize, nYSize, pTemp, nBufXSize,
nBufYSize, eWrkDT, nWrkDTSize,
static_cast<GSpacing>(nBufXSize) * nWrkDTSize, psExtraArg);
const bool bAllocHasAlreadyFailed =
GDALMajorObject::GetMetadataItem(__func__, "__INTERNAL__") !=
nullptr;
CPLError(bAllocHasAlreadyFailed ? CE_Failure : CE_Warning,
CPLE_OutOfMemory,
"GDALNoDataMaskBand::IRasterIO(): cannot allocate %d x %d "
"x %d bytes%s",
nBufXSize, nBufYSize, nWrkDTSize,
bAllocHasAlreadyFailed
? ""
: ". Falling back to block-based approach");
if (bAllocHasAlreadyFailed)
return std::pair(CE_Failure, nullptr);
// Sets a metadata item to prevent potential infinite recursion
GDALMajorObject::SetMetadataItem(__func__, "IN", "__INTERNAL__");
const CPLErr eErr = GDALRasterBand::IRasterIO(
eRWFlag, nXOff, nYOff, nXSize, nYSize, pData, nBufXSize,
nBufYSize, eBufType, nPixelSpace, nLineSpace, psExtraArg);
GDALMajorObject::SetMetadataItem(__func__, nullptr, "__INTERNAL__");
return std::pair(eErr, nullptr);
}
return std::pair(CE_None, pTemp);
};

const CPLErr eErr = m_poParent->RasterIO(
if (eBufType == GDT_Byte)
{
const int nWrkDTSize = GDALGetDataTypeSizeBytes(eWrkDT);
auto [eErr, pTemp] = AllocTempBufferOrFallback(nWrkDTSize);
if (!pTemp)
return eErr;

eErr = m_poParent->RasterIO(
GF_Read, nXOff, nYOff, nXSize, nYSize, pTemp, nBufXSize, nBufYSize,
eWrkDT, nWrkDTSize, static_cast<GSpacing>(nBufXSize) * nWrkDTSize,
psExtraArg);
Expand Down Expand Up @@ -453,30 +499,26 @@ CPLErr GDALNoDataMaskBand::IRasterIO(GDALRWFlag eRWFlag, int nXOff, int nYOff,

// Output buffer is non-Byte. Ask for Byte and expand to user requested
// type
GByte *pabyBuf =
static_cast<GByte *>(VSI_MALLOC2_VERBOSE(nBufXSize, nBufYSize));
if (pabyBuf == nullptr)
{
return GDALRasterBand::IRasterIO(eRWFlag, nXOff, nYOff, nXSize, nYSize,
pData, nBufXSize, nBufYSize, eBufType,
nPixelSpace, nLineSpace, psExtraArg);
}
const CPLErr eErr =
IRasterIO(eRWFlag, nXOff, nYOff, nXSize, nYSize, pabyBuf, nBufXSize,
nBufYSize, GDT_Byte, 1, nBufXSize, psExtraArg);
auto [eErr, pTemp] = AllocTempBufferOrFallback(sizeof(GByte));
if (!pTemp)
return eErr;

eErr = IRasterIO(eRWFlag, nXOff, nYOff, nXSize, nYSize, pTemp, nBufXSize,
nBufYSize, GDT_Byte, 1, nBufXSize, psExtraArg);
if (eErr != CE_None)
{
VSIFree(pabyBuf);
VSIFree(pTemp);
return eErr;
}

for (int iY = 0; iY < nBufYSize; iY++)
{
GDALCopyWords(pabyBuf + static_cast<size_t>(iY) * nBufXSize, GDT_Byte,
1, static_cast<GByte *>(pData) + iY * nLineSpace,
eBufType, static_cast<int>(nPixelSpace), nBufXSize);
GDALCopyWords(
static_cast<GByte *>(pTemp) + static_cast<size_t>(iY) * nBufXSize,
GDT_Byte, 1, static_cast<GByte *>(pData) + iY * nLineSpace,
eBufType, static_cast<int>(nPixelSpace), nBufXSize);
}
VSIFree(pabyBuf);
VSIFree(pTemp);
return CE_None;
}

Expand Down
Loading