/
azure.go
210 lines (180 loc) · 7.26 KB
/
azure.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package filemanager
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Microsoft/confidential-sidecar-containers/pkg/common"
"github.com/Microsoft/confidential-sidecar-containers/pkg/msi"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// tokenRefresher is a function callback passed during the creation of token credentials
// its implementation shall update an expired token with a new token and return the new
// expiring duration.
func tokenRefresher(credential azblob.TokenCredential) (t time.Duration) {
// we extract the audience from the existing token so that we can set the resource
// id for retrieving a new (refresh) token for the same audience.
currentToken := credential.Token()
// JWT tokens comprise three fields. the second field is the payload (or claims).
// we care about the `aud` attribute of the payload
currentTokenFields := strings.Split(currentToken, ".")
logrus.Debugf("Current token fields: %v", currentTokenFields)
payload, err := base64.RawURLEncoding.DecodeString(currentTokenFields[1])
if err != nil {
logrus.Errorf("Error decoding base64 token payload: %s", err)
return 0
}
logrus.Debugf("Current token payload: %s", string(payload))
var payloadMap map[string]interface{}
err = json.Unmarshal([]byte(payload), &payloadMap)
if err != nil {
logrus.Errorf("Error unmarshalling token payload: %s", err)
return 0
}
audience := payloadMap["aud"].(string)
identity := common.Identity{
ClientId: payloadMap["appid"].(string),
}
// retrieve token using the existing token audience
logrus.Debugf("Retrieving new token for audience %s and identity %s", audience, identity)
refreshToken, err := common.GetToken(audience, identity)
if err != nil {
logrus.Errorf("Error retrieving token: %s", err)
return 0
}
logrus.Debugf("Retrieved new token: %s", refreshToken.AccessToken)
// Duration expects nanosecond count
ExpiresInSeconds, err := strconv.ParseInt(refreshToken.ExpiresIn, 10, 64)
if err != nil {
logrus.Errorf("Error parsing token expiration to seconds: %s", err)
return 0
}
credential.SetToken(refreshToken.AccessToken)
return time.Duration(1000 * 1000 * 1000 * ExpiresInSeconds)
}
// For more information about the library used to access Azure:
//
// https://pkg.go.dev/github.com/Azure/azure-storage-blob-go/azblob
func AzureSetup(urlString string, urlPrivate bool, identity common.Identity) error {
// Create a ContainerURL object that wraps a blob's URL and a default
// request pipeline.
//
// The pipeline indicates how the outgoing HTTP request and incoming HTTP
// response is processed. It specifies things like retry policies, logging,
// deserialization of HTTP response payloads, and more:
//
// https://pkg.go.dev/github.com/Azure/azure-storage-blob-go/azblob#hdr-URL_Types
logrus.Info("Connecting to Azure...")
u, err := url.Parse(urlString)
if err != nil {
return errors.Wrapf(err, "Can't parse URL string %s", urlString)
}
if urlPrivate {
ctx, cancel := context.WithTimeout(context.Background(), msi.WorkloadIdentityRquestTokenTimeout)
defer cancel()
accessToken := ""
var tokenRefresherFunc func(azblob.TokenCredential) (t time.Duration)
if msi.WorkloadIdentityEnabled() {
tokenRefresherFunc = nil
logrus.Infof("Requesting token for using workload identity from %s", fmt.Sprintf("https://%s", u.Host))
accessToken, err = msi.GetAccessTokenFromFederatedToken(ctx, fmt.Sprintf("https://%s", u.Host))
if err != nil {
return errors.Wrapf(err, "retrieving authentication token using workload identity failed")
}
} else {
tokenRefresherFunc = tokenRefresher
// we use token credentials to access private azure blob storage the blob's
// url Host denotes the scope/audience for which we need to get a token
logrus.Trace("Using token credentials to access private azure blob storage...")
var token common.TokenResponse
count := 0
logrus.Debugf("Getting token for https://%s", u.Host)
for {
token, err = common.GetToken("https://"+u.Host, identity)
if err != nil {
logrus.Info("Can't obtain a token required for accessing private blobs. Will retry in case the ACI identity sidecar is not running yet...")
time.Sleep(3 * time.Second)
count++
if count == 20 {
return errors.Wrapf(err, "Timeout of 60 seconds expired. Could not obtain token")
}
} else {
logrus.Debugf("Token obtained: %s", token.AccessToken)
accessToken = token.AccessToken
break
}
}
}
tokenCredential := azblob.NewTokenCredential(accessToken, tokenRefresherFunc)
logrus.Debugf("Token credential created: %s", tokenCredential.Token())
fm.blobURL = azblob.NewPageBlobURL(*u, azblob.NewPipeline(tokenCredential, azblob.PipelineOptions{}))
logrus.Debugf("Blob URL created: %s", fm.blobURL)
} else {
// we can use anonymous credentials to access public azure blob storage
logrus.Trace("Using anonymous credentials to access public azure blob storage...")
anonCredential := azblob.NewAnonymousCredential()
logrus.Debugf("Anonymous credential created: %s", anonCredential)
fm.blobURL = azblob.NewPageBlobURL(*u, azblob.NewPipeline(anonCredential, azblob.PipelineOptions{}))
logrus.Debugf("Blob URL created: %s", fm.blobURL)
}
// Use a never-expiring context
fm.ctx = context.Background()
logrus.Trace("Getting size of file...")
// Get file size
getMetadata, err := fm.blobURL.GetProperties(fm.ctx, azblob.BlobAccessConditions{},
azblob.ClientProvidedKeyOptions{})
if err != nil {
return errors.Wrapf(err, "Can't get blob file size")
}
fm.contentLength = getMetadata.ContentLength()
logrus.Tracef("Blob Size: %d bytes", fm.contentLength)
// Setup data downloader and uploader
fm.downloadBlock = AzureDownloadBlock
fm.uploadBlock = AzureUploadBlock
return nil
}
func AzureUploadBlock(blockIndex int64, b []byte) (err error) {
logrus.Info("Uploading block...")
bytesInBlock := GetBlockSize()
var offset int64 = blockIndex * bytesInBlock
logrus.Tracef("Block offset %d = block index %d * bytes in block %d", offset, blockIndex, bytesInBlock)
r := bytes.NewReader(b)
_, err = fm.blobURL.UploadPages(fm.ctx, offset, r, azblob.PageBlobAccessConditions{},
nil, azblob.NewClientProvidedKeyOptions(nil, nil, nil))
if err != nil {
return errors.Wrapf(err, "Can't upload block")
}
return nil
}
func AzureDownloadBlock(blockIndex int64) (err error, b []byte) {
logrus.Info("Downloading block...")
bytesInBlock := GetBlockSize()
var offset int64 = blockIndex * bytesInBlock
logrus.Tracef("Block offset %d = block index %d * bytes in block %d", offset, blockIndex, bytesInBlock)
var count int64 = bytesInBlock
get, err := fm.blobURL.Download(fm.ctx, offset, count, azblob.BlobAccessConditions{},
false, azblob.ClientProvidedKeyOptions{})
if err != nil {
var empty []byte
return errors.Wrapf(err, "Can't download block"), empty
}
blobData := &bytes.Buffer{}
reader := get.Body(azblob.RetryReaderOptions{})
_, err = blobData.ReadFrom(reader)
// The client must close the response body when finished with it
reader.Close()
if err != nil {
var empty []byte
return errors.Wrapf(err, "ReadFrom() failed for block"), empty
}
return nil, blobData.Bytes()
}