/
pathio.go
413 lines (364 loc) · 11.9 KB
/
pathio.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
// Package pathio is a package that allows writing to and reading from different types of paths transparently.
// It supports two types of paths:
// 1. Local file paths
// 2. S3 File Paths (s3://bucket/key)
//
// Note that using s3 paths requires setting two environment variables
// 1. AWS_SECRET_ACCESS_KEY
// 2. AWS_ACCESS_KEY_ID
package pathio
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
)
const (
defaultLocation = "us-east-1"
aesAlgo = "AES256"
)
// generate a mock for Pathio
//go:generate bin/mockgen -source=$GOFILE -destination=gen_mock_s3handler.go -package=pathio
// Pathio is a defined interface for accessing both S3 and local files.
type Pathio interface {
Reader(path string) (rc io.ReadCloser, err error)
Write(path string, input []byte) error
WriteReader(path string, input io.ReadSeeker) error
Delete(path string) error
ListFiles(path string) ([]string, error)
Exists(path string) (bool, error)
}
// Client is the pathio client used to access the local file system and S3.
// To configure options on the client, create a new Client and call its methods
// directly.
// &Client{
// disableS3Encryption: true, // disables encryption
// Region: "us-east-1", // hardcodes the s3 region, instead of looking it up
// }.Write(...)
type Client struct {
disableS3Encryption bool
Region string
providedConfig *aws.Config
}
// DefaultClient is the default pathio client called by the Reader, Writer, and
// WriteReader methods. It has S3 encryption enabled.
var DefaultClient Pathio = &Client{}
// NewClient creates a new client that utilizes the provided AWS config. This can
// be leveraged to enforce more limited permissions.
func NewClient(cfg *aws.Config) *Client {
return &Client{
providedConfig: cfg,
}
}
// Reader calls DefaultClient's Reader method.
func Reader(path string) (rc io.ReadCloser, err error) {
return DefaultClient.Reader(path)
}
// Write calls DefaultClient's Write method.
func Write(path string, input []byte) error {
return DefaultClient.Write(path, input)
}
// WriteReader calls DefaultClient's WriteReader method.
func WriteReader(path string, input io.ReadSeeker) error {
return DefaultClient.WriteReader(path, input)
}
// Delete calls DefaultClient's Delete method.
func Delete(path string) error {
return DefaultClient.Delete(path)
}
// ListFiles calls DefaultClient's ListFiles method.
func ListFiles(path string) ([]string, error) {
return DefaultClient.ListFiles(path)
}
// Exists calls DefaultClient's Exists method.
func Exists(path string) (bool, error) {
return DefaultClient.Exists(path)
}
// s3Handler defines the interface that pathio needs for AWS access.
type s3Handler interface {
GetBucketLocation(input *s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error)
GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error)
DeleteObject(input *s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error)
PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error)
ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error)
HeadObject(input *s3.HeadObjectInput) (*s3.HeadObjectOutput, error)
}
type s3Connection struct {
handler s3Handler
bucket string
key string
}
// Reader returns an io.Reader for the specified path. The path can either be a local file path
// or an S3 path. It is the caller's responsibility to close rc.
func (c *Client) Reader(path string) (rc io.ReadCloser, err error) {
if strings.HasPrefix(path, "s3://") {
s3Conn, err := c.s3ConnectionInformation(path, c.Region)
if err != nil {
return nil, err
}
return s3FileReader(s3Conn)
}
// Local file path
return os.Open(path)
}
// Write writes a byte array to the specified path. The path can be either a local file path or an
// S3 path.
func (c *Client) Write(path string, input []byte) error {
return c.WriteReader(path, bytes.NewReader(input))
}
// WriteReader writes all the data read from the specified io.Reader to the
// output path. The path can either a local file path or an S3 path.
func (c *Client) WriteReader(path string, input io.ReadSeeker) error {
// return the file pointer to the start before reading from it when writing
if offset, err := input.Seek(0, os.SEEK_SET); err != nil || offset != 0 {
return fmt.Errorf("failed to reset the file pointer to 0. offset: %d; error %s", offset, err)
}
if strings.HasPrefix(path, "s3://") {
s3Conn, err := c.s3ConnectionInformation(path, c.Region)
if err != nil {
return err
}
return writeToS3(s3Conn, input, c.disableS3Encryption)
}
return writeToLocalFile(path, input)
}
// Delete deletes the object at the specified path. The path can be either
// a local file path or an S3 path.
func (c *Client) Delete(path string) error {
if strings.HasPrefix(path, "s3://") {
s3Conn, err := c.s3ConnectionInformation(path, c.Region)
if err != nil {
return err
}
return deleteS3Object(s3Conn)
}
// Local file path
return os.Remove(path)
}
// ListFiles lists all the files/directories in the directory. It does not recurse
func (c *Client) ListFiles(path string) ([]string, error) {
if strings.HasPrefix(path, "s3://") {
s3Conn, err := c.s3ConnectionInformation(path, c.Region)
if err != nil {
return nil, err
}
return lsS3(s3Conn)
}
return lsLocal(path)
}
// Exists determines if a path does or does not exist.
// NOTE: S3 is eventually consistent so keep in mind that there is a delay.
func (c *Client) Exists(path string) (bool, error) {
if strings.HasPrefix(path, "s3://") {
s3Conn, err := c.s3ConnectionInformation(path, c.Region)
if err != nil {
return false, err
}
return existsS3(s3Conn)
}
return existsLocal(path)
}
func existsS3(s3Conn s3Connection) (bool, error) {
_, err := s3Conn.handler.HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(s3Conn.bucket),
Key: aws.String(s3Conn.key),
})
if err != nil {
if aerr, ok := err.(s3.RequestFailure); ok && aerr.StatusCode() == 404 {
return false, nil
}
return false, err
}
return true, nil
}
func existsLocal(path string) (bool, error) {
_, err := os.Stat(path)
if os.IsNotExist(err) {
return false, nil
}
return err == nil, err
}
func lsS3(s3Conn s3Connection) ([]string, error) {
params := s3.ListObjectsInput{
Bucket: aws.String(s3Conn.bucket),
Prefix: aws.String(s3Conn.key),
Delimiter: aws.String("/"),
}
finalResults := []string{}
// s3 ListObjects limits the respose to 1000 objects and marks as truncated if there were more
// In this case we set a Marker that the next query will start from.
// We also ensure that prefixes are not duplicated
for {
resp, err := s3Conn.handler.ListObjects(¶ms)
if err != nil {
return nil, err
}
if len(resp.CommonPrefixes) > 0 && elementInSlice(finalResults, *resp.CommonPrefixes[0].Prefix) {
resp.CommonPrefixes = resp.CommonPrefixes[1:]
}
results := make([]string, len(resp.Contents)+len(resp.CommonPrefixes))
for i, val := range resp.CommonPrefixes {
results[i] = *val.Prefix
}
for i, val := range resp.Contents {
results[i+len(resp.CommonPrefixes)] = *val.Key
}
finalResults = append(finalResults, results...)
if resp.IsTruncated != nil && *resp.IsTruncated {
params.Marker = aws.String(results[len(results)-1])
} else {
break
}
}
return finalResults, nil
}
func elementInSlice(slice []string, elem string) bool {
for _, v := range slice {
if elem == v {
return true
}
}
return false
}
func lsLocal(path string) ([]string, error) {
resp, err := ioutil.ReadDir(path)
if err != nil {
return nil, err
}
results := make([]string, len(resp))
for i, val := range resp {
results[i] = val.Name()
}
return results, nil
}
// s3FileReader converts an S3Path into an io.ReadCloser
func s3FileReader(s3Conn s3Connection) (io.ReadCloser, error) {
params := s3.GetObjectInput{
Bucket: aws.String(s3Conn.bucket),
Key: aws.String(s3Conn.key),
}
resp, err := s3Conn.handler.GetObject(¶ms)
if err != nil {
return nil, err
}
return resp.Body, nil
}
// writeToS3 uploads the given file to S3
func writeToS3(s3Conn s3Connection, input io.ReadSeeker, disableEncryption bool) error {
params := s3.PutObjectInput{
Bucket: aws.String(s3Conn.bucket),
Key: aws.String(s3Conn.key),
Body: input,
}
if !disableEncryption {
algo := aesAlgo
params.ServerSideEncryption = &algo
}
_, err := s3Conn.handler.PutObject(¶ms)
return err
}
// deleteS3Object deletes the file on S3 at the given path
func deleteS3Object(s3Conn s3Connection) error {
params := s3.DeleteObjectInput{
Bucket: aws.String(s3Conn.bucket),
Key: aws.String(s3Conn.key),
}
_, err := s3Conn.handler.DeleteObject(¶ms)
return err
}
// writeToLocalFile writes the given file locally
func writeToLocalFile(path string, input io.ReadSeeker) error {
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
return err
}
file, err := os.Create(path)
defer file.Close()
if err != nil {
return err
}
_, err = io.Copy(file, input)
return err
}
// parseS3path parses an S3 path (s3://bucket/key) and returns a bucket, key, error tuple
func parseS3Path(path string) (string, string, error) {
// S3 path names are of the form s3://bucket/key
stringsArray := strings.SplitN(path, "/", 4)
if len(stringsArray) < 4 {
return "", "", fmt.Errorf("Invalid s3 path %s", path)
}
bucketName := stringsArray[2]
// Everything after the third slash is the key
key := stringsArray[3]
return bucketName, key, nil
}
// s3ConnectionInformation parses the s3 path and returns the s3 connection from the
// correct region, as well as the bucket, and key
func (c *Client) s3ConnectionInformation(path, region string) (s3Connection, error) {
bucket, key, err := parseS3Path(path)
if err != nil {
return s3Connection{}, err
}
// If no region passed in, look up region in S3
if region == "" {
region, err = getRegionForBucket(c.newS3Handler(defaultLocation), bucket)
if err != nil {
return s3Connection{}, err
}
}
return s3Connection{c.newS3Handler(region), bucket, key}, nil
}
// getRegionForBucket looks up the region name for the given bucket
func getRegionForBucket(svc s3Handler, name string) (string, error) {
// Any region will work for the region lookup, but the request MUST use
// PathStyle
params := s3.GetBucketLocationInput{
Bucket: aws.String(name),
}
resp, err := svc.GetBucketLocation(¶ms)
if err != nil {
return "", fmt.Errorf("Failed to get location for bucket '%s', %s", name, err)
}
if resp.LocationConstraint == nil {
// "US Standard", returns an empty region, which means us-east-1
// See http://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html
return defaultLocation, nil
}
return *resp.LocationConstraint, nil
}
type liveS3Handler struct {
liveS3 *s3.S3
}
func (m *liveS3Handler) GetBucketLocation(input *s3.GetBucketLocationInput) (*s3.GetBucketLocationOutput, error) {
return m.liveS3.GetBucketLocation(input)
}
func (m *liveS3Handler) GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error) {
return m.liveS3.GetObject(input)
}
func (m *liveS3Handler) DeleteObject(input *s3.DeleteObjectInput) (*s3.DeleteObjectOutput, error) {
return m.liveS3.DeleteObject(input)
}
func (m *liveS3Handler) PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error) {
return m.liveS3.PutObject(input)
}
func (m *liveS3Handler) ListObjects(input *s3.ListObjectsInput) (*s3.ListObjectsOutput, error) {
return m.liveS3.ListObjects(input)
}
func (m *liveS3Handler) HeadObject(input *s3.HeadObjectInput) (*s3.HeadObjectOutput, error) {
return m.liveS3.HeadObject(input)
}
func (c *Client) newS3Handler(region string) *liveS3Handler {
if c.providedConfig != nil {
return &liveS3Handler{
liveS3: s3.New(session.New(), c.providedConfig.WithRegion(region).WithS3ForcePathStyle(true)),
}
}
config := aws.NewConfig().WithRegion(region).WithS3ForcePathStyle(true)
session := session.New()
return &liveS3Handler{s3.New(session, config)}
}