Skip to content

Commit

Permalink
Add hash of message body to auth tokens in slurm_send_node_msg()
Browse files Browse the repository at this point in the history
  • Loading branch information
fafik23 authored and wickberg committed May 10, 2022
1 parent c86bd97 commit 40099fb
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 20 deletions.
2 changes: 2 additions & 0 deletions NEWS
Expand Up @@ -237,6 +237,8 @@ documents those changes that are of interest to users and administrators.
-- openapi/dbv0.0.38 - add with_deleted input to GET /qos.
-- openapi/dbv0.0.38 - set with_deleted to false by default for GET /account[s]
-- openapi/dbv0.0.38 - add with_deleted input to GET /account[s].
-- Include k12 hash of the RPC message body in the auth/munge tokens to
provide for additional communication resiliency.

* Changes in Slurm 21.08.9
==========================
Expand Down
107 changes: 87 additions & 20 deletions src/common/slurm_protocol_api.c
Expand Up @@ -60,6 +60,7 @@
#include "src/common/assoc_mgr.h"
#include "src/common/fd.h"
#include "src/common/forward.h"
#include "src/common/hash.h"
#include "src/common/log.h"
#include "src/common/macros.h"
#include "src/common/pack.h"
Expand Down Expand Up @@ -181,28 +182,89 @@ static int _check_hash(buf_t *buffer, header_t *header, slurm_msg_t *msg,
int rc;
static time_t config_update = (time_t) -1;
static bool block_null_hash = true;
static bool block_zero_hash = true;

if (config_update != slurm_conf.last_update) {
block_null_hash = (xstrcasestr(slurm_conf.comm_params,
"block_null_hash"));
block_zero_hash = (xstrcasestr(slurm_conf.comm_params,
"block_zero_hash"));
config_update = slurm_conf.last_update;
}

rc = auth_g_get_data(cred, &cred_hash, &cred_hash_len);
if (!slurm_get_plugin_hash_enable(msg->auth_index))
return SLURM_SUCCESS;

if (cred_hash || cred_hash_len) {
if (cred_hash_len != 3 || cred_hash[0] != 1 ||
memcmp(cred_hash + 1,
&msg->msg_type, sizeof(msg->msg_type)))
rc = SLURM_ERROR;
} else if (block_null_hash &&
slurm_get_plugin_hash_enable(msg->auth_index))
rc = auth_g_get_data(cred, &cred_hash, &cred_hash_len);
if (cred_hash_len) {
log_flag_hex(NET_RAW, cred_hash, cred_hash_len,
"%s: cred_hash:", __func__);
if (cred_hash[0] == HASH_PLUGIN_NONE) {
if (block_zero_hash || (cred_hash_len != 3) ||
memcmp(cred_hash + 1, &msg->msg_type,
sizeof(msg->msg_type)))
rc = SLURM_ERROR;
else
msg->hash_index = HASH_PLUGIN_NONE;
} else {
char *data;
uint32_t size = header->body_length;
slurm_hash_t hash = { 0 };
int h_len;
uint16_t msg_type = htons(msg->msg_type);

data = get_buf_data(buffer) + get_buf_offset(buffer);
hash.type = cred_hash[0];

h_len = hash_g_compute(data, size, (char *) &msg_type,
sizeof(msg_type), &hash);
if ((h_len + 1) != cred_hash_len ||
memcmp(cred_hash + 1, hash.hash, h_len))
rc = SLURM_ERROR;
else
msg->hash_index = hash.type;
log_flag_hex(NET_RAW, &hash, sizeof(hash),
"%s: hash:", __func__);
}
} else if (block_null_hash)
rc = SLURM_ERROR;

xfree(cred_hash);
return rc;
}

static int _compute_hash(buf_t *buffer, slurm_msg_t *msg, slurm_hash_t *hash)
{
int h_len = 0;

if (slurm_get_plugin_hash_enable(msg->auth_index)) {
if (msg->hash_index != HASH_PLUGIN_DEFAULT)
hash->type = msg->hash_index;
else if (msg->protocol_version <= SLURM_21_08_PROTOCOL_VERSION)
hash->type = HASH_PLUGIN_NONE;

if (hash->type == HASH_PLUGIN_NONE) {
memcpy(hash->hash, &msg->msg_type,
sizeof(msg->msg_type));
h_len = sizeof(msg->msg_type);
} else {
uint16_t msg_type = htons(msg->msg_type);

h_len = hash_g_compute(get_buf_data(buffer),
get_buf_offset(buffer),
(char *) &msg_type,
sizeof(msg_type), hash);
}

if (h_len < 0)
return h_len;
h_len++;
}

return h_len;

}

static int _get_tres_id(char *type, char *name)
{
slurmdb_tres_rec_t tres_rec;
Expand Down Expand Up @@ -1701,7 +1763,8 @@ int slurm_send_node_msg(int fd, slurm_msg_t * msg)
int rc;
void * auth_cred;
time_t start_time = time(NULL);
uint8_t auth_payload[3] = { 1 }; /* uint8_t + uint16_t (msg_type) */
slurm_hash_t hash = { 0 };
int h_len;

if (msg->conn) {
persist_msg_t persist_msg;
Expand Down Expand Up @@ -1737,7 +1800,6 @@ int slurm_send_node_msg(int fd, slurm_msg_t * msg)

if (!msg->restrict_uid_set)
fatal("%s: restrict_uid is not set", __func__);
memcpy(auth_payload + 1, &msg->msg_type, sizeof(msg->msg_type));
/*
* Pack message into buffer
*/
Expand All @@ -1754,14 +1816,21 @@ int slurm_send_node_msg(int fd, slurm_msg_t * msg)
* but we may need to generate the credential again later if we
* wait too long for the incoming message.
*/
h_len = _compute_hash(buffers.body, msg, &hash);
if (h_len < 0) {
error("%s: hash_g_compute: %s has error",
__func__, rpc_num2string(msg->msg_type));
free_buf(buffers.body);
slurm_seterrno_ret(SLURM_UNEXPECTED_MSG_ERROR);
}
log_flag_hex(NET_RAW, &hash, sizeof(hash),
"%s: hash:", __func__);
if (msg->flags & SLURM_GLOBAL_AUTH_KEY) {
auth_cred = auth_g_create(msg->auth_index, _global_auth_key(),
msg->restrict_uid, auth_payload,
sizeof(auth_payload));
msg->restrict_uid, &hash, h_len);
} else {
auth_cred = auth_g_create(msg->auth_index, slurm_conf.authinfo,
msg->restrict_uid, auth_payload,
sizeof(auth_payload));
msg->restrict_uid, &hash, h_len);
}

if (msg->forward.init != FORWARD_INIT) {
Expand All @@ -1779,15 +1848,13 @@ int slurm_send_node_msg(int fd, slurm_msg_t * msg)
if (msg->flags & SLURM_GLOBAL_AUTH_KEY) {
auth_cred = auth_g_create(msg->auth_index,
_global_auth_key(),
msg->restrict_uid,
auth_payload,
sizeof(auth_payload));
msg->restrict_uid, &hash,
h_len);
} else {
auth_cred = auth_g_create(msg->auth_index,
slurm_conf.authinfo,
msg->restrict_uid,
auth_payload,
sizeof(auth_payload));
msg->restrict_uid, &hash,
h_len);
}
}
if (auth_cred == NULL) {
Expand Down

0 comments on commit 40099fb

Please sign in to comment.