/
shared_key_credential.go
131 lines (107 loc) · 3.79 KB
/
shared_key_credential.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package azcosmos
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"sync/atomic"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
)
// NewKeyCredential creates an KeyCredential containing the
// account's primary or secondary key.
func NewKeyCredential(accountKey string) (KeyCredential, error) {
c := KeyCredential{}
if err := c.Update(accountKey); err != nil {
return c, err
}
return c, nil
}
// KeyCredential contains an account's name and its primary or secondary key.
// It is immutable making it shareable and goroutine-safe.
type KeyCredential struct {
// Only the KeyCredential method should set these; all other methods should treat them as read-only
accountKey atomic.Value // []byte
}
// Update replaces the existing account key with the specified account key.
func (c *KeyCredential) Update(accountKey string) error {
bytes, err := base64.StdEncoding.DecodeString(accountKey)
if err != nil {
return fmt.Errorf("decode account key: %w", err)
}
c.accountKey.Store(bytes)
return nil
}
// computeHMACSHA256 generates a hash signature for an HTTP request
func (c *KeyCredential) computeHMACSHA256(s string) (base64String string) {
h := hmac.New(sha256.New, c.accountKey.Load().([]byte))
_, _ = h.Write([]byte(s))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func (c *KeyCredential) buildCanonicalizedAuthHeaderFromRequest(req *policy.Request) (string, error) {
var opValues pipelineRequestOptions
value := ""
if req.OperationValue(&opValues) {
resourceTypePath, err := getResourcePath(opValues.resourceType)
if err != nil {
return "", err
}
resourceAddress := opValues.resourceAddress
if opValues.isRidBased {
resourceAddress = strings.ToLower(resourceAddress)
}
value = c.buildCanonicalizedAuthHeader(req.Raw().Method, resourceTypePath, resourceAddress, req.Raw().Header.Get(headerXmsDate), "master", "1.0")
}
return value, nil
}
//where date is like time.RFC1123 but hard-codes GMT as the time zone
func (c *KeyCredential) buildCanonicalizedAuthHeader(method, resourceType, resourceAddress, xmsDate, tokenType, version string) string {
if method == "" || resourceType == "" {
return ""
}
// https://docs.microsoft.com/en-us/rest/api/cosmos-db/access-control-on-cosmosdb-resources#constructkeytoken
stringToSign := join(strings.ToLower(method), "\n", strings.ToLower(resourceType), "\n", resourceAddress, "\n", strings.ToLower(xmsDate), "\n", "", "\n")
signature := c.computeHMACSHA256(stringToSign)
return url.QueryEscape(join("type=" + tokenType + "&ver=" + version + "&sig=" + signature))
}
type sharedKeyCredPolicy struct {
cred KeyCredential
}
func newSharedKeyCredPolicy(cred KeyCredential) *sharedKeyCredPolicy {
s := &sharedKeyCredPolicy{
cred: cred,
}
return s
}
func (s *sharedKeyCredPolicy) Do(req *policy.Request) (*http.Response, error) {
// Add a x-ms-date header if it doesn't already exist
if d := req.Raw().Header.Get(headerXmsDate); d == "" {
req.Raw().Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
}
authHeader, err := s.cred.buildCanonicalizedAuthHeaderFromRequest(req)
if err != nil {
return nil, err
}
if authHeader != "" {
req.Raw().Header.Set(headerAuthorization, authHeader)
}
response, err := req.Next()
if err != nil && response != nil && response.StatusCode == http.StatusForbidden {
// Service failed to authenticate request, log it
log.Write(log.EventResponse, "===== HTTP Forbidden status, Authorization:\n"+authHeader+"\n=====\n")
}
return response, err
}
func join(strs ...string) string {
var sb strings.Builder
for _, str := range strs {
sb.WriteString(str)
}
return sb.String()
}