-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
tenant_id.go
129 lines (109 loc) · 3.26 KB
/
tenant_id.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package logstorage
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/VictoriaMetrics/VictoriaMetrics/lib/encoding"
)
// TenantID is an id of a tenant for log streams.
//
// Each log stream is associated with a single TenantID.
type TenantID struct {
// AccountID is the id of the account for the log stream.
AccountID uint32
// ProjectID is the id of the project for the log stream.
ProjectID uint32
}
// Reset resets tid.
func (tid *TenantID) Reset() {
tid.AccountID = 0
tid.ProjectID = 0
}
// String returns human-readable representation of tid
func (tid *TenantID) String() string {
return fmt.Sprintf("{accountID=%d,projectID=%d}", tid.AccountID, tid.ProjectID)
}
// equal returns true if tid equals to a.
func (tid *TenantID) equal(a *TenantID) bool {
return tid.AccountID == a.AccountID && tid.ProjectID == a.ProjectID
}
// less returns true if tid is less than a.
func (tid *TenantID) less(a *TenantID) bool {
if tid.AccountID != a.AccountID {
return tid.AccountID < a.AccountID
}
return tid.ProjectID < a.ProjectID
}
// marshal appends the marshaled tid to dst and returns the result
func (tid *TenantID) marshal(dst []byte) []byte {
dst = encoding.MarshalUint32(dst, tid.AccountID)
dst = encoding.MarshalUint32(dst, tid.ProjectID)
return dst
}
// unmarshal unmarshals tid from src and returns the remaining tail.
func (tid *TenantID) unmarshal(src []byte) ([]byte, error) {
if len(src) < 8 {
return src, fmt.Errorf("cannot unmarshal tenantID from %d bytes; need at least 8 bytes", len(src))
}
tid.AccountID = encoding.UnmarshalUint32(src[:4])
tid.ProjectID = encoding.UnmarshalUint32(src[4:])
return src[8:], nil
}
// GetTenantIDFromRequest returns tenantID from r.
func GetTenantIDFromRequest(r *http.Request) (TenantID, error) {
var tenantID TenantID
accountID, err := getUint32FromHeader(r, "AccountID")
if err != nil {
return tenantID, err
}
projectID, err := getUint32FromHeader(r, "ProjectID")
if err != nil {
return tenantID, err
}
tenantID.AccountID = accountID
tenantID.ProjectID = projectID
return tenantID, nil
}
// GetTenantIDFromString returns tenantID from s.
// String is expected in the form of accountID:projectID
func GetTenantIDFromString(s string) (TenantID, error) {
var tenantID TenantID
colon := strings.Index(s, ":")
if colon < 0 {
account, err := getUint32FromString(s)
if err != nil {
return tenantID, fmt.Errorf("cannot parse accountID from %q: %w", s, err)
}
tenantID.AccountID = account
return tenantID, nil
}
account, err := getUint32FromString(s[:colon])
if err != nil {
return tenantID, fmt.Errorf("cannot parse accountID part from %q: %w", s, err)
}
tenantID.AccountID = account
project, err := getUint32FromString(s[colon+1:])
if err != nil {
return tenantID, fmt.Errorf("cannot parse projectID part from %q: %w", s, err)
}
tenantID.ProjectID = project
return tenantID, nil
}
func getUint32FromHeader(r *http.Request, headerName string) (uint32, error) {
s := r.Header.Get(headerName)
if len(s) == 0 {
return 0, nil
}
return getUint32FromString(s)
}
func getUint32FromString(s string) (uint32, error) {
if len(s) == 0 {
return 0, nil
}
n, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return 0, fmt.Errorf("cannot parse %q as uint32: %w", s, err)
}
return uint32(n), nil
}