diff --git a/internal/router/init.go b/internal/router/init.go index 9d58cb3..b78c676 100644 --- a/internal/router/init.go +++ b/internal/router/init.go @@ -17,6 +17,8 @@ import ( var ( lock sync.RWMutex cancel = make(chan bool) + + Verifier = NewChallenger() ) func Setup(errorChan chan<- error, iptables bool) (err error) { diff --git a/internal/router/session_manager.go b/internal/router/session_manager.go new file mode 100644 index 0000000..429b0d9 --- /dev/null +++ b/internal/router/session_manager.go @@ -0,0 +1,155 @@ +package router + +import ( + "crypto/subtle" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/NHAS/wag/internal/data" + "github.com/NHAS/wag/internal/users" + "github.com/NHAS/wag/internal/utils" + "github.com/gorilla/websocket" +) + +type wsConnWrapper struct { + *websocket.Conn + wait chan interface{} +} + +func (ws *wsConnWrapper) Await() <-chan interface{} { + return ws.wait +} + +func (ws *wsConnWrapper) Close() error { + close(ws.wait) + return ws.Conn.Close() +} + +type Challenger struct { + sync.RWMutex + connections map[string]*wsConnWrapper + + upgrader websocket.Upgrader +} + +func NewChallenger() *Challenger { + r := &Challenger{ + connections: make(map[string]*wsConnWrapper), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + domain, err := data.GetDomain() + if err != nil { + log.Println("was unable to get the wag domain: ", err) + return false + } + + valid := r.Header.Get("Origin") == domain + if !valid { + log.Printf("websocket origin does not equal expected value: %q != %q", r.Header.Get("Origin"), domain) + } + + return valid + }, + }, + } + + return r +} + +func (c *Challenger) Challenge(address string) error { + c.RLock() + defer c.RUnlock() + + conn, ok := c.connections[address] + if !ok { + return fmt.Errorf("no connection found for device: %s", address) + } + + err := conn.SetWriteDeadline(time.Now().Add(2 * time.Second)) + if err != nil { + return err + } + + err = conn.WriteJSON("challenge") + if err != nil { + return err + } + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + if err != nil { + return err + } + + msg := struct{ Challenge string }{} + err = conn.ReadJSON(&msg) + if err != nil { + return err + } + + deviceDetails, err := data.GetDeviceByAddress(address) + if err != nil { + return fmt.Errorf("failed to get device address for ws challenge: %s", err) + } + + if subtle.ConstantTimeCompare([]byte(deviceDetails.Challenge), []byte(msg.Challenge)) != 1 { + return fmt.Errorf("challenge does not match") + } + + return nil +} + +func (c *Challenger) WS(w http.ResponseWriter, r *http.Request) { + remoteAddress := utils.GetIPFromRequest(r) + user, err := users.GetUserFromAddress(remoteAddress) + if err != nil { + log.Println("unknown", remoteAddress, "Could not find user: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + // Upgrade HTTP connection to WebSocket connection + _c, err := c.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println(user.Username, remoteAddress, "failed to create websocket challenger:", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + + conn := &wsConnWrapper{Conn: _c, wait: make(chan interface{})} + + defer func() { + if conn != nil { + conn.Close() + } + + c.Lock() + delete(c.connections, remoteAddress.String()) + c.Unlock() + + }() + + c.Lock() + + if prev, ok := c.connections[remoteAddress.String()]; ok && prev != nil { + prev.Close() + } + + c.connections[remoteAddress.String()] = conn + c.Unlock() + + err = c.Challenge(remoteAddress.String()) + if err != nil { + log.Printf("client did not complete ws challenge: %s", err) + return + } + + select { + case <-cancel: + case <-conn.Await(): + } +} diff --git a/internal/router/statemachine.go b/internal/router/statemachine.go index 093c700..4f02b47 100644 --- a/internal/router/statemachine.go +++ b/internal/router/statemachine.go @@ -99,8 +99,20 @@ func deviceChanges(_ string, current, previous data.Device, et data.EventType) e return fmt.Errorf("cannot get lockout: %s", err) } + if current.Endpoint.String() != previous.Endpoint.String() { + + // Will take at most 4 seconds + err := Verifier.Challenge(current.Address) + if err != nil { + log.Printf("%s:%s failed to pass websockets challenge: %s", current.Username, current.Address, err) + err := Deauthenticate(current.Address) + if err != nil { + return fmt.Errorf("cannot deauthenticate device %s: %s", current.Address, err) + } + } + } + if current.Attempts > lockout || // If the number of authentication attempts on a device has exceeded the max - current.Endpoint.String() != previous.Endpoint.String() || // If the client ip has changed current.Authorised.IsZero() { // If we've explicitly deauthorised a device var reasons []string @@ -108,10 +120,6 @@ func deviceChanges(_ string, current, previous data.Device, et data.EventType) e reasons = []string{fmt.Sprintf("exceeded lockout (%d)", current.Attempts)} } - if current.Endpoint.String() != previous.Endpoint.String() { - reasons = append(reasons, "endpoint changed") - } - if current.Authorised.IsZero() { reasons = append(reasons, "session terminated") } diff --git a/internal/webserver/authenticators/authenticators.go b/internal/webserver/authenticators/authenticators.go index 2596aeb..e817875 100644 --- a/internal/webserver/authenticators/authenticators.go +++ b/internal/webserver/authenticators/authenticators.go @@ -20,8 +20,6 @@ var ( types.Pam: new(Pam), } lck sync.RWMutex - - ChallengesManager = NewChallenger() ) func GetMethod(method string) (Authenticator, bool) { diff --git a/internal/webserver/authenticators/challenger.go b/internal/webserver/authenticators/challenger.go deleted file mode 100644 index e99a063..0000000 --- a/internal/webserver/authenticators/challenger.go +++ /dev/null @@ -1,198 +0,0 @@ -package authenticators - -import ( - "crypto/subtle" - "errors" - "fmt" - "log" - "net/http" - "sync" - "time" - - "github.com/NHAS/wag/internal/data" - "github.com/NHAS/wag/internal/users" - "github.com/NHAS/wag/internal/utils" - "github.com/gorilla/websocket" -) - -type Challenger struct { - sync.RWMutex - listenerKey string - challenges map[string]*websocket.Conn - - upgrader websocket.Upgrader -} - -func NewChallenger() *Challenger { - r := &Challenger{ - challenges: make(map[string]*websocket.Conn), - upgrader: websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - domain, err := data.GetDomain() - if err != nil { - log.Println("was unable to get the wag domain: ", err) - return false - } - - valid := r.Header.Get("Origin") == domain - if !valid { - log.Printf("websocket origin does not equal expected value: %q != %q", r.Header.Get("Origin"), domain) - } - - return valid - }, - }, - } - - return r -} - -func (c *Challenger) Start() error { - c.Lock() - defer c.Unlock() - - key, err := data.RegisterEventListener(data.DevicesPrefix, true, c.deviceChangeHandler) - if err != nil { - return fmt.Errorf("unable to register device change listener for challenger: %s", err) - } - c.listenerKey = key - - return err -} - -func (c *Challenger) Stop() error { - c.Lock() - defer c.Unlock() - - var errs []error - if c.listenerKey != "" { - err := data.DeregisterEventListener(c.listenerKey) - if err != nil { - errs = append(errs, err) - } - } - - c.listenerKey = "" - - for i := range c.challenges { - if c.challenges[i] != nil { - c.challenges[i].Close() - } - } - - clear(c.challenges) - - return errors.Join(errs...) -} - -func (c *Challenger) deviceChangeHandler(_ string, current, previous data.Device, et data.EventType) error { - - switch et { - case data.MODIFIED: - c.Lock() - defer c.Unlock() - - conn, ok := c.challenges[current.Address] - if !ok { - // we dont have a challenge for this device - return nil - } - - if current.Challenge != previous.Challenge || - current.Endpoint.String() != previous.Endpoint.String() { - - conn.SetWriteDeadline(time.Now().Add(3 * time.Second)) - err := conn.WriteJSON(struct{ Type string }{Type: "check"}) - if err != nil { - conn.Close() - log.Println("failed to check authorisation: ", err) - return nil - } - conn.SetWriteDeadline(time.Time{}) - - } - - case data.DELETED: - c.Lock() - defer c.Unlock() - - conn, ok := c.challenges[current.Address] - if !ok { - // we dont have a challenge for this device - return nil - } - - conn.Close() - - delete(c.challenges, current.Address) - - } - - return nil -} - -func (c *Challenger) WS(w http.ResponseWriter, r *http.Request) { - remoteAddress := utils.GetIPFromRequest(r) - user, err := users.GetUserFromAddress(remoteAddress) - if err != nil { - log.Println("unknown", remoteAddress, "Could not find user: ", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - // Upgrade HTTP connection to WebSocket connection - conn, err := c.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Println(user.Username, remoteAddress, "failed to create websocket challenger:", err) - http.Error(w, "Server Error", http.StatusInternalServerError) - return - } - - c.Lock() - if conn, ok := c.challenges[remoteAddress.String()]; ok && conn != nil { - conn.Close() - delete(c.challenges, remoteAddress.String()) - data.DeauthenticateDevice(remoteAddress.String()) - } - c.challenges[remoteAddress.String()] = conn - c.Unlock() - - var response struct { - Challenge string - } - - defer func() { - c.Lock() - defer c.Unlock() - - conn.Close() - delete(c.challenges, remoteAddress.String()) - }() - - for { - - err := conn.ReadJSON(&response) - if err != nil { - return - } - - d, err := data.GetDeviceByAddress(remoteAddress.String()) - if err != nil { - return - } - - if subtle.ConstantTimeCompare([]byte(d.Challenge), []byte(response.Challenge)) == 1 { - _, err := data.AuthoriseDevice(user.Username, remoteAddress.String()) - if err != nil { - log.Println("unable to authorise device based on challenge: ", err) - return - } - } else { - data.DeauthenticateDevice(remoteAddress.String()) - return - } - } - -} diff --git a/internal/webserver/web.go b/internal/webserver/web.go index 3b711e0..16bde3d 100644 --- a/internal/webserver/web.go +++ b/internal/webserver/web.go @@ -58,10 +58,6 @@ func Teardown() { publicTLSServ.Close() } - if authenticators.ChallengesManager != nil { - authenticators.ChallengesManager.Stop() - } - log.Println("Stopped MFA portal") } @@ -182,12 +178,7 @@ func Start(errChan chan<- error) error { tunnel.Get("/public_key/", publicKey) - err = authenticators.ChallengesManager.Start() - if err != nil { - return fmt.Errorf("unable to start challenge manager: %s", err) - } - - tunnel.Get("/challenge/", authenticators.ChallengesManager.WS) + tunnel.Get("/challenge/", router.Verifier.WS) tunnel.GetOrPost("/", index)