Skip to content

Commit

Permalink
Sanitize dump payload: fail RESTORE if memory allocation fails
Browse files Browse the repository at this point in the history
When RDB input attempts to make a huge memory allocation that fails,
RESTORE should fail gracefully rather than die with panic
  • Loading branch information
oranagra committed Dec 6, 2020
1 parent 3716950 commit 7ca00d6
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 83 deletions.
29 changes: 26 additions & 3 deletions src/dict.c
Expand Up @@ -143,9 +143,13 @@ int dictResize(dict *d)
return dictExpand(d, minimal);
}

/* Expand or create the hash table */
int dictExpand(dict *d, unsigned long size)
/* Expand or create the hash table,
* when malloc_failed is non-NULL, it'll avoid panic if malloc fails (in which case it'll be set to 1).
* Returns DICT_OK if expand was performed, and DICT_ERR if skipped. */
int _dictExpand(dict *d, unsigned long size, int* malloc_failed)
{
if (malloc_failed) *malloc_failed = 0;

/* the size is invalid if it is smaller than the number of
* elements already inside the hash table */
if (dictIsRehashing(d) || d->ht[0].used > size)
Expand All @@ -160,7 +164,14 @@ int dictExpand(dict *d, unsigned long size)
/* Allocate the new hash table and initialize all pointers to NULL */
n.size = realsize;
n.sizemask = realsize-1;
n.table = zcalloc(realsize*sizeof(dictEntry*));
if (malloc_failed) {
n.table = ztrycalloc(realsize*sizeof(dictEntry*));
*malloc_failed = n.table == NULL;
if (*malloc_failed)
return DICT_ERR;
} else
n.table = zcalloc(realsize*sizeof(dictEntry*));

n.used = 0;

/* Is this the first initialization? If so it's not really a rehashing
Expand All @@ -176,6 +187,18 @@ int dictExpand(dict *d, unsigned long size)
return DICT_OK;
}

/* return DICT_ERR if expand was not performed */
int dictExpand(dict *d, unsigned long size) {
return _dictExpand(d, size, NULL);
}

/* return DICT_ERR if expand failed due to memory allocation failure */
int dictTryExpand(dict *d, unsigned long size) {
int malloc_failed;
_dictExpand(d, size, &malloc_failed);
return malloc_failed? DICT_ERR : DICT_OK;
}

/* Performs N steps of incremental rehashing. Returns 1 if there are still
* keys to move from the old to the new hash table, otherwise 0 is returned.
*
Expand Down
1 change: 1 addition & 0 deletions src/dict.h
Expand Up @@ -151,6 +151,7 @@ typedef void (dictScanBucketFunction)(void *privdata, dictEntry **bucketref);
/* API */
dict *dictCreate(dictType *type, void *privDataPtr);
int dictExpand(dict *d, unsigned long size);
int dictTryExpand(dict *d, unsigned long size);
int dictAdd(dict *d, void *key, void *val);
dictEntry *dictAddRaw(dict *d, void *key, dictEntry **existing);
dictEntry *dictAddOrFind(dict *d, void *key);
Expand Down
56 changes: 42 additions & 14 deletions src/rdb.c
Expand Up @@ -387,14 +387,22 @@ void *rdbLoadLzfStringObject(rio *rdb, int flags, size_t *lenptr) {

if ((clen = rdbLoadLen(rdb,NULL)) == RDB_LENERR) return NULL;
if ((len = rdbLoadLen(rdb,NULL)) == RDB_LENERR) return NULL;
if ((c = zmalloc(clen)) == NULL) goto err;
if ((c = ztrymalloc(clen)) == NULL) {
serverLog(server.loading? LL_WARNING: LL_VERBOSE, "rdbLoadLzfStringObject failed allocating %llu bytes", (unsigned long long)clen);
goto err;
}

/* Allocate our target according to the uncompressed size. */
if (plain) {
val = zmalloc(len);
val = ztrymalloc(len);
} else {
val = sdsnewlen(SDS_NOINIT,len);
val = sdstrynewlen(SDS_NOINIT,len);
}
if (!val) {
serverLog(server.loading? LL_WARNING: LL_VERBOSE, "rdbLoadLzfStringObject failed allocating %llu bytes", (unsigned long long)len);
goto err;
}

if (lenptr) *lenptr = len;

/* Load the compressed representation and uncompress it to target. */
Expand Down Expand Up @@ -522,7 +530,11 @@ void *rdbGenericLoadStringObject(rio *rdb, int flags, size_t *lenptr) {
}

if (plain || sds) {
void *buf = plain ? zmalloc(len) : sdsnewlen(SDS_NOINIT,len);
void *buf = plain ? ztrymalloc(len) : sdstrynewlen(SDS_NOINIT,len);
if (!buf) {
serverLog(server.loading? LL_WARNING: LL_VERBOSE, "rdbGenericLoadStringObject failed allocating %llu bytes", len);
return NULL;
}
if (lenptr) *lenptr = len;
if (len && rioRead(rdb,buf,len) == 0) {
if (plain)
Expand Down Expand Up @@ -1545,8 +1557,11 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
o = createSetObject();
/* It's faster to expand the dict to the right size asap in order
* to avoid rehashing */
if (len > DICT_HT_INITIAL_SIZE)
dictExpand(o->ptr,len);
if (len > DICT_HT_INITIAL_SIZE && dictTryExpand(o->ptr,len) != DICT_OK) {
rdbReportCorruptRDB("OOM in dictTryExpand %llu", (unsigned long long)len);
decrRefCount(o);
return NULL;
}
} else {
o = createIntsetObject();
}
Expand Down Expand Up @@ -1574,7 +1589,12 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
}
} else {
setTypeConvert(o,OBJ_ENCODING_HT);
dictExpand(o->ptr,len);
if (dictTryExpand(o->ptr,len) != DICT_OK) {
rdbReportCorruptRDB("OOM in dictTryExpand %llu", (unsigned long long)len);
sdsfree(sdsele);
decrRefCount(o);
return NULL;
}
}
}

Expand All @@ -1601,8 +1621,11 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
o = createZsetObject();
zs = o->ptr;

if (zsetlen > DICT_HT_INITIAL_SIZE)
dictExpand(zs->dict,zsetlen);
if (zsetlen > DICT_HT_INITIAL_SIZE && dictTryExpand(zs->dict,zsetlen) != DICT_OK) {
rdbReportCorruptRDB("OOM in dictTryExpand %llu", (unsigned long long)zsetlen);
decrRefCount(o);
return NULL;
}

/* Load every single element of the sorted set. */
while(zsetlen--) {
Expand Down Expand Up @@ -1723,8 +1746,13 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
dupSearchDict = NULL;
}

if (o->encoding == OBJ_ENCODING_HT && len > DICT_HT_INITIAL_SIZE)
dictExpand(o->ptr,len);
if (o->encoding == OBJ_ENCODING_HT && len > DICT_HT_INITIAL_SIZE) {
if (dictTryExpand(o->ptr,len) != DICT_OK) {
rdbReportCorruptRDB("OOM in dictTryExpand %llu", (unsigned long long)len);
decrRefCount(o);
return NULL;
}
}

/* Load remaining fields and values into the hash table */
while (o->encoding == OBJ_ENCODING_HT && len > 0) {
Expand Down Expand Up @@ -1823,9 +1851,9 @@ robj *rdbLoadObject(int rdbtype, rio *rdb, sds key) {
zl = ziplistPush(zl, vstr, vlen, ZIPLIST_TAIL);

/* search for duplicate records */
sds field = sdsnewlen(fstr, flen);
if (dictAdd(dupSearchDict, field, NULL) != DICT_OK) {
rdbReportCorruptRDB("Hash zipmap with dup elements");
sds field = sdstrynewlen(fstr, flen);
if (!field || dictAdd(dupSearchDict, field, NULL) != DICT_OK) {
rdbReportCorruptRDB("Hash zipmap with dup elements, or big length (%u)", flen);
dictRelease(dupSearchDict);
sdsfree(field);
zfree(encoded);
Expand Down
14 changes: 12 additions & 2 deletions src/sds.c
Expand Up @@ -100,7 +100,7 @@ static inline size_t sdsTypeMaxSize(char type) {
* You can print the string with printf() as there is an implicit \0 at the
* end of the string. However the string is binary safe and can contain
* \0 characters in the middle, as the length is stored in the sds header. */
sds sdsnewlen(const void *init, size_t initlen) {
sds _sdsnewlen(const void *init, size_t initlen, int trymalloc) {
void *sh;
sds s;
char type = sdsReqType(initlen);
Expand All @@ -111,7 +111,9 @@ sds sdsnewlen(const void *init, size_t initlen) {
unsigned char *fp; /* flags pointer. */
size_t usable;

sh = s_malloc_usable(hdrlen+initlen+1, &usable);
sh = trymalloc?
s_trymalloc_usable(hdrlen+initlen+1, &usable) :
s_malloc_usable(hdrlen+initlen+1, &usable);
if (sh == NULL) return NULL;
if (init==SDS_NOINIT)
init = NULL;
Expand Down Expand Up @@ -162,6 +164,14 @@ sds sdsnewlen(const void *init, size_t initlen) {
return s;
}

sds sdsnewlen(const void *init, size_t initlen) {
return _sdsnewlen(init, initlen, 0);
}

sds sdstrynewlen(const void *init, size_t initlen) {
return _sdsnewlen(init, initlen, 1);
}

/* Create an empty (zero length) sds string. Even in this case the string
* always has an implicit null term. */
sds sdsempty(void) {
Expand Down
1 change: 1 addition & 0 deletions src/sds.h
Expand Up @@ -216,6 +216,7 @@ static inline void sdssetalloc(sds s, size_t newlen) {
}

sds sdsnewlen(const void *init, size_t initlen);
sds sdstrynewlen(const void *init, size_t initlen);
sds sdsnew(const char *init);
sds sdsempty(void);
sds sdsdup(const sds s);
Expand Down
4 changes: 4 additions & 0 deletions src/sdsalloc.h
Expand Up @@ -42,9 +42,13 @@
#include "zmalloc.h"
#define s_malloc zmalloc
#define s_realloc zrealloc
#define s_trymalloc ztrymalloc
#define s_tryrealloc ztryrealloc
#define s_free zfree
#define s_malloc_usable zmalloc_usable
#define s_realloc_usable zrealloc_usable
#define s_trymalloc_usable ztrymalloc_usable
#define s_tryrealloc_usable ztryrealloc_usable
#define s_free_usable zfree_usable

#endif

0 comments on commit 7ca00d6

Please sign in to comment.