Skip to content

Commit

Permalink
all: remove wildcard requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Jan 25, 2021
1 parent 3b67948 commit f046352
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 103 deletions.
27 changes: 4 additions & 23 deletions internal/dnsforward/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,34 +256,15 @@ func ValidateClientID(clientID string) (err error) {
// is the server name of the host. cliSrvName is the server name as sent by the
// client.
func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) {
if len(hostSrvName) < 2 {
return "", fmt.Errorf("bad host server name %q", hostSrvName)
}

// Firstly, check a simple host-to-host match.
if hostSrvName[0] != '*' && hostSrvName[1] != '.' {
if hostSrvName != cliSrvName {
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
}

return "", nil
}

// Secondly, check for a simple "example.com" with "*.example.com"
// match. If matched, we have no client ID.
if cliSrvName == hostSrvName[2:] {
if hostSrvName == cliSrvName {
return "", nil
}

// Next, make sure that this is even the right host. That is, error out
// on stuff like "client.example.net" with "*.example.com".
if !strings.HasSuffix(cliSrvName, hostSrvName[1:]) {
return "", fmt.Errorf("client server name %q doesn't match host server name wildcard %q", cliSrvName, hostSrvName)
if !strings.HasSuffix(cliSrvName, hostSrvName) {
return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName)
}

// Slice away the suffix. We already know that cliSrvName matches, so
// no need for strings.StripSuffix.
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)+1]
clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1]
err = ValidateClientID(clientID)
if err != nil {
return "", fmt.Errorf("invalid client id: %w", err)
Expand Down
28 changes: 6 additions & 22 deletions internal/dnsforward/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,49 +53,33 @@ func TestProcessClientID(t *testing.T) {
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_no_wildcard_error",
name: "tls_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.com" doesn't match host server name "example.com"`,
wantRes: resultCodeError,
}, {
name: "tls_no_client_id_wildcard",
proto: proxy.ProtoTLS,
hostSrvName: "*.example.com",
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_client_id_wildcard",
proto: proxy.ProtoTLS,
hostSrvName: "*.example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
wantRes: resultCodeSuccess,
}, {
name: "tls_client_id_wildcard_error",
name: "tls_client_id_hostname_error",
proto: proxy.ProtoTLS,
hostSrvName: "*.example.com",
hostSrvName: "example.com",
cliSrvName: "cli.example.net",
wantClientID: "",
wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name wildcard "*.example.com"`,
wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`,
wantRes: resultCodeError,
}, {
name: "tls_invalid_client_id",
proto: proxy.ProtoTLS,
hostSrvName: "*.example.com",
hostSrvName: "example.com",
cliSrvName: "!!!.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`,
wantRes: resultCodeError,
}, {
name: "tls_client_id_too_long",
proto: proxy.ProtoTLS,
hostSrvName: "*.example.com",
hostSrvName: "example.com",
cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com",
wantClientID: "",
wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`,
Expand Down
13 changes: 0 additions & 13 deletions internal/home/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net"
"os"
"path/filepath"
"strings"
"sync"

"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
Expand Down Expand Up @@ -120,18 +119,6 @@ type tlsConfigSettings struct {
dnsforward.TLSConfig `yaml:",inline" json:",inline"`
}

// hostname returns the hostname from the server name of the configuration.
//
// TODO(a.garipov): Think of a better way to do this. Perhaps, caching on
// change?
func (s *tlsConfigSettings) hostname() (h string) {
if s == nil {
return ""
}

return strings.TrimPrefix(s.ServerName, "*.")
}

// initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{
BindPort: 3000,
Expand Down
2 changes: 1 addition & 1 deletion internal/home/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func getDNSEncryption() (de dnsEncryption) {
Context.tls.WriteDiskConfig(&tlsConf)

if tlsConf.Enabled && len(tlsConf.ServerName) != 0 {
hostname := tlsConf.hostname()
hostname := tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
addr := hostname
if tlsConf.PortHTTPS != 443 {
Expand Down
5 changes: 2 additions & 3 deletions internal/home/home.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,11 +583,10 @@ func printHTTPAddresses(proto string) {

var hostStr string
if proto == "https" && tlsConf.ServerName != "" {
hostname := tlsConf.hostname()
if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", hostname)
log.Printf("Go to https://%s", tlsConf.ServerName)
} else {
log.Printf("Go to https://%s:%s", hostname, port)
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
}
} else if config.BindHost.IsUnspecified() {
log.Println("AdGuard Home is available on the following addresses:")
Expand Down
36 changes: 5 additions & 31 deletions internal/home/mobileconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"net/url"
"path"
"strings"

"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/log"
Expand Down Expand Up @@ -53,28 +52,25 @@ const (
)

func getMobileConfig(d dnsSettings) ([]byte, error) {
srvName := strings.TrimPrefix(d.ServerName, "*.")

var dspName string
switch d.DNSProtocol {
case dnsProtoHTTPS:
dspName = fmt.Sprintf("%s DoH", srvName)

d.ServerName = srvName
dspName = fmt.Sprintf("%s DoH", d.ServerName)

u := &url.URL{
Scheme: "https",
Host: srvName,
Host: d.ServerName,
Path: "/dns-query",
}
if d.clientID != "" {
u.Path = path.Join(u.Path, d.clientID)
}

d.ServerURL = u.String()
case dnsProtoTLS:
dspName = fmt.Sprintf("%s DoT", srvName)
dspName = fmt.Sprintf("%s DoT", d.ServerName)
if d.clientID != "" {
d.ServerName = d.clientID + "." + srvName
d.ServerName = d.clientID + "." + d.ServerName
}
default:
return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol)
Expand Down Expand Up @@ -103,14 +99,6 @@ func getMobileConfig(d dnsSettings) ([]byte, error) {
return plist.MarshalIndent(data, plist.XMLFormat, "\t")
}

func canUseClientID(srvName, clientID string) (ok bool) {
if clientID == "" {
return true
}

return strings.HasPrefix(srvName, "*.")
}

func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
var err error

Expand Down Expand Up @@ -149,20 +137,6 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
return
}

if !canUseClientID(host, clientID) {
w.WriteHeader(http.StatusBadRequest)

msg := fmt.Sprintf("can't use client_id with server name %q", host)
err = json.NewEncoder(w).Encode(&jsonError{
Message: msg,
})
if err != nil {
log.Debug("writing 400 json response: %s", err)
}

return
}

d := dnsSettings{
DNSProtocol: dnsp,
ServerName: host,
Expand Down
4 changes: 2 additions & 2 deletions internal/home/mobileconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestHandleMobileConfigDOH(t *testing.T) {
})

t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=*.example.org&client_id=cli42", nil)
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)

w := httptest.NewRecorder()
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestHandleMobileConfigDOT(t *testing.T) {
})

t.Run("client_id", func(t *testing.T) {
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=*.example.org&client_id=cli42", nil)
r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil)
assert.Nil(t, err)

w := httptest.NewRecorder()
Expand Down
10 changes: 2 additions & 8 deletions openapi/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1104,16 +1104,13 @@
`tls.server_name` from the configuration file is used. If
`tls.server_name` is not set, the API returns an error with a 500
status.
Can be a wildcard host, for example `"*.example.org"`.
'example': 'example.org'
'in': 'query'
'name': 'host'
'schema':
'type': 'string'
- 'description': >
Client ID. This can only be set it the host is a wildcard, for
example `"*.example.org"`.
Client ID.
'example': 'client-1'
'in': 'query'
'name': 'client_id'
Expand Down Expand Up @@ -1141,16 +1138,13 @@
`tls.server_name` from the configuration file is used. If
`tls.server_name` is not set, the API returns an error with a 500
status.
Can be a wildcard host, for example `"*.example.org"`.
'example': 'example.org'
'in': 'query'
'name': 'host'
'schema':
'type': 'string'
- 'description': >
Client ID. This can only be set it the host is a wildcard, for
example `"*.example.org"`.
Client ID.
'example': 'client-1'
'in': 'query'
'name': 'client_id'
Expand Down

0 comments on commit f046352

Please sign in to comment.