Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions tests/unit/s2n_client_psk_extension_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ struct s2n_psk_test_case {
size_t identity_size;
};

uint16_t s2n_test_customer_wire_index_choice;
static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn,
static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn, void *context,
struct s2n_offered_psk_list *psk_identity_list)
{
uint16_t *wire_index_choice = (uint16_t*) context;

struct s2n_offered_psk offered_psk = { 0 };
uint16_t idx = 0;
while(s2n_offered_psk_list_has_next(psk_identity_list)) {
POSIX_GUARD(s2n_offered_psk_list_next(psk_identity_list, &offered_psk));
if (idx == s2n_test_customer_wire_index_choice) {
if (idx == *wire_index_choice) {
POSIX_GUARD(s2n_offered_psk_list_choose_psk(psk_identity_list, &offered_psk));
break;
}
Expand All @@ -49,7 +50,7 @@ static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn,
return S2N_SUCCESS;
}

static int s2n_test_error_select_psk_identity_callback(struct s2n_connection *conn,
static int s2n_test_error_select_psk_identity_callback(struct s2n_connection *conn, void *context,
struct s2n_offered_psk_list *psk_identity_list)
{
POSIX_BAIL(S2N_ERR_UNIMPLEMENTED);
Expand Down Expand Up @@ -574,7 +575,7 @@ int main(int argc, char **argv)
{
struct s2n_config *config = s2n_config_new();
EXPECT_NOT_NULL(config);
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_error_select_psk_identity_callback));
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_error_select_psk_identity_callback, NULL));

struct s2n_connection *conn;
EXPECT_NOT_NULL(conn = s2n_connection_new(S2N_SERVER));
Expand All @@ -594,7 +595,9 @@ int main(int argc, char **argv)
{
struct s2n_config *config = s2n_config_new();
EXPECT_NOT_NULL(config);
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback));

uint16_t expected_wire_choice = 0;
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback, &expected_wire_choice));

struct s2n_connection *conn;
EXPECT_NOT_NULL(conn = s2n_connection_new(S2N_SERVER));
Expand All @@ -609,7 +612,6 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_stuffer_growable_alloc(&wire_identities_in, 0));
EXPECT_OK(s2n_write_test_identity(&wire_identities_in, &match_psk->identity));

s2n_test_customer_wire_index_choice = 0;
EXPECT_OK(s2n_client_psk_recv_identity_list(conn, &wire_identities_in));
EXPECT_EQUAL(conn->psk_params.chosen_psk, match_psk);
EXPECT_EQUAL(conn->psk_params.chosen_psk_wire_index, 0);
Expand All @@ -623,7 +625,9 @@ int main(int argc, char **argv)
{
struct s2n_config *config = s2n_config_new();
EXPECT_NOT_NULL(config);
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback));

uint16_t expected_wire_choice = 10;
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback, &expected_wire_choice));

struct s2n_connection *conn;
EXPECT_NOT_NULL(conn = s2n_connection_new(S2N_SERVER));
Expand All @@ -638,7 +642,6 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_stuffer_growable_alloc(&wire_identities_in, 0));
EXPECT_OK(s2n_write_test_identity(&wire_identities_in, &match_psk->identity));

s2n_test_customer_wire_index_choice = 10;
EXPECT_ERROR(s2n_client_psk_recv_identity_list(conn, &wire_identities_in));
EXPECT_EQUAL(conn->psk_params.chosen_psk, NULL);

Expand All @@ -651,7 +654,9 @@ int main(int argc, char **argv)
{
struct s2n_config *config = s2n_config_new();
EXPECT_NOT_NULL(config);
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback));

uint16_t expected_wire_choice = 0;
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback, &expected_wire_choice));

struct s2n_connection *conn;
EXPECT_NOT_NULL(conn = s2n_connection_new(S2N_SERVER));
Expand All @@ -669,7 +674,6 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_stuffer_growable_alloc(&wire_identities_in, 0));
EXPECT_OK(s2n_write_test_identity(&wire_identities_in, &wire_identity));

s2n_test_customer_wire_index_choice = 0;
EXPECT_ERROR(s2n_client_psk_recv_identity_list(conn, &wire_identities_in));
EXPECT_EQUAL(conn->psk_params.chosen_psk, NULL);

Expand Down
23 changes: 15 additions & 8 deletions tests/unit/s2n_config_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
#include "tls/s2n_security_policies.h"
#include "tls/s2n_tls13.h"

static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn, struct s2n_offered_psk_list *psk_identity_list)
static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn, void *context,
struct s2n_offered_psk_list *psk_identity_list)
{
return S2N_SUCCESS;
}
Expand Down Expand Up @@ -139,16 +140,22 @@ int main(int argc, char **argv)
{
struct s2n_config *config = NULL;
EXPECT_NOT_NULL(config = s2n_config_new());
uint8_t context = 13;

/* Safety checks */
{
EXPECT_FAILURE_WITH_ERRNO(s2n_config_set_psk_selection_callback(NULL, s2n_test_select_psk_identity_callback), S2N_ERR_NULL);
EXPECT_FAILURE_WITH_ERRNO(s2n_config_set_psk_selection_callback(config, NULL), S2N_ERR_NULL);
}

/* Safety check */
EXPECT_FAILURE_WITH_ERRNO(s2n_config_set_psk_selection_callback(
NULL, s2n_test_select_psk_identity_callback, &context), S2N_ERR_NULL);
EXPECT_NULL(config->psk_selection_cb);
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback));
EXPECT_NULL(config->psk_selection_ctx);

EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, s2n_test_select_psk_identity_callback, &context));
EXPECT_EQUAL(config->psk_selection_cb, s2n_test_select_psk_identity_callback);
EXPECT_EQUAL(config->psk_selection_ctx, &context);

EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(config, NULL, NULL));
EXPECT_NULL(config->psk_selection_cb);
EXPECT_NULL(config->psk_selection_ctx);

EXPECT_SUCCESS(s2n_config_free(config));
}

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/s2n_self_talk_psk_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ static s2n_result validate_chosen_psk(struct s2n_connection *server_conn, uint8_
return S2N_RESULT_OK;
}

static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn, struct s2n_offered_psk_list *psk_identity_list)
static int s2n_test_select_psk_identity_callback(struct s2n_connection *conn, void *context,
struct s2n_offered_psk_list *psk_identity_list)
{
struct s2n_offered_psk offered_psk = { 0 };
uint16_t idx = 0;
Expand Down Expand Up @@ -206,7 +207,7 @@ int main(int argc, char **argv)
EXPECT_OK(setup_server_psks(server_conn));

/* Set the customer callback to select PSK identity */
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(server_conn->config, s2n_test_select_psk_identity_callback));
EXPECT_SUCCESS(s2n_config_set_psk_selection_callback(server_conn->config, s2n_test_select_psk_identity_callback, NULL));
EXPECT_EQUAL(server_conn->config->psk_selection_cb, s2n_test_select_psk_identity_callback);

/* Negotiate handshake */
Expand Down
3 changes: 2 additions & 1 deletion tls/extensions/s2n_client_psk.c
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ static S2N_RESULT s2n_select_psk_identity(struct s2n_connection *conn, struct s2
static S2N_RESULT s2n_client_psk_recv_identity_list(struct s2n_connection *conn, struct s2n_stuffer *wire_identities_in)
{
RESULT_ENSURE_REF(conn);
RESULT_ENSURE_REF(conn->config);
RESULT_ENSURE_REF(wire_identities_in);

struct s2n_offered_psk_list identity_list = {
Expand All @@ -192,7 +193,7 @@ static S2N_RESULT s2n_client_psk_recv_identity_list(struct s2n_connection *conn,
};

if (conn->config->psk_selection_cb) {
RESULT_GUARD_POSIX(conn->config->psk_selection_cb(conn, &identity_list));
RESULT_GUARD_POSIX(conn->config->psk_selection_cb(conn, conn->config->psk_selection_ctx, &identity_list));
} else {
RESULT_GUARD(s2n_select_psk_identity(conn, &identity_list));
}
Expand Down
4 changes: 2 additions & 2 deletions tls/s2n_config.c
Original file line number Diff line number Diff line change
Expand Up @@ -846,11 +846,11 @@ int s2n_config_enable_cert_req_dss_legacy_compat(struct s2n_config *config)
return S2N_SUCCESS;
}

int s2n_config_set_psk_selection_callback(struct s2n_config *config, s2n_psk_selection_callback cb)
int s2n_config_set_psk_selection_callback(struct s2n_config *config, s2n_psk_selection_callback cb, void *context)
{
POSIX_ENSURE_REF(config);
POSIX_ENSURE_REF(cb);
config->psk_selection_cb = cb;
config->psk_selection_ctx = context;
return S2N_SUCCESS;
}

Expand Down
2 changes: 2 additions & 0 deletions tls/s2n_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ struct s2n_config {
uint16_t max_verify_cert_chain_depth;

s2n_async_pkey_fn async_pkey_cb;

s2n_psk_selection_callback psk_selection_cb;
void *psk_selection_ctx;

s2n_key_log_fn key_log_cb;
void *key_log_ctx;
Expand Down
4 changes: 2 additions & 2 deletions tls/s2n_psk.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,6 @@ int s2n_offered_psk_list_next(struct s2n_offered_psk_list *psk_list, struct s2n_
int s2n_offered_psk_list_reread(struct s2n_offered_psk_list *psk_list);
int s2n_offered_psk_list_choose_psk(struct s2n_offered_psk_list *psk_list, struct s2n_offered_psk *psk);

typedef int (*s2n_psk_selection_callback)(struct s2n_connection *conn,
typedef int (*s2n_psk_selection_callback)(struct s2n_connection *conn, void *context,
struct s2n_offered_psk_list *psk_list);
int s2n_config_set_psk_selection_callback(struct s2n_config *config, s2n_psk_selection_callback cb);
int s2n_config_set_psk_selection_callback(struct s2n_config *config, s2n_psk_selection_callback cb, void *context);