diff --git a/bin/main.c b/bin/main.c index ffd9655..a6c5f87 100644 --- a/bin/main.c +++ b/bin/main.c @@ -120,29 +120,6 @@ keydup(const options_t *o, uint8_t **key) return size; } -static ssize_t -srv_psk_cb(void *m, const char *username, uint8_t **key) -{ - const options_t *o = m; - - if (strcmp(username, o->psku) != 0) - return -1; - - return keydup(o, key); -} - -static ssize_t -clt_psk_cb(void *m, char **username, uint8_t **key) -{ - const options_t *o = m; - - *username = strdup(o->psku); - if (!*username) - return -1; - - return keydup(o, key); -} - static status_t on_conn(options_t *opts, int con, int in, int out, const struct addrinfo *ai) { @@ -154,31 +131,91 @@ on_conn(options_t *opts, int con, int in, int out, const struct addrinfo *ai) if (ai->ai_protocol == IPPROTO_TLS) { int ret; + int true_indicator = 1; + int false_indicator = 0; + socklen_t *val_len = NULL; + size_t size; + const char *user = NULL; + uint8_t **key = NULL; if (opts->listen) { - tls_srv_handshake_t srv = { .misc = opts }; - - if (opts->psku) - srv.psk = srv_psk_cb; - ret = non_setsockopt(con, IPPROTO_TLS, - TLS_SRV_HANDSHAKE, &srv, sizeof(srv)); + TLS_IS_SERVER, &true_indicator, sizeof(true_indicator)); + + if (ret != 0) + goto fail; + + if (opts->psku) { + ret = non_setsockopt(con, IPPROTO_TLS, + TLS_PSK, &true_indicator, sizeof(true_indicator)); + } + + if (ret != 0) + goto fail; + + while (handshake(con) == -1) { + fprintf(stderr, "Entering server handshake function.\nErrno is %d\n", errno); + switch (errno) { + case ENOKEY: + ret = getsockopt(con, IPPROTO_TLS, + TLS_PSK_USER, &user, val_len); + + if (ret != 0 || *val_len != strlen(user)) + goto fail; + else { + if (strcmp(user, opts->psku) == 0) { + size = keydup(opts, key); + ret = non_setsockopt(con, IPPROTO_TLS, + TLS_PSK_KEY, key, size); + if (ret != 0) + goto fail; + } + } + break; + } + } } else { - tls_clt_handshake_t clt = { .misc = opts }; - - if (opts->psku) - clt.psk = clt_psk_cb; - ret = non_setsockopt(con, IPPROTO_TLS, - TLS_CLT_HANDSHAKE, &clt, sizeof(clt)); + TLS_IS_SERVER, &false_indicator, sizeof(false_indicator)); + + if (ret != 0) + goto fail; + + while (handshake(con) == -1) { + int val; + fprintf(stderr, "Entering client handshake function.\nErrno is %d.\n", errno); + switch (errno) { + case ENOKEY: + ret = getsockopt(con, IPPROTO_TLS, + TLS_PSK, &val, val_len); + + if (ret != 0) + goto fail; + + if (val) { + ret = non_setsockopt(con, IPPROTO_TLS, TLS_PSK_USER, + opts->psku, strlen(opts->psku)); + if (ret != 0) + goto fail; + + size = keydup(opts, key); + ret = non_setsockopt(con, IPPROTO_TLS, + TLS_PSK_KEY, key, size); + if (ret != 0) + goto fail; + } + break; + } + } } - if (ret != 0) { + fail: fprintf(stderr, "%m: Unable to complete TLS handshake!\n"); shutdown(con, SHUT_RDWR); return STATUS_FAILURE; } - } + + //fprintf(stderr, "Complete TLS handshake!\n"); while (poll(pfds, 2, -1) >= 0) { char buffer[64 * 1024] = {};