Skip to content

Commit

Permalink
[mdns] Implement Known Answer suppression for received
Browse files Browse the repository at this point in the history
queries.

This commit introduces the functionality where mDNS server is not
populating/creating an answer if the answer it would give can be found
in the answer section of the incoming query message.

Signed-off-by: Cristian Bulacu <cristian.bulacu@nxp.com>
  • Loading branch information
Cristib05 committed Nov 24, 2023
1 parent 4ba9625 commit d509a11
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 21 deletions.
243 changes: 228 additions & 15 deletions src/core/net/mdns_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,17 +693,19 @@ Message *MdnsServer::FindQueryByName(const char *aName)
return matchedQuery;
}

Header::Response MdnsServer::ResolveQuestion(const char *aName,
const Question &aQuestion,
Header &aResponseHeader,
Message &aResponseMessage,
NameCompressInfo &aCompressInfo,
bool aAdditional)
{
const Service *service = nullptr;
uint16_t qtype = aQuestion.GetType();
bool needAdditionalAaaaRecord = false;
Header::Response responseCode = Header::kResponseSuccess;
Header::Response MdnsServer::ResolveQuestion(const char *aName,
const Question &aQuestion,
Header &aResponseHeader,
Message &aResponseMessage,
NameCompressInfo &aCompressInfo,
bool aAdditional,
LinkedList<KnownAnswerEntry> &aKnownAnswersList)
{
const Service *service = nullptr;
uint16_t qtype = aQuestion.GetType();
bool needAdditionalAaaaRecord = false;
bool shouldSuppressAnswer = false;
Header::Response responseCode = Header::kResponseSuccess;

bool serviceNameMatched = false;

Expand Down Expand Up @@ -738,6 +740,22 @@ Header::Response MdnsServer::ResolveQuestion(const char *aName,
bool txtQueryMatched =
(qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAny) && instanceNameMatched;

for (KnownAnswerEntry &entry : aKnownAnswersList)
{
if (StringMatch(entry.GetServiceName(), service->GetServiceName(), kStringCaseInsensitiveMatch) &&
StringMatch(entry.GetInstanceName(), service->GetInstanceName(), kStringCaseInsensitiveMatch) &&
entry.GetRecord().GetTtl() > service->GetTtl() / 2)
{
shouldSuppressAnswer = true;
break;
}
}

if (shouldSuppressAnswer)
{
break;
}

if (ptrQueryMatched || srvQueryMatched)
{
needAdditionalAaaaRecord = true;
Expand Down Expand Up @@ -810,6 +828,134 @@ Header::Response MdnsServer::ResolveQuestion(const char *aName,
return responseCode;
}

Header::Response MdnsServer::ResolveQuestionBySrp(const char *aName,
const Question &aQuestion,
Header &aResponseHeader,
Message &aResponseMessage,
NameCompressInfo &aCompressInfo,
bool aAdditional,
LinkedList<KnownAnswerEntry> &aKnownAnswersList)
{
Error error = kErrorNone;
const Srp::Server::Host *host = nullptr;
TimeMilli now = TimerMilli::GetNow();
uint16_t qtype = aQuestion.GetType();
Header::Response response = Header::kResponseNameError;
bool shouldSuppressAnswer = false;
KnownAnswerEntry *kaElement = nullptr;

if(!aKnownAnswersList.IsEmpty())
{
kaElement = aKnownAnswersList.GetHead();
}

while ((host = Get<Server>().GetNextSrpHost(host)) != nullptr)
{
bool needAdditionalAaaaRecord = false;
const char *hostName = host->GetFullName();

// Handle PTR/SRV/TXT/ANY query
if (qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeSrv ||
qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAny)
{
const Srp::Server::Service *service = nullptr;

while ((service = Get<Server>().GetNextSrpService(*host, service)) != nullptr)
{
uint32_t instanceTtl = TimeMilli::MsecToSec(service->GetExpireTime() - TimerMilli::GetNow());
const char *instanceName = service->GetInstanceName();
char convertedInstanceName[Dns::Name::kMaxNameSize];
char convertedServiceName[Dns::Name::kMaxNameSize];
bool serviceNameMatched = service->MatchesServiceName(aName);
bool instanceNameMatched = (!service->IsSubType() && service->MatchesInstanceName(aName));
bool ptrQueryMatched =
(qtype == ResourceRecord::kTypePtr || qtype == ResourceRecord::kTypeAny) && serviceNameMatched;
bool srvQueryMatched =
(qtype == ResourceRecord::kTypeSrv || qtype == ResourceRecord::kTypeAny) && instanceNameMatched;
bool txtQueryMatched =
(qtype == ResourceRecord::kTypeTxt || qtype == ResourceRecord::kTypeAny) && instanceNameMatched;

for (; kaElement != nullptr; kaElement = kaElement->GetNext())
{
ConvertDomainName(convertedInstanceName, kaElement->GetInstanceName(), kDefaultDomainName,
Server::kDefaultDomainName);
ConvertDomainName(convertedServiceName, kaElement->GetServiceName(), kDefaultDomainName,
Server::kDefaultDomainName);

if (StringMatch(convertedServiceName, service->GetServiceName(), kStringCaseInsensitiveMatch) &&
StringMatch(convertedInstanceName, service->GetInstanceName(), kStringCaseInsensitiveMatch) &&
kaElement->GetRecord().GetTtl() > service->GetTtl() / 2)
{
shouldSuppressAnswer = true;
break;
}
}

if (shouldSuppressAnswer)
{
break;
}

if (ptrQueryMatched || srvQueryMatched)
{
needAdditionalAaaaRecord = true;
}

if (!aAdditional && ptrQueryMatched)
{
SuccessOrExit(
error = Server::AppendPtrRecord(aResponseMessage, aName, instanceName, instanceTtl, aCompressInfo));
Server::IncResourceRecordCount(aResponseHeader, aAdditional);
response = Header::kResponseSuccess;
}

if ((!aAdditional && srvQueryMatched) ||
(aAdditional && ptrQueryMatched &&
!Server::HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeSrv)))
{
SuccessOrExit(error = Server::AppendSrvRecord(aResponseMessage, instanceName, hostName, instanceTtl,
service->GetPriority(), service->GetWeight(),
service->GetPort(), aCompressInfo));
Server::IncResourceRecordCount(aResponseHeader, aAdditional);
response = Header::kResponseSuccess;
}

if ((!aAdditional && txtQueryMatched) ||
(aAdditional && ptrQueryMatched &&
!Server::HasQuestion(aResponseHeader, aResponseMessage, instanceName, ResourceRecord::kTypeTxt)))
{
SuccessOrExit(error = Server::AppendTxtRecord(aResponseMessage, instanceName, service->GetTxtData(),
service->GetTxtDataLength(), instanceTtl, aCompressInfo));
Server::IncResourceRecordCount(aResponseHeader, aAdditional);
response = Header::kResponseSuccess;
}
}
}

// Handle AAAA query
if ((!aAdditional && (qtype == ResourceRecord::kTypeAaaa || qtype == ResourceRecord::kTypeAny) &&
host->Matches(aName)) ||
(aAdditional && needAdditionalAaaaRecord &&
!Server::HasQuestion(aResponseHeader, aResponseMessage, hostName, ResourceRecord::kTypeAaaa)))
{
uint8_t addrNum;
const Ip6::Address *addrs = host->GetAddresses(addrNum);
uint32_t hostTtl = TimeMilli::MsecToSec(host->GetExpireTime() - now);

for (uint8_t i = 0; i < addrNum; i++)
{
SuccessOrExit(error = Server::AppendAaaaRecord(aResponseMessage, hostName, addrs[i], hostTtl, aCompressInfo));
Server::IncResourceRecordCount(aResponseHeader, aAdditional);
}

response = Header::kResponseSuccess;
}
}

exit:
return error == kErrorNone ? response : Header::kResponseServerFailure;
}

Header::Response MdnsServer::ResolveQuery(const Header &aRequestHeader,
const Message &aRequestMessage,
Header &aResponseHeader,
Expand All @@ -821,9 +967,34 @@ Header::Response MdnsServer::ResolveQuery(const Header &aRequestHead
uint16_t readOffset;
NameComponentsOffsetInfo nameComponentsOffsetInfo;
Header::Response responseCode = Header::kResponseSuccess;
uint16_t knownAnswerOffset = ReturnKnownAnswerOffsetFromQuery(aRequestHeader, aRequestMessage);

readOffset = sizeof(Header);

if(knownAnswerOffset)
{
ResourceRecord record;
for (uint8_t index = 0; index < aRequestHeader.GetAnswerCount(); index++)
{
char instanceName[Dns::Name::kMaxNameSize];
char serviceName[Dns::Name::kMaxNameSize];

Name::ReadName(aRequestMessage, knownAnswerOffset, serviceName, sizeof(serviceName));
ResourceRecord::ReadRecord(aRequestMessage, knownAnswerOffset, record);
Name::ReadName(aRequestMessage, knownAnswerOffset, instanceName, sizeof(instanceName));

KnownAnswerEntry *kaEntry = KnownAnswerEntry::AllocateAndInit(serviceName, instanceName, record);
if (!mReceivedKnownAnswers.ContainsMatching(*kaEntry))
{
mReceivedKnownAnswers.Push(*kaEntry);
}
else
{
kaEntry->Free();
}
}
}

/* Go through each question and attach the corresponding RRs in the answer section */
for (uint16_t i = 0; i < aRequestHeader.GetQuestionCount(); i++)
{
Expand All @@ -850,13 +1021,13 @@ Header::Response MdnsServer::ResolveQuery(const Header &aRequestHead
responseCode = Header::kResponseNameError);

SuccessOrExit(responseCode =
ResolveQuestion(name, question, aResponseHeader, aResponseMessage, aCompressInfo, false));
ResolveQuestion(name, question, aResponseHeader, aResponseMessage, aCompressInfo, false, mReceivedKnownAnswers));

#if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
// Convert from kDefaultMcastDomainName to kDefaultDomainName (.local -> default.service.arpa) for searching
memcpy(name + nameComponentsOffsetInfo.mDomainOffset, kThreadDefaultDomainName,
sizeof(kThreadDefaultDomainName));
Get<Server>().ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo, false);
ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo, false, mReceivedKnownAnswers);
#endif
}

Expand All @@ -881,19 +1052,20 @@ Header::Response MdnsServer::ResolveQuery(const Header &aRequestHead
responseCode = Header::kResponseNameError);

SuccessOrExit(responseCode =
ResolveQuestion(name, question, aResponseHeader, aResponseMessage, aCompressInfo, true));
ResolveQuestion(name, question, aResponseHeader, aResponseMessage, aCompressInfo, true, mReceivedKnownAnswers));

#if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
// Convert from kDefaultMcastDomainName to kDefaultDomainName (.local -> default.service.arpa) for
// searching
memcpy(name + nameComponentsOffsetInfo.mDomainOffset, kThreadDefaultDomainName,
sizeof(kThreadDefaultDomainName));
Get<Server>().ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo, true);
ResolveQuestionBySrp(name, question, aResponseHeader, aResponseMessage, aCompressInfo, true, mReceivedKnownAnswers);
#endif
}
}

exit:
RemoveAllKnownAnswerEntries();
return responseCode;
}

Expand Down Expand Up @@ -3563,6 +3735,47 @@ Error MdnsServer::SrpAdvertisingServiceInfo::Init(const char *aServiceName, cons
return error;
}

uint16_t MdnsServer::ReturnKnownAnswerOffsetFromQuery(const Header &aHeader, const Message &aMessage)
{
uint16_t retOffset = 0;

if (aHeader.GetAnswerCount())
{
uint16_t readOffset = sizeof(Header);
Name aName(aMessage, readOffset);

for (uint16_t i = 0; i < aHeader.GetQuestionCount(); i++)
{
Question question;

Name::CompareName(aMessage, readOffset, aName);
IgnoreError(aMessage.Read(readOffset, question));
readOffset += sizeof(question);
retOffset = readOffset;
}
}
return retOffset;
}

Error MdnsServer::KnownAnswerEntry::Init(char *aServiceName, char *aInstanceName, ResourceRecord &aRecord)
{
mServiceName.Set(aServiceName);
mInstanceName.Set(aInstanceName);
mRecord = aRecord;

return kErrorNone;
}

void MdnsServer::RemoveAllKnownAnswerEntries(void)
{
while(!mReceivedKnownAnswers.IsEmpty())
{
MdnsServer::KnownAnswerEntry *entry = mReceivedKnownAnswers.GetHead();
IgnoreError(mReceivedKnownAnswers.Remove(*entry));
entry->Free();
}
}

} // namespace ServiceDiscovery
} // namespace Dns
} // namespace ot
Expand Down
48 changes: 42 additions & 6 deletions src/core/net/mdns_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,30 @@ class MdnsServer : public InstanceLocator, private NonCopyable
Client::ServiceCallback mServiceCallback;
};

struct KnownAnswerEntry : public LinkedListEntry<KnownAnswerEntry>, public Heap::Allocatable<KnownAnswerEntry>
{
friend class LinkedListEntry<KnownAnswerEntry>;
friend class Heap::Allocatable<KnownAnswerEntry>;

public:
Error Init(char *aServiceName, char *aInstanceName, ResourceRecord &aRecord);
ResourceRecord GetRecord() { return mRecord; }
const char *GetServiceName() { return mServiceName.AsCString(); }
const char *GetInstanceName() { return mInstanceName.AsCString(); }
bool Matches(const KnownAnswerEntry &aEntry) const
{
return StringMatch(mInstanceName.AsCString(), AsNonConst(aEntry).GetInstanceName(), kStringCaseInsensitiveMatch) &&
StringMatch(mServiceName.AsCString(), AsNonConst(aEntry).GetServiceName(), kStringCaseInsensitiveMatch);
}

private:
ResourceRecord mRecord;
Heap::String mInstanceName;
Heap::String mServiceName;

KnownAnswerEntry *mNext;
};

static void HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo);
void HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo);
#if MDNS_USE_TASKLET
Expand Down Expand Up @@ -785,12 +809,21 @@ class MdnsServer : public InstanceLocator, private NonCopyable
Message &aResponseMessage,
Server::NameCompressInfo &aCompressInfo,
bool &bUnicastResponse);
Header::Response ResolveQuestion(const char *aName,
const Question &aQuestion,
Header &aResponseHeader,
Message &aResponseMessage,
NameCompressInfo &aCompressInfo,
bool aAdditional);
Header::Response ResolveQuestion(const char *aName,
const Question &aQuestion,
Header &aResponseHeader,
Message &aResponseMessage,
NameCompressInfo &aCompressInfo,
bool aAdditional,
LinkedList<KnownAnswerEntry> &aKnownAnswersList);

Header::Response ResolveQuestionBySrp(const char *aName,
const Question &aQuestion,
Header &aResponseHeader,
Message &aResponseMessage,
NameCompressInfo &aCompressInfo,
bool aAdditional,
LinkedList<KnownAnswerEntry> &aKnownAnswersList);

static void SrpAdvertisingProxyHandler(otSrpServerServiceUpdateId aId,
const otSrpServerHost *aHost,
Expand Down Expand Up @@ -819,6 +852,8 @@ class MdnsServer : public InstanceLocator, private NonCopyable
Prober *AllocateProber(bool aProbeForHost, const otSrpServerHost *aHost, uint32_t aId);
Error UpdateExistingAnnouncerDataEntries(Announcer &aAnnouncer, Service &aService);
Announcer *AllocateAnnouncer(uint32_t aId);
uint16_t ReturnKnownAnswerOffsetFromQuery(const Header &aHeader, const Message &aMessage);
void RemoveAllKnownAnswerEntries(void);

using RetryTimer = TimerMilliIn<MdnsServer, &MdnsServer::HandleTimer>;
#if MDNS_USE_TASKLET
Expand All @@ -844,6 +879,7 @@ class MdnsServer : public InstanceLocator, private NonCopyable
LinkedList<Prober> mProbingInstances;
LinkedList<Announcer> mAnnouncingInstances;
Callback<MdnsProbingCallback> mCallback;
LinkedList<KnownAnswerEntry> mReceivedKnownAnswers;
};

} // namespace ServiceDiscovery
Expand Down

0 comments on commit d509a11

Please sign in to comment.