Skip to content

Commit

Permalink
[nla,transport] move public key retrieval to transport IO.
Browse files Browse the repository at this point in the history
  • Loading branch information
llyzs authored and mfleisz committed Jan 3, 2024
1 parent 87557b1 commit 33447dc
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
3 changes: 3 additions & 0 deletions include/freerdp/transport_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ extern "C"
typedef BOOL (*pTransportAttach)(rdpTransport* transport, int sockfd);
typedef int (*pTransportRWFkt)(rdpTransport* transport, wStream* s);
typedef SSIZE_T (*pTransportRead)(rdpTransport* transport, BYTE* data, size_t bytes);
typedef BOOL (*pTransportGetPublicKey)(rdpTransport* transport, const BYTE** data,
DWORD* length);

struct rdp_transport_io
{
Expand All @@ -52,6 +54,7 @@ extern "C"
pTransportRWFkt ReadPdu; /* Reads a whole PDU from the transport */
pTransportRWFkt WritePdu; /* Writes a whole PDU to the transport */
pTransportRead ReadBytes; /* Reads up to a requested amount of bytes from the transport */
pTransportGetPublicKey GetPublicKey;
};
typedef struct rdp_transport_io rdpTransportIo;

Expand Down
21 changes: 13 additions & 8 deletions libfreerdp/core/nla.c
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,15 @@ static int nla_client_init(rdpNla* nla)
if (!credssp_auth_setup_client(nla->auth, "TERMSRV", hostname, nla->identity, nla->pkinitArgs))
return -1;

rdpTls* tls = transport_get_tls(nla->transport);

if (!tls)
const BYTE* data = NULL;
DWORD length = 0;
if (!transport_get_public_key(nla->transport, &data, &length))
{
WLog_ERR(TAG, "Unknown NLA transport layer");
WLog_ERR(TAG, "Failed to get public key");
return -1;
}

if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, tls->PublicKey, 0, tls->PublicKeyLength))
if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, data, 0, length))
{
WLog_ERR(TAG, "Failed to allocate sspi secBuffer");
return -1;
Expand Down Expand Up @@ -662,10 +662,15 @@ static int nla_server_init(rdpNla* nla)
{
WINPR_ASSERT(nla);

rdpTls* tls = transport_get_tls(nla->transport);
WINPR_ASSERT(tls);
const BYTE* data = NULL;
DWORD length = 0;
if (!transport_get_public_key(nla->transport, &data, &length))
{
WLog_ERR(TAG, "Failed to get public key");
return -1;
}

if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, tls->PublicKey, 0, tls->PublicKeyLength))
if (!nla_sec_buffer_alloc_from_data(&nla->PublicKey, data, 0, length))
{
WLog_ERR(TAG, "Failed to allocate SecBuffer for public key");
return -1;
Expand Down
19 changes: 19 additions & 0 deletions libfreerdp/core/transport.c
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,24 @@ static int transport_default_write(rdpTransport* transport, wStream* s)
return status;
}

BOOL transport_get_public_key(rdpTransport* transport, const BYTE** data, DWORD* length)
{
return IFCALLRESULT(FALSE, transport->io.GetPublicKey, transport, data, length);
}

static BOOL transport_default_get_public_key(rdpTransport* transport, const BYTE** data,
DWORD* length)
{
rdpTls* tls = transport_get_tls(transport);
if (!tls)
return FALSE;

*data = tls->PublicKey;
*length = tls->PublicKeyLength;

return TRUE;
}

DWORD transport_get_event_handles(rdpTransport* transport, HANDLE* events, DWORD count)
{
DWORD nCount = 0; /* always the reread Event */
Expand Down Expand Up @@ -1535,6 +1553,7 @@ rdpTransport* transport_new(rdpContext* context)
transport->io.ReadPdu = transport_default_read_pdu;
transport->io.WritePdu = transport_default_write;
transport->io.ReadBytes = transport_read_layer;
transport->io.GetPublicKey = transport_default_get_public_key;

transport->context = context;
transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
Expand Down
3 changes: 3 additions & 0 deletions libfreerdp/core/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ FREERDP_LOCAL BOOL transport_accept_rdstls(rdpTransport* transport);
FREERDP_LOCAL int transport_read_pdu(rdpTransport* transport, wStream* s);
FREERDP_LOCAL int transport_write(rdpTransport* transport, wStream* s);

FREERDP_LOCAL BOOL transport_get_public_key(rdpTransport* transport, const BYTE** data,
DWORD* length);

#if defined(WITH_FREERDP_DEPRECATED)
FREERDP_LOCAL void transport_get_fds(rdpTransport* transport, void** rfds, int* rcount);
#endif
Expand Down

0 comments on commit 33447dc

Please sign in to comment.