Skip to content

Commit

Permalink
Per channel keys
Browse files Browse the repository at this point in the history
  • Loading branch information
alaingilbert committed Jan 29, 2018
1 parent 801ab79 commit f1b21b3
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 38 deletions.
6 changes: 3 additions & 3 deletions broker/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (c *Conn) onSubscribe(mqttTopic []byte) *EventError {
}

// Check if the key has the permission for the required channel
if key.Target() != 0 && key.Target() != channel.Target() {
if !key.ValidateChannel(string(channel.Channel)) {
return ErrUnauthorized
}

Expand Down Expand Up @@ -131,7 +131,7 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *EventError {
}

// Check if the key has the permission for the required channel
if key.Target() != 0 && key.Target() != channel.Target() {
if !key.ValidateChannel(string(channel.Channel)) {
return ErrUnauthorized
}

Expand Down Expand Up @@ -187,7 +187,7 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *EventError {
}

// Check if the key has the permission for the required channel
if key.Target() != 0 && key.Target() != channel.Target() {
if !key.ValidateChannel(string(channel.Channel)) {
return ErrUnauthorized
}

Expand Down
2 changes: 1 addition & 1 deletion broker/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func (s *Service) onHTTPPublish(w http.ResponseWriter, r *http.Request) {
}

// Check if the key has the permission for the required channel
if key.Target() != 0 && key.Target() != channel.Target() {
if !key.ValidateChannel(string(channel.Channel)) {
w.WriteHeader(http.StatusUnauthorized)
return
}
Expand Down
5 changes: 0 additions & 5 deletions security/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ type Channel struct {
ChannelType uint8
}

// Target returns the channel target (first element of the query, second element of an SSID)
func (c *Channel) Target() uint32 {
return c.Query[0]
}

// TTL returns a Time-To-Live option
func (c *Channel) TTL() (uint32, bool) {
return c.getOptUint("ttl")
Expand Down
16 changes: 0 additions & 16 deletions security/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,3 @@ func TestGetChannelLast(t *testing.T) {
assert.Equal(t, hasValue, tc.ok)
}
}

func TestGetChannelTarget(t *testing.T) {
tests := []struct {
channel string
target uint32
}{
{channel: "emitter/a/?ttl=42&abc=9", target: 0xc103eab3},
}

for _, tc := range tests {
channel := ParseChannel([]byte(tc.channel))
target := channel.Target()

assert.Equal(t, tc.target, target)
}
}
4 changes: 1 addition & 3 deletions security/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ import (
"math/big"
"strconv"
"time"

"github.com/emitter-io/emitter/utils"
)

const (
Expand Down Expand Up @@ -144,7 +142,7 @@ func (c *Cipher) GenerateKey(masterKey Key, channel string, permissions uint32,
key.SetContract(masterKey.Contract())
key.SetSignature(masterKey.Signature())
key.SetPermissions(permissions)
key.SetTarget(utils.GetHash([]byte(channel)))
key.SetTarget(channel)
key.SetExpires(expires)
return c.EncryptKey(key)
}
Expand Down
81 changes: 73 additions & 8 deletions security/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
package security

import (
"errors"
"strings"
"time"

"github.com/emitter-io/emitter/utils"
)

// Access types for a security key.
Expand Down Expand Up @@ -89,28 +93,89 @@ func (k Key) SetSignature(value uint32) {

// Permissions gets the permission flags.
func (k Key) Permissions() uint32 {
return uint32(k[12])<<24 | uint32(k[13])<<16 | uint32(k[14])<<8 | uint32(k[15])
return uint32(k[12])
}

// SetPermissions sets the permission flags.
func (k Key) SetPermissions(value uint32) {
k[12] = byte(value >> 24)
k[13] = byte(value >> 16)
k[14] = byte(value >> 8)
k[15] = byte(value)
k[12] = byte(value)
}

// Target gets the target for the key.
func (k Key) Target() uint32 {
return uint32(k[16])<<24 | uint32(k[17])<<16 | uint32(k[18])<<8 | uint32(k[19])
func (k Key) ValidateChannel(channel string) bool {
channel = strings.TrimRight(channel, "/")
parts := strings.Split(channel, "/")
wc := parts[len(parts)-1] == "#"
if wc {
parts = parts[0 : len(parts)-1]
}

targetBytes := uint16(k[13])<<8 | uint16(k[14])
maxDepth := 0
for i := 0; i < 15; i++ {
if ((targetBytes >> (14 - uint16(i))) & 1) == 1 {
maxDepth = i
}
}
maxDepth += 1

keyIsExactTarget := ((targetBytes >> 15) & 1) == 0
if len(parts) < maxDepth || (keyIsExactTarget && len(parts) != maxDepth) {
return false
}

for idx, part := range parts {
if part == "+" {
if ((targetBytes >> (14 - uint16(idx))) & 1) == 1 {
return false
}
}
if ((targetBytes >> (14 - uint16(idx))) & 1) == 0 {
parts[idx] = "+"
}
}

newChannel := strings.Join(parts[0:maxDepth], "/") + "/"
h := utils.GetHash([]byte(newChannel))

// Bytes 16-17-18-19 contains target hash
keyHash := uint32(k[16])<<24 | uint32(k[17])<<16 | uint32(k[18])<<8 | uint32(k[19])
return h == keyHash
}

// SetTarget sets the target for the key.
func (k Key) SetTarget(value uint32) {
func (k Key) SetTarget(channel string) error {
channel = strings.TrimRight(channel, "/")
parts := strings.Split(channel, "/")
var bitPath uint16 = 0
wc := parts[len(parts)-1] == "#"
if wc {
parts = parts[0 : len(parts)-1]
bitPath |= uint16(1 << 15)
}

if len(parts) > 15 {
return errors.New("Channel can not have more than 15 parts.")
}

for idx, part := range parts {
if part != "+" && part != "#" {
bitPath |= uint16(1 << (14 - uint16(idx)))
}
}

newChannel := strings.Join(parts, "/") + "/"
value := utils.GetHash([]byte(newChannel))

k[13] = byte(bitPath >> 8)
k[14] = byte(bitPath)

k[16] = byte(value >> 24)
k[17] = byte(value >> 16)
k[18] = byte(value >> 8)
k[19] = byte(value)

return nil
}

// Expires gets the expiration date for the key.
Expand Down
33 changes: 31 additions & 2 deletions security/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@ func TestKeyIsEmpty(t *testing.T) {
assert.True(t, true, key.IsEmpty())
}

func TestKey1(t *testing.T) {
key := Key(make([]byte, 24))

// Test exact channel
key.SetTarget("a/b/c/")
assert.False(t, key.ValidateChannel("a/b/"))
assert.True(t, key.ValidateChannel("a/b/c/"))
assert.False(t, key.ValidateChannel("a/b/c/d/"))

// Test exact channel with wildcard
key.SetTarget("a/+/c/")
assert.True(t, key.ValidateChannel("a/b/c/"))
assert.True(t, key.ValidateChannel("a/c/c/"))
assert.True(t, key.ValidateChannel("a/d/c/"))
assert.True(t, key.ValidateChannel("a/+/c/"))
assert.False(t, key.ValidateChannel("a/b/+/"))

// Test open channel
key.SetTarget("a/b/c/#/")
assert.False(t, key.ValidateChannel("a/b/"))
assert.True(t, key.ValidateChannel("a/b/c/"))
assert.True(t, key.ValidateChannel("a/b/c/d/"))
assert.True(t, key.ValidateChannel("a/b/c/d/e/"))
assert.True(t, key.ValidateChannel("a/b/c/d/+/f/"))
assert.True(t, key.ValidateChannel("a/b/c/d/+/f/#/"))

assert.Nil(t, key.SetTarget("1/2/3/4/5/6/7/8/9/10/11/12/13/14/15/"))
assert.Nil(t, key.SetTarget("1/2/3/4/5/6/7/8/9/10/11/12/13/14/15/#/"))
assert.NotNil(t, key.SetTarget("1/2/3/4/5/6/7/8/9/10/11/12/13/14/15/16/"))
}

func TestKey(t *testing.T) {
key := Key(make([]byte, 24))

Expand All @@ -20,15 +51,13 @@ func TestKey(t *testing.T) {
key.SetContract(123)
key.SetSignature(777)
key.SetPermissions(AllowReadWrite)
key.SetTarget(56789)
key.SetExpires(time.Unix(1497683272, 0).UTC())

assert.Equal(t, uint16(999), key.Salt())
assert.Equal(t, uint16(2), key.Master())
assert.Equal(t, uint32(123), key.Contract())
assert.Equal(t, uint32(777), key.Signature())
assert.Equal(t, AllowReadWrite, key.Permissions())
assert.Equal(t, uint32(56789), key.Target())
assert.Equal(t, time.Unix(1497683272, 0).UTC(), key.Expires())

key.SetExpires(time.Unix(0, 0))
Expand Down

0 comments on commit f1b21b3

Please sign in to comment.