diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 14ee55decb5..4c94fc9929d 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -256,34 +256,15 @@ func ValidateClientID(clientID string) (err error) { // is the server name of the host. cliSrvName is the server name as sent by the // client. func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) { - if len(hostSrvName) < 2 { - return "", fmt.Errorf("bad host server name %q", hostSrvName) - } - - // Firstly, check a simple host-to-host match. - if hostSrvName[0] != '*' && hostSrvName[1] != '.' { - if hostSrvName != cliSrvName { - return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) - } - - return "", nil - } - - // Secondly, check for a simple "example.com" with "*.example.com" - // match. If matched, we have no client ID. - if cliSrvName == hostSrvName[2:] { + if hostSrvName == cliSrvName { return "", nil } - // Next, make sure that this is even the right host. That is, error out - // on stuff like "client.example.net" with "*.example.com". - if !strings.HasSuffix(cliSrvName, hostSrvName[1:]) { - return "", fmt.Errorf("client server name %q doesn't match host server name wildcard %q", cliSrvName, hostSrvName) + if !strings.HasSuffix(cliSrvName, hostSrvName) { + return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) } - // Slice away the suffix. We already know that cliSrvName matches, so - // no need for strings.StripSuffix. - clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)+1] + clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] err = ValidateClientID(clientID) if err != nil { return "", fmt.Errorf("invalid client id: %w", err) diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 9bd465e283a..0ae18d54f20 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -53,41 +53,25 @@ func TestProcessClientID(t *testing.T) { wantErrMsg: "", wantRes: resultCodeSuccess, }, { - name: "tls_no_wildcard_error", + name: "tls_client_id", proto: proxy.ProtoTLS, hostSrvName: "example.com", cliSrvName: "cli.example.com", - wantClientID: "", - wantErrMsg: `client id check: client server name "cli.example.com" doesn't match host server name "example.com"`, - wantRes: resultCodeError, - }, { - name: "tls_no_client_id_wildcard", - proto: proxy.ProtoTLS, - hostSrvName: "*.example.com", - cliSrvName: "example.com", - wantClientID: "", - wantErrMsg: "", - wantRes: resultCodeSuccess, - }, { - name: "tls_client_id_wildcard", - proto: proxy.ProtoTLS, - hostSrvName: "*.example.com", - cliSrvName: "cli.example.com", wantClientID: "cli", wantErrMsg: "", wantRes: resultCodeSuccess, }, { - name: "tls_client_id_wildcard_error", + name: "tls_client_id_hostname_error", proto: proxy.ProtoTLS, - hostSrvName: "*.example.com", + hostSrvName: "example.com", cliSrvName: "cli.example.net", wantClientID: "", - wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name wildcard "*.example.com"`, + wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`, wantRes: resultCodeError, }, { name: "tls_invalid_client_id", proto: proxy.ProtoTLS, - hostSrvName: "*.example.com", + hostSrvName: "example.com", cliSrvName: "!!!.example.com", wantClientID: "", wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`, @@ -95,7 +79,7 @@ func TestProcessClientID(t *testing.T) { }, { name: "tls_client_id_too_long", proto: proxy.ProtoTLS, - hostSrvName: "*.example.com", + hostSrvName: "example.com", cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com", wantClientID: "", wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`, diff --git a/internal/home/config.go b/internal/home/config.go index dc455dcee77..65a9401c4e9 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -6,7 +6,6 @@ import ( "net" "os" "path/filepath" - "strings" "sync" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" @@ -120,18 +119,6 @@ type tlsConfigSettings struct { dnsforward.TLSConfig `yaml:",inline" json:",inline"` } -// hostname returns the hostname from the server name of the configuration. -// -// TODO(a.garipov): Think of a better way to do this. Perhaps, caching on -// change? -func (s *tlsConfigSettings) hostname() (h string) { - if s == nil { - return "" - } - - return strings.TrimPrefix(s.ServerName, "*.") -} - // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ BindPort: 3000, diff --git a/internal/home/dns.go b/internal/home/dns.go index 040e7476d1c..b0f52f71da1 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -216,7 +216,7 @@ func getDNSEncryption() (de dnsEncryption) { Context.tls.WriteDiskConfig(&tlsConf) if tlsConf.Enabled && len(tlsConf.ServerName) != 0 { - hostname := tlsConf.hostname() + hostname := tlsConf.ServerName if tlsConf.PortHTTPS != 0 { addr := hostname if tlsConf.PortHTTPS != 443 { diff --git a/internal/home/home.go b/internal/home/home.go index f72b47ed461..939d501fdda 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -583,11 +583,10 @@ func printHTTPAddresses(proto string) { var hostStr string if proto == "https" && tlsConf.ServerName != "" { - hostname := tlsConf.hostname() if tlsConf.PortHTTPS == 443 { - log.Printf("Go to https://%s", hostname) + log.Printf("Go to https://%s", tlsConf.ServerName) } else { - log.Printf("Go to https://%s:%s", hostname, port) + log.Printf("Go to https://%s:%s", tlsConf.ServerName, port) } } else if config.BindHost.IsUnspecified() { log.Println("AdGuard Home is available on the following addresses:") diff --git a/internal/home/mobileconfig.go b/internal/home/mobileconfig.go index 185a2b27f3d..3953e2e64a5 100644 --- a/internal/home/mobileconfig.go +++ b/internal/home/mobileconfig.go @@ -6,7 +6,6 @@ import ( "net/http" "net/url" "path" - "strings" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/log" @@ -53,28 +52,25 @@ const ( ) func getMobileConfig(d dnsSettings) ([]byte, error) { - srvName := strings.TrimPrefix(d.ServerName, "*.") - var dspName string switch d.DNSProtocol { case dnsProtoHTTPS: - dspName = fmt.Sprintf("%s DoH", srvName) - - d.ServerName = srvName + dspName = fmt.Sprintf("%s DoH", d.ServerName) u := &url.URL{ Scheme: "https", - Host: srvName, + Host: d.ServerName, Path: "/dns-query", } if d.clientID != "" { u.Path = path.Join(u.Path, d.clientID) } + d.ServerURL = u.String() case dnsProtoTLS: - dspName = fmt.Sprintf("%s DoT", srvName) + dspName = fmt.Sprintf("%s DoT", d.ServerName) if d.clientID != "" { - d.ServerName = d.clientID + "." + srvName + d.ServerName = d.clientID + "." + d.ServerName } default: return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol) @@ -103,14 +99,6 @@ func getMobileConfig(d dnsSettings) ([]byte, error) { return plist.MarshalIndent(data, plist.XMLFormat, "\t") } -func canUseClientID(srvName, clientID string) (ok bool) { - if clientID == "" { - return true - } - - return strings.HasPrefix(srvName, "*.") -} - func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { var err error @@ -149,20 +137,6 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { return } - if !canUseClientID(host, clientID) { - w.WriteHeader(http.StatusBadRequest) - - msg := fmt.Sprintf("can't use client_id with server name %q", host) - err = json.NewEncoder(w).Encode(&jsonError{ - Message: msg, - }) - if err != nil { - log.Debug("writing 400 json response: %s", err) - } - - return - } - d := dnsSettings{ DNSProtocol: dnsp, ServerName: host, diff --git a/internal/home/mobileconfig_test.go b/internal/home/mobileconfig_test.go index 6cc2dad6ee1..9dcafc972d2 100644 --- a/internal/home/mobileconfig_test.go +++ b/internal/home/mobileconfig_test.go @@ -75,7 +75,7 @@ func TestHandleMobileConfigDOH(t *testing.T) { }) t.Run("client_id", func(t *testing.T) { - r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=*.example.org&client_id=cli42", nil) + r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil) assert.Nil(t, err) w := httptest.NewRecorder() @@ -160,7 +160,7 @@ func TestHandleMobileConfigDOT(t *testing.T) { }) t.Run("client_id", func(t *testing.T) { - r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=*.example.org&client_id=cli42", nil) + r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil) assert.Nil(t, err) w := httptest.NewRecorder() diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index f888f7e75cd..cb679f01491 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -1104,16 +1104,13 @@ `tls.server_name` from the configuration file is used. If `tls.server_name` is not set, the API returns an error with a 500 status. - - Can be a wildcard host, for example `"*.example.org"`. 'example': 'example.org' 'in': 'query' 'name': 'host' 'schema': 'type': 'string' - 'description': > - Client ID. This can only be set it the host is a wildcard, for - example `"*.example.org"`. + Client ID. 'example': 'client-1' 'in': 'query' 'name': 'client_id' @@ -1141,16 +1138,13 @@ `tls.server_name` from the configuration file is used. If `tls.server_name` is not set, the API returns an error with a 500 status. - - Can be a wildcard host, for example `"*.example.org"`. 'example': 'example.org' 'in': 'query' 'name': 'host' 'schema': 'type': 'string' - 'description': > - Client ID. This can only be set it the host is a wildcard, for - example `"*.example.org"`. + Client ID. 'example': 'client-1' 'in': 'query' 'name': 'client_id'