/
s3_mock.go
88 lines (73 loc) · 1.75 KB
/
s3_mock.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
package stratus
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"reflect"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/mock"
)
type S3Mock struct {
mock.Mock
}
func NewS3Mock() *S3Mock {
return new(S3Mock)
}
func (client *S3Mock) PutObjectWithContext(
_ aws.Context,
input *s3.PutObjectInput,
_ ...request.Option,
) (*s3.PutObjectOutput, error) {
args := client.Called(input)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*s3.PutObjectOutput), args.Error(1)
}
func (client *S3Mock) PutObjectWithContextMatcher(
expected *s3.PutObjectInput,
) func(*s3.PutObjectInput) bool {
return func(actual *s3.PutObjectInput) bool {
if expected == nil && actual == nil {
return true
}
if expected == nil ||
actual == nil ||
!reflect.DeepEqual(expected.Bucket, actual.Bucket) ||
!reflect.DeepEqual(expected.Key, actual.Key) {
fmt.Printf(
"S3Mock.PutObjectWithContextMatcher: expected '%+v', received '%+v'\n",
expected,
actual,
)
return false
}
err := compareReaders(expected.Body, actual.Body)
if err != nil {
fmt.Printf("S3Mock.PutObjectWithContextMatcher.Body: %+v\n", err)
return false
}
return true
}
}
func compareReaders(a, b io.Reader) error {
var bufferA, bufferB bytes.Buffer
teeA := io.TeeReader(a, &bufferA)
teeB := io.TeeReader(b, &bufferB)
dataA, err := ioutil.ReadAll(teeA)
if err != nil {
return fmt.Errorf("error reading a: %+v", err)
}
dataB, err := ioutil.ReadAll(teeB)
if err != nil {
return fmt.Errorf("error reading b: %+v", err)
}
equal := reflect.DeepEqual(dataA, dataB)
if !equal {
return fmt.Errorf("expected '%s', received '%s'", dataA, dataB)
}
return nil
}