diff --git a/main.go b/main.go index dc53a047d..3d8d830c2 100644 --- a/main.go +++ b/main.go @@ -57,6 +57,9 @@ type Options struct { // Server listen ports ListenPorts []int `yaml:"listen-ports" short:"p" long:"port" description:"Listening ports. Zero value disables TCP and UDP listeners"` + // HTTP listen ports + HTTPListenPorts []int `yaml:"http-port" short:"i" long:"http-port" description:"Listening ports for DNS-over-HTTP"` + // HTTPS listen ports HTTPSListenPorts []int `yaml:"https-port" short:"s" long:"https-port" description:"Listening ports for DNS-over-HTTPS"` @@ -245,9 +248,9 @@ func main() { if err != nil { if flagsErr, ok := err.(*goFlags.Error); ok && flagsErr.Type == goFlags.ErrHelp { os.Exit(0) + } else { + log.Fatalf("failed to parse args: %v", err) } - - os.Exit(1) } run(options) @@ -621,6 +624,13 @@ func initListenAddrs(config *proxy.Config, options *Options) { } } + for _, port := range options.HTTPListenPorts { + for _, ip := range listenIPs { + a := net.TCPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16(port))) + config.HTTPListenAddr = append(config.HTTPSListenAddr, a) + } + } + if config.DNSCryptResolverCert != nil && config.DNSCryptProviderName != "" { for _, port := range options.DNSCryptListenPorts { for _, ip := range listenIPs { diff --git a/proxy/config.go b/proxy/config.go index 33c0fd296..b5ff64eb1 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -49,6 +49,7 @@ type Config struct { UDPListenAddr []*net.UDPAddr // if nil, then it does not listen for UDP TCPListenAddr []*net.TCPAddr // if nil, then it does not listen for TCP + HTTPListenAddr []*net.TCPAddr // if nil, then it does not listen for HTTP (DoH) HTTPSListenAddr []*net.TCPAddr // if nil, then it does not listen for HTTPS (DoH) TLSListenAddr []*net.TCPAddr // if nil, then it does not listen for TLS (DoT) QUICListenAddr []*net.UDPAddr // if nil, then it does not listen for QUIC (DoQ) @@ -332,6 +333,7 @@ func (p *Proxy) hasListenAddrs() bool { return p.UDPListenAddr != nil || p.TCPListenAddr != nil || p.TLSListenAddr != nil || + p.HTTPListenAddr != nil || p.HTTPSListenAddr != nil || p.QUICListenAddr != nil || p.DNSCryptUDPListenAddr != nil || diff --git a/proxy/proxy.go b/proxy/proxy.go index 5aa8b4bd1..33e0bc763 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -86,6 +86,9 @@ type Proxy struct { // quicListen are the listened QUIC connections. quicListen []*quic.EarlyListener + httpListen []net.Listener // HTTP listeners + httpServer *http.Server // HTTP server instance + // httpsListen are the listened HTTPS connections. httpsListen []net.Listener diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 53d7b4f44..d4cbe2a76 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1225,6 +1225,7 @@ func createTestProxy(t *testing.T, tlsConfig *tls.Config) *Proxy { } else { p.UDPListenAddr = []*net.UDPAddr{{IP: ip, Port: 0}} p.TCPListenAddr = []*net.TCPAddr{{IP: ip, Port: 0}} + p.HTTPListenAddr = []*net.TCPAddr{{Port: 0, IP: net.ParseIP(listenIP)}} } upstreams := make([]upstream.Upstream, 0) dnsUpstream, err := upstream.AddressToUpstream( diff --git a/proxy/server.go b/proxy/server.go index a01b2bc78..9dfad3135 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -23,6 +23,11 @@ func (p *Proxy) startListeners(ctx context.Context) error { return err } + err = p.createHTTPListeners() + if err != nil { + return err + } + err = p.createTLSListeners() if err != nil { return err @@ -55,6 +60,10 @@ func (p *Proxy) startListeners(ctx context.Context) error { go p.tcpPacketLoop(l, ProtoTLS, p.requestsSema) } + for _, l := range p.httpListen { + go func(l net.Listener) { _ = p.httpServer.Serve(l) }(l) + } + for _, l := range p.httpsListen { go func(l net.Listener) { _ = p.httpsServer.Serve(l) }(l) } diff --git a/proxy/server_https.go b/proxy/server_https.go index 1175590db..b2ede1a6f 100644 --- a/proxy/server_https.go +++ b/proxy/server_https.go @@ -19,10 +19,10 @@ import ( "golang.org/x/net/http2" ) -// listenHTTP creates instances of TLS listeners that will be used to run an +// listenHTTPS creates instances of TLS listeners that will be used to run an // H1/H2 server. Returns the address the listener actually listens to (useful // in the case if port 0 is specified). -func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) { +func (p *Proxy) listenHTTPS(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) { tcpListen, err := net.ListenTCP("tcp", addr) if err != nil { return nil, fmt.Errorf("tcp listener: %w", err) @@ -38,6 +38,21 @@ func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) { return tcpListen.Addr().(*net.TCPAddr), nil } +// listenHTTP creates instances of TCP listeners that will be used to run an +// H1 server. Returns the address the listener actually listens to (useful +// in the case if port 0 is specified). +func (p *Proxy) listenHTTP(addr *net.TCPAddr) (laddr *net.TCPAddr, err error) { + tcpListen, err := net.ListenTCP("tcp", addr) + if err != nil { + return nil, fmt.Errorf("tcp listener: %w", err) + } + log.Info("Listening to http://%s", tcpListen.Addr()) + + p.httpListen = append(p.httpListen, tcpListen) + + return tcpListen.Addr().(*net.TCPAddr), nil +} + // listenH3 creates instances of QUIC listeners that will be used for running // an HTTP/3 server. func (p *Proxy) listenH3(addr *net.UDPAddr) (err error) { @@ -70,10 +85,9 @@ func (p *Proxy) createHTTPSListeners() (err error) { for _, addr := range p.HTTPSListenAddr { log.Info("Creating an HTTPS server") - - tcpAddr, lErr := p.listenHTTP(addr) - if lErr != nil { - return fmt.Errorf("failed to start HTTPS server on %s: %w", addr, lErr) + tcpAddr, err := p.listenHTTPS(addr) + if err != nil { + return fmt.Errorf("failed to start HTTPS server on %s: %w", addr, err) } if p.HTTP3 { @@ -90,6 +104,26 @@ func (p *Proxy) createHTTPSListeners() (err error) { return nil } +// createHTTPListeners creates the cleartext HTTP listener for DNS-over-HTTPS (behind a proxy doing TLS termination). +func (p *Proxy) createHTTPListeners() (err error) { + p.httpServer = &http.Server{ + Handler: p, + ReadHeaderTimeout: defaultTimeout, + WriteTimeout: defaultTimeout, + } + + for _, addr := range p.HTTPListenAddr { + log.Info("Creating an HTTP server") + + _, err := p.listenHTTP(addr) + if err != nil { + return fmt.Errorf("failed to start HTTP server on %s: %w", addr, err) + } + } + + return nil +} + // ServeHTTP is the http.Handler implementation that handles DoH queries. // Here is what it returns: //