Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Mar 14, 2024
1 parent 2aee9a6 commit 7f11807
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 31 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Expand Up @@ -30,11 +30,17 @@ NOTE: Add new changes BELOW THIS COMMENT.
- Ability to define custom directories for storage of query log files and
statistics ([#5992]).

### Changed

- Failed authentication attempts show the originating IP address in the logs, if
the request was routed through trusted proxy ([#5829]).

### Fixed

- Missing "served from cache" label on long DNS server strings ([#6740]).
- Incorrect tracking of the system hosts file's changes ([#6711]).

[#5829]: https://github.com/AdguardTeam/AdGuardHome/issues/5829
[#5992]: https://github.com/AdguardTeam/AdGuardHome/issues/5992
[#6610]: https://github.com/AdguardTeam/AdGuardHome/issues/6610
[#6711]: https://github.com/AdguardTeam/AdGuardHome/issues/6711
Expand Down
22 changes: 12 additions & 10 deletions internal/home/auth.go
Expand Up @@ -52,13 +52,13 @@ func (s *session) deserialize(data []byte) bool {
return true
}

// Auth - global object
// Auth is the global authentication object.
type Auth struct {
db *bbolt.DB
rateLimiter *authRateLimiter
sessions map[string]*session
users []webUser
trustedProxies []netutil.Prefix
trustedProxies netutil.SubnetSet
lock sync.Mutex
sessionTTL uint32
}
Expand All @@ -71,17 +71,17 @@ type webUser struct {
PasswordHash string `yaml:"password"`
}

// InitAuth - create a global object
// InitAuth initializes the global authentication object.
func InitAuth(
dbFilename string,
users []webUser,
sessionTTL uint32,
rateLimiter *authRateLimiter,
trustedProxies []netutil.Prefix,
) *Auth {
trustedProxies netutil.SubnetSet,
) (a *Auth) {
log.Info("Initializing auth module: %s", dbFilename)

a := &Auth{
a = &Auth{
sessionTTL: sessionTTL,
rateLimiter: rateLimiter,
sessions: make(map[string]*session),
Expand All @@ -104,7 +104,7 @@ func InitAuth(
return a
}

// Close - close module
// Close closes the authentication database.
func (a *Auth) Close() {
_ = a.db.Close()
}
Expand All @@ -113,7 +113,8 @@ func bucketName() []byte {
return []byte("sessions-2")
}

// load sessions from file, remove expired sessions
// loadSessions loads sessions from the database file and removes expired
// sessions.
func (a *Auth) loadSessions() {
tx, err := a.db.Begin(true)
if err != nil {
Expand Down Expand Up @@ -165,7 +166,8 @@ func (a *Auth) loadSessions() {
log.Debug("auth: loaded %d sessions from DB (removed %d expired)", len(a.sessions), removed)
}

// store session data in file
// addSession adds a new session to the list of sessions and saves it in the
// database file.
func (a *Auth) addSession(data []byte, s *session) {
name := hex.EncodeToString(data)
a.lock.Lock()
Expand All @@ -176,7 +178,7 @@ func (a *Auth) addSession(data []byte, s *session) {
}
}

// store session data in file
// storeSession saves a session in the database file.
func (a *Auth) storeSession(data []byte, s *session) bool {
tx, err := a.db.Begin(true)
if err != nil {
Expand Down
24 changes: 4 additions & 20 deletions internal/home/authhttp.go
Expand Up @@ -96,15 +96,15 @@ func realIP(r *http.Request) (ip netip.Addr, err error) {
// If none of the above yielded any results, get the leftmost IP address
// from the X-Forwarded-For header.
s := r.Header.Get(httphdr.XForwardedFor)
ipStrs := strings.SplitN(s, ", ", 2)
ip, err = netip.ParseAddr(ipStrs[0])
ipStr, _, _ := strings.Cut(s, ",")
ip, err = netip.ParseAddr(ipStr)
if err == nil {
return ip, nil
}

// When everything else fails, just return the remote address as understood
// by the stdlib.
ipStr, err := netutil.SplitHost(r.RemoteAddr)
ipStr, err = netutil.SplitHost(r.RemoteAddr)
if err != nil {
return netip.Addr{}, fmt.Errorf("getting ip from client addr: %w", err)
}
Expand Down Expand Up @@ -142,8 +142,6 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
// to security issues.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
//
// TODO(e.burkov): Use realIP when the issue will be fixed.
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
writeErrorWithIP(
r,
Expand Down Expand Up @@ -173,7 +171,6 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
}
}

// Use realIP here, since this IP address is only used for logging.
ip, err := realIP(r)
if err != nil {
log.Error("auth: getting real ip from request with remote ip %s: %s", remoteIP, err)
Expand All @@ -182,7 +179,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
cookie, err := Context.auth.newCookie(req, remoteIP)
if err != nil {
logIP := remoteIP
if isTrustedIP(ip, Context.auth.trustedProxies) {
if Context.auth.trustedProxies.Contains(ip.Unmap()) {
logIP = ip.String()
}

Expand All @@ -203,19 +200,6 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
aghhttp.OK(w)
}

// isTrustedIP returns true if the trustedProxies include ip.
func isTrustedIP(ip netip.Addr, trustedProxies []netutil.Prefix) (ok bool) {
ip = ip.Unmap()

for _, p := range trustedProxies {
if p.Contains(ip) {
return true
}
}

return false
}

// handleLogout is the handler for the GET /control/logout HTTP API.
func handleLogout(w http.ResponseWriter, r *http.Request) {
respHdr := w.Header()
Expand Down
2 changes: 1 addition & 1 deletion internal/home/home.go
Expand Up @@ -667,7 +667,7 @@ func initUsers() (auth *Auth, err error) {
log.Info("authratelimiter is disabled")
}

trustedProxies := slices.Clone(config.DNS.TrustedProxies)
trustedProxies := netutil.SliceSubnetSet(netutil.UnembedPrefixes(config.DNS.TrustedProxies))

sessionTTL := config.HTTPConfig.SessionTTL.Seconds()
auth = InitAuth(sessFilename, config.Users, uint32(sessionTTL), rateLimiter, trustedProxies)
Expand Down

0 comments on commit 7f11807

Please sign in to comment.