Skip to content

Commit

Permalink
home: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Nov 2, 2023
1 parent 2956aed commit 1eef672
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 175 deletions.
116 changes: 108 additions & 8 deletions internal/home/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"fmt"
"net/http"
"sync"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"go.etcd.io/bbolt"
"golang.org/x/crypto/bcrypt"
)

// sessionTokenSize is the length of session token in bytes.
Expand Down Expand Up @@ -50,14 +54,16 @@ func (s *session) deserialize(data []byte) bool {
// Auth - global object
type Auth struct {
db *bbolt.DB
raleLimiter *authRateLimiter
rateLimiter *authRateLimiter
sessions map[string]*session
users []webUser
lock sync.Mutex
sessionTTL uint32
}

// webUser represents a user of the Web UI.
//
// TODO(s.chzhen): Improve naming.
type webUser struct {
Name string `yaml:"name"`
PasswordHash string `yaml:"password"`
Expand All @@ -69,7 +75,7 @@ func InitAuth(dbFilename string, users []webUser, sessionTTL uint32, rateLimiter

a := &Auth{
sessionTTL: sessionTTL,
raleLimiter: rateLimiter,
rateLimiter: rateLimiter,
sessions: make(map[string]*session),
users: users,
}
Expand Down Expand Up @@ -197,8 +203,8 @@ func (a *Auth) storeSession(data []byte, s *session) bool {
return true
}

// remove session from file
func (a *Auth) removeSession(sess []byte) {
// removeSessionFromFile removes a stored session from the DB file on disk.
func (a *Auth) removeSessionFromFile(sess []byte) {
tx, err := a.db.Begin(true)
if err != nil {
log.Error("auth: bbolt.Begin: %s", err)
Expand Down Expand Up @@ -260,7 +266,7 @@ func (a *Auth) checkSession(sess string) (res checkSessionResult) {
if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSession(key)
a.removeSessionFromFile(key)

return checkSessionExpired
}
Expand All @@ -282,13 +288,107 @@ func (a *Auth) checkSession(sess string) (res checkSessionResult) {
return checkSessionOK
}

// RemoveSession - remove session
func (a *Auth) RemoveSession(sess string) {
// removeSession removes the session from the active sessions and the disk.
func (a *Auth) removeSession(sess string) {
key, _ := hex.DecodeString(sess)
a.lock.Lock()
delete(a.sessions, sess)
a.lock.Unlock()
a.removeSession(key)
a.removeSessionFromFile(key)
}

// addUser adds a new user with the given password.
func (a *Auth) addUser(u *webUser, password string) (err error) {
if len(password) == 0 {
return errors.Error("empty password")
}

hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("generating hash: %w", err)
}

u.PasswordHash = string(hash)

a.lock.Lock()
defer a.lock.Unlock()

a.users = append(a.users, *u)

log.Debug("auth: added user with login %q", u.Name)

return nil
}

// findUser returns a user if there is one.
func (a *Auth) findUser(login, password string) (u webUser, ok bool) {
a.lock.Lock()
defer a.lock.Unlock()

for _, u = range a.users {
if u.Name == login &&
bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)) == nil {
return u, true
}
}

return webUser{}, false
}

// getCurrentUser returns the current user. It returns an empty User if the
// user is not found.
func (a *Auth) getCurrentUser(r *http.Request) (u webUser) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
u, _ = Context.auth.findUser(user, pass)

return u
}

return webUser{}
}

a.lock.Lock()
defer a.lock.Unlock()

s, ok := a.sessions[cookie.Value]
if !ok {
return webUser{}
}

for _, u = range a.users {
if u.Name == s.userName {
return u
}
}

return webUser{}
}

// usersList returns a copy of a users list.
func (a *Auth) usersList() (users []webUser) {
a.lock.Lock()
defer a.lock.Unlock()

users = make([]webUser, len(a.users))
copy(users, a.users)

return users
}

// authRequired returns true if a authentication is required.
func (a *Auth) authRequired() bool {
if GLMode {
return true
}

a.lock.Lock()
defer a.lock.Unlock()

return len(a.users) != 0
}

// newSessionToken returns cryptographically secure randomly generated slice of
Expand Down
62 changes: 62 additions & 0 deletions internal/home/auth_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package home
import (
"bytes"
"crypto/rand"
"encoding/hex"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -25,3 +28,62 @@ func TestNewSessionToken(t *testing.T) {
require.Error(t, err)
assert.Empty(t, token)
}

func TestAuth(t *testing.T) {
dir := t.TempDir()
fn := filepath.Join(dir, "sessions.db")

users := []webUser{{
Name: "name",
PasswordHash: "$2y$05$..vyzAECIhJPfaQiOK17IukcQnqEgKJHy0iETyYqxn3YXJl8yZuo2",
}}
a := InitAuth(fn, nil, 60, nil)
s := session{}

user := webUser{Name: "name"}
err := a.addUser(&user, "password")
require.NoError(t, err)

assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.removeSession("notfound")

sess, err := newSessionToken()
require.NoError(t, err)
sessStr := hex.EncodeToString(sess)

now := time.Now().UTC().Unix()
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))

// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))

a.Close()

// load saved session
a = InitAuth(fn, users, 60, nil)

// the session is still alive
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()

u, ok := a.findUser("name", "password")
assert.True(t, ok)
assert.NotEmpty(t, u.Name)

time.Sleep(3 * time.Second)

// load and remove expired sessions
a = InitAuth(fn, users, 60, nil)
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))

a.Close()
}

0 comments on commit 1eef672

Please sign in to comment.