Skip to content

Commit

Permalink
Async DNS over TCP (#25460)
Browse files Browse the repository at this point in the history
  • Loading branch information
raboof committed Sep 27, 2018
1 parent d952fd4 commit 42ed417
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 13 deletions.
@@ -0,0 +1,59 @@
/*
* Copyright (C) 2018 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.io.dns

import java.net.InetAddress

import akka.io.dns.DnsProtocol.{ Ip, RequestType, Srv }
import akka.io.{ Dns, IO }
import akka.pattern.ask
import akka.testkit.AkkaSpec
import akka.util.Timeout

import scala.concurrent.duration._

/*
Relies on being run while online
*/
class OnlineAsyncDnsResolverIntegrationSpec extends AkkaSpec(
"""
akka.loglevel = DEBUG
akka.io.dns.resolver = async-dns
akka.io.dns.async-dns.nameservers = default
""") {
val duration = 10.seconds
implicit val timeout = Timeout(duration)

"Resolver" must {

"resolve mixed A/AAAA records" in {
val name = "akka.io"
val answer = resolve(name)
answer.name shouldEqual name

answer.records.collect { case r: ARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("104.31.90.133"),
InetAddress.getByName("104.31.91.133")
)

answer.records.collect { case r: AAAARecord r.ip }.toSet shouldEqual Set(
InetAddress.getByName("2606:4700:30::681f:5a85"),
InetAddress.getByName("2606:4700:30::681f:5b85")
)
}

"resolve queries that are too big for UDP" in {
val name = "many.bzzt.net"
val answer = resolve(name)
answer.name shouldEqual name
answer.records.length should be(48)
}

def resolve(name: String, requestType: RequestType = Ip()): DnsProtocol.Resolved = {
(IO(Dns) ? DnsProtocol.Resolve(name, requestType)).mapTo[DnsProtocol.Resolved].futureValue
}

}
}
36 changes: 23 additions & 13 deletions akka-actor/src/main/scala/akka/io/dns/internal/DnsClient.scala
Expand Up @@ -7,7 +7,7 @@ package akka.io.dns.internal
import java.net.{ InetAddress, InetSocketAddress }

import akka.actor.Status.Failure
import akka.actor.{ Actor, ActorLogging, ActorRef, NoSerializationVerificationNeeded, Stash }
import akka.actor.{ Actor, ActorLogging, ActorRef, NoSerializationVerificationNeeded, Props, Stash }
import akka.annotation.InternalApi
import akka.io.dns.{ RecordClass, RecordType, ResourceRecord }
import akka.io.{ IO, Udp }
Expand Down Expand Up @@ -40,7 +40,9 @@ import scala.util.Try

IO(Udp) ! Udp.Bind(self, new InetSocketAddress(InetAddress.getByAddress(Array.ofDim(4)), 0))

var inflightRequests: Map[Short, ActorRef] = Map.empty
var inflightRequests: Map[Short, (ActorRef, Message)] = Map.empty

val tcpDnsClient = context.actorOf(Props(classOf[TcpDnsClient], ns), "tcpDnsClient")

def receive: Receive = {
case Udp.Bound(local)
Expand All @@ -65,22 +67,22 @@ import scala.util.Try
inflightRequests -= id
case Question4(id, name)
log.debug("Resolving [{}] (A)", name)
inflightRequests += (id -> sender())
val msg = message(name, id, RecordType.A)
inflightRequests += (id -> (sender(), msg))
log.debug(s"Message [{}] to [{}]: [{}]", id, ns, msg)
socket ! Udp.Send(msg.write(), ns)

case Question6(id, name)
log.debug("Resolving [{}] (AAAA)", name)
inflightRequests += (id -> sender())
val msg = message(name, id, RecordType.AAAA)
inflightRequests += (id -> (sender(), msg))
log.debug(s"Message to [{}]: [{}]", ns, msg)
socket ! Udp.Send(msg.write(), ns)

case SrvQuestion(id, name)
log.debug("Resolving [{}] (SRV)", name)
inflightRequests += (id -> sender())
val msg = message(name, id, RecordType.SRV)
inflightRequests += (id -> (sender(), msg))
log.debug(s"Message to {}: msg", ns, msg)
socket ! Udp.Send(msg.write(), ns)

Expand All @@ -91,9 +93,10 @@ import scala.util.Try
// best effort, don't throw
Try {
val msg = Message.parse(send.payload)
inflightRequests.get(msg.id).foreach { s
s ! Failure(new RuntimeException("Send failed to nameserver"))
inflightRequests -= msg.id
inflightRequests.get(msg.id).foreach {
case (s, _)
s ! Failure(new RuntimeException("Send failed to nameserver"))
inflightRequests -= msg.id
}
}
case _
Expand All @@ -105,18 +108,25 @@ import scala.util.Try
log.debug(s"Decoded: $msg")
// TODO remove me when #25460 is implemented
if (msg.flags.isTruncated) {
log.warning("DNS response truncated and fallback to TCP is not yet implemented. See #25460")
log.debug("DNS response truncated, falling back to TCP")
inflightRequests.get(msg.id) match {
case Some((_, msg))
tcpDnsClient ! msg
case _
log.debug("Client for id {} not found. Discarding unsuccessful response.", msg.id)
}
} else {
val (recs, additionalRecs) = if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
self ! Answer(msg.id, recs, additionalRecs)
}
val (recs, additionalRecs) = if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
val response = Answer(msg.id, recs, additionalRecs)
case response: Answer
inflightRequests.get(response.id) match {
case Some(reply)
case Some((reply, _))
reply ! response
inflightRequests -= response.id
case None
log.debug("Client for id {} not found. Discarding response.", response.id)
}

case Udp.Unbind socket ! Udp.Unbind
case Udp.Unbound context.stop(self)
}
Expand Down
71 changes: 71 additions & 0 deletions akka-actor/src/main/scala/akka/io/dns/internal/TcpDnsClient.scala
@@ -0,0 +1,71 @@
/*
* Copyright (C) 2018 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.io.dns.internal

import java.net.InetSocketAddress

import akka.actor.{ Actor, ActorLogging, ActorRef, Stash }
import akka.annotation.InternalApi
import akka.io.Tcp._
import akka.io.dns.internal.DnsClient.{ Answer, DnsQuestion, Question4 }
import akka.io.{ IO, Tcp }
import akka.util.ByteString

/**
* INTERNAL API
*/
@InternalApi private[akka] class TcpDnsClient(ns: InetSocketAddress) extends Actor with ActorLogging with Stash {

import context.system

log.warning("Connecting to [{}]", ns)
IO(Tcp) ! Tcp.Connect(ns)

override def receive: Receive = {
case CommandFailed(_: Connect)
log.warning("Failed to connect to [{}]", ns)
// TODO
case _: Tcp.Connected
log.debug(s"Connected to TCP address [{}]", ns)
val connection = sender()
context.become(ready(connection))
connection ! Register(self)
unstashAll()
case _: Message
stash()
}

def encodeLength(length: Int): ByteString =
ByteString((length / 256).toByte, length.toByte)

def decodeLength(data: ByteString): Int =
((data(0).toInt + 256) % 256) * 256 + ((data(1) + 256) % 256)

def ready(connection: ActorRef): Receive = {
case msg: Message
log.warning("Sending message to connection")
val bytes = msg.write()
connection ! Tcp.Write(encodeLength(bytes.length))
connection ! Tcp.Write(bytes)
case CommandFailed(_: Write)
// TODO
log.warning("Write failed")
case Received(data)
log.warning("Received data")
require(data.length > 2, "Expected a response datagram starting with the size")
val expectedLength = decodeLength(data)
log.warning(s"First 2 bytes are ${data(0)} and ${data(1)}, totalling $expectedLength")
require(data.length == expectedLength + 2, s"Expected a full response datagram of length ${expectedLength}, got ${data.length - 2} data bytes instead.")
val msg = Message.parse(data.drop(2))
log.debug(s"Decoded: $msg")
if (msg.flags.isTruncated) {
log.warning("TCP DNS response truncated")
}
val (recs, additionalRecs) = if (msg.flags.responseCode == ResponseCode.SUCCESS) (msg.answerRecs, msg.additionalRecs) else (Nil, Nil)
context.parent ! Answer(msg.id, recs, additionalRecs)
case PeerClosed
log.warning("Peer closed")
}
}

0 comments on commit 42ed417

Please sign in to comment.