diff --git a/extensions/store-s3/cmd/root.go b/extensions/store-s3/cmd/root.go new file mode 100644 index 00000000..df808da4 --- /dev/null +++ b/extensions/store-s3/cmd/root.go @@ -0,0 +1,54 @@ +package cmd + +import ( + "fmt" + "net" + + "github.com/linuxsuren/api-testing/extensions/store-s3/pkg" + "github.com/linuxsuren/api-testing/pkg/testing/remote" + "github.com/spf13/cobra" + "google.golang.org/grpc" +) + +func NewRootCmd(s3Creator pkg.S3Creator) (cmd *cobra.Command) { + opt := &option{ + S3Creator: s3Creator, + } + cmd = &cobra.Command{ + Use: "store-s3", + Short: "S3 storage extension of api-testing", + RunE: opt.runE, + } + flags := cmd.Flags() + flags.IntVarP(&opt.port, "port", "p", 7072, "The port of gRPC server") + return cmd +} + +func (o *option) runE(cmd *cobra.Command, args []string) (err error) { + removeServer := pkg.NewRemoteServer(o.S3Creator) + + var lis net.Listener + lis, err = net.Listen("tcp", fmt.Sprintf(":%d", o.port)) + if err != nil { + return + } + + gRPCServer := grpc.NewServer() + remote.RegisterLoaderServer(gRPCServer, removeServer) + cmd.Println("S3 storage extension is running at port", o.port) + + go func() { + <-cmd.Context().Done() + gRPCServer.Stop() + }() + + err = gRPCServer.Serve(lis) + return +} + +type option struct { + port int + + // inner fields + S3Creator pkg.S3Creator +} diff --git a/extensions/store-s3/cmd/root_test.go b/extensions/store-s3/cmd/root_test.go new file mode 100644 index 00000000..7df92b18 --- /dev/null +++ b/extensions/store-s3/cmd/root_test.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "io" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" +) + +func TestNewRootCmd(t *testing.T) { + t.Run("not run", func(t *testing.T) { + cmd := newRootCmdForTest() + assert.NotNil(t, cmd) + assert.Equal(t, "store-s3", cmd.Use) + assert.Equal(t, "7072", cmd.Flags().Lookup("port").Value.String()) + }) + + t.Run("invalid port", func(t *testing.T) { + cmd := newRootCmdForTest() + cmd.SetArgs([]string{"--port", "-1"}) + err := cmd.Execute() + assert.Error(t, err) + }) + + t.Run("stop the command", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + + cmd := newRootCmdForTest() + cmd.SetContext(ctx) + cmd.SetArgs([]string{"--port", "0"}) + err := cmd.Execute() + assert.NoError(t, err) + }) +} + +func newRootCmdForTest() *cobra.Command { + cmd := NewRootCmd(nil) + cmd.SetOut(io.Discard) + return cmd +} diff --git a/extensions/store-s3/main.go b/extensions/store-s3/main.go index 24bee912..89fbce70 100644 --- a/extensions/store-s3/main.go +++ b/extensions/store-s3/main.go @@ -1,49 +1,15 @@ package main import ( - "fmt" - "net" "os" + "github.com/linuxsuren/api-testing/extensions/store-s3/cmd" "github.com/linuxsuren/api-testing/extensions/store-s3/pkg" - "github.com/linuxsuren/api-testing/pkg/testing/remote" - "github.com/spf13/cobra" - "google.golang.org/grpc" ) func main() { - opt := &option{} - cmd := &cobra.Command{ - Use: "store-s3", - Short: "S3 storage extension of api-testing", - RunE: opt.runE, - } - flags := cmd.Flags() - flags.IntVarP(&opt.port, "port", "p", 7072, "The port of gRPC server") + cmd := cmd.NewRootCmd(&pkg.DefaultS3Creator{}) if err := cmd.Execute(); err != nil { os.Exit(1) } } - -func (o *option) runE(cmd *cobra.Command, args []string) (err error) { - var removeServer remote.LoaderServer - if removeServer, err = pkg.NewRemoteServer(); err != nil { - return - } - - var lis net.Listener - lis, err = net.Listen("tcp", fmt.Sprintf(":%d", o.port)) - if err != nil { - return - } - - gRPCServer := grpc.NewServer() - remote.RegisterLoaderServer(gRPCServer, removeServer) - cmd.Println("S3 storage extension is running at port", o.port) - err = gRPCServer.Serve(lis) - return -} - -type option struct { - port int -} diff --git a/extensions/store-s3/pkg/fake_s3.go b/extensions/store-s3/pkg/fake_s3.go new file mode 100644 index 00000000..97e7fb16 --- /dev/null +++ b/extensions/store-s3/pkg/fake_s3.go @@ -0,0 +1,71 @@ +package pkg + +import ( + "bytes" + "io" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" +) + +type S3API interface { + ListObjectsWithContext(ctx aws.Context, input *s3.ListObjectsInput, opts ...request.Option) (*s3.ListObjectsOutput, error) + PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) + GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) + DeleteObjectWithContext(ctx aws.Context, input *s3.DeleteObjectInput, opts ...request.Option) (*s3.DeleteObjectOutput, error) +} + +type S3Creator interface { + New(p client.ConfigProvider, cfgs ...*aws.Config) S3API +} + +type DefaultS3Creator struct{} + +func (d *DefaultS3Creator) New(p client.ConfigProvider, cfgs ...*aws.Config) S3API { + return s3.New(p, cfgs...) +} + +type fakeS3 struct { + data map[*string][]byte +} + +func (f *fakeS3) New(p client.ConfigProvider, cfgs ...*aws.Config) S3API { + return f +} + +func (f *fakeS3) ListObjectsWithContext(ctx aws.Context, input *s3.ListObjectsInput, opts ...request.Option) (output *s3.ListObjectsOutput, err error) { + output = &s3.ListObjectsOutput{} + for k := range f.data { + output.Contents = append(output.Contents, &s3.Object{ + Key: k, + }) + } + return +} +func (f *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) { + data, err := io.ReadAll(input.Body) + f.data[input.Key] = data + return nil, err +} +func (f *fakeS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (output *s3.GetObjectOutput, err error) { + for k := range f.data { + if *input.Key == *k { + output = &s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader(f.data[k])), + } + break + } + } + return +} +func (f *fakeS3) DeleteObjectWithContext(ctx aws.Context, input *s3.DeleteObjectInput, opts ...request.Option) (*s3.DeleteObjectOutput, error) { + for k := range f.data { + if *input.Key == *k { + delete(f.data, k) + break + } + } + return nil, nil +} diff --git a/extensions/store-s3/pkg/s3_server.go b/extensions/store-s3/pkg/s3_server.go index 50b5e5c7..aaa57d64 100644 --- a/extensions/store-s3/pkg/s3_server.go +++ b/extensions/store-s3/pkg/s3_server.go @@ -17,11 +17,12 @@ import ( ) type s3Client struct { + S3Creator S3Creator remote.UnimplementedLoaderServer } -func NewRemoteServer() (remote.LoaderServer, error) { - return &s3Client{}, nil +func NewRemoteServer(S3Creator S3Creator) remote.LoaderServer { + return &s3Client{S3Creator: S3Creator} } func (s *s3Client) ListTestSuite(ctx context.Context, _ *server.Empty) (suites *remote.TestSuites, err error) { @@ -34,22 +35,28 @@ func (s *s3Client) ListTestSuite(ctx context.Context, _ *server.Empty) (suites * var list *s3.ListObjectsOutput if list, err = client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: aws.String(client.bucket), - }); err == nil { - var suite *testing.TestSuite - for _, obj := range list.Contents { - if !strings.HasSuffix(*obj.Key, ".yaml") { - continue - } + }); err == nil && list != nil { + suites, err = listObjectsOutputToTestSuite(ctx, list, client) + } + return +} +func listObjectsOutputToTestSuite(ctx context.Context, list *s3.ListObjectsOutput, client *s3WithBucket) ( + suites *remote.TestSuites, err error) { + var suite *testing.TestSuite + suites = &remote.TestSuites{} + for _, obj := range list.Contents { + if !strings.HasSuffix(*obj.Key, ".yaml") { + continue + } - var objOutput *s3.GetObjectOutput - if objOutput, err = client.GetObjectWithContext(ctx, &s3.GetObjectInput{ - Bucket: aws.String(client.bucket), - Key: obj.Key, - }); err == nil { - data := objOutput.Body - if suite, err = testing.ParseFromStream(data); err == nil { - suites.Data = append(suites.Data, remote.ConvertToGRPCTestSuite(suite)) - } + var objOutput *s3.GetObjectOutput + if objOutput, err = client.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(client.bucket), + Key: obj.Key, + }); err == nil { + data := objOutput.Body + if suite, err = testing.ParseFromStream(data); err == nil { + suites.Data = append(suites.Data, remote.ConvertToGRPCTestSuite(suite)) } } } @@ -60,39 +67,34 @@ func (s *s3Client) CreateTestSuite(ctx context.Context, testSuite *remote.TestSu reply = &server.Empty{} var data []byte - if data, err = yaml.Marshal(suite); err != nil { - return - } - - var client *s3WithBucket - if client, err = s.getClient(ctx); err != nil { - return + if data, err = yaml.Marshal(suite); err == nil { + var client *s3WithBucket + if client, err = s.getClient(ctx); err == nil { + _, err = client.PutObjectWithContext(ctx, &s3.PutObjectInput{ + Bucket: aws.String(client.bucket), + Key: generateKey(suite.Name), + Body: bytes.NewReader(data), + }) + } } - _, err = client.PutObjectWithContext(ctx, &s3.PutObjectInput{ - Bucket: aws.String(client.bucket), - Key: aws.String(suite.Name + ".yaml"), - Body: bytes.NewReader(data), - }) return } func (s *s3Client) GetTestSuite(ctx context.Context, suite *remote.TestSuite) (reply *remote.TestSuite, err error) { reply = &remote.TestSuite{} var client *s3WithBucket - if client, err = s.getClient(ctx); err != nil || client == nil { - return - } + if client, err = s.getClient(ctx); err == nil && client != nil { + var objOutput *s3.GetObjectOutput + if objOutput, err = client.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(client.bucket), + Key: generateKey(suite.Name), + }); err == nil && objOutput != nil { + data := objOutput.Body - var objOutput *s3.GetObjectOutput - if objOutput, err = client.GetObjectWithContext(ctx, &s3.GetObjectInput{ - Bucket: aws.String(client.bucket), - Key: aws.String(suite.Name + ".yaml"), - }); err == nil { - data := objOutput.Body - - var suite *testing.TestSuite - if suite, err = testing.ParseFromStream(data); err == nil { - reply = remote.ConvertToGRPCTestSuite(suite) + var suite *testing.TestSuite + if suite, err = testing.ParseFromStream(data); err == nil { + reply = remote.ConvertToGRPCTestSuite(suite) + } } } return @@ -100,34 +102,28 @@ func (s *s3Client) GetTestSuite(ctx context.Context, suite *remote.TestSuite) (r func (s *s3Client) UpdateTestSuite(ctx context.Context, suite *remote.TestSuite) (reply *remote.TestSuite, err error) { reply = &remote.TestSuite{} var oldSuite *remote.TestSuite - if oldSuite, err = s.GetTestSuite(ctx, suite); err != nil { - return + if oldSuite, err = s.GetTestSuite(ctx, suite); err == nil { + suite.Items = oldSuite.Items + _, err = s.CreateTestSuite(ctx, suite) } - - suite.Items = oldSuite.Items - _, err = s.CreateTestSuite(ctx, suite) return } func (s *s3Client) DeleteTestSuite(ctx context.Context, suite *remote.TestSuite) (reply *server.Empty, err error) { reply = &server.Empty{} var client *s3WithBucket - if client, err = s.getClient(ctx); err != nil || client == nil { - return + if client, err = s.getClient(ctx); err == nil && client != nil { + _, err = client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(client.bucket), + Key: generateKey(suite.Name), + }) } - - _, err = client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ - Bucket: aws.String(client.bucket), - Key: aws.String(suite.Name + ".yaml"), - }) return } func (s *s3Client) ListTestCases(ctx context.Context, suite *remote.TestSuite) (result *server.TestCases, err error) { - if suite, err = s.GetTestSuite(ctx, suite); err != nil { - return - } - - result = &server.TestCases{ - Data: suite.Items, + if suite, err = s.GetTestSuite(ctx, suite); err == nil { + result = &server.TestCases{ + Data: suite.Items, + } } return } @@ -137,27 +133,18 @@ func (s *s3Client) CreateTestCase(ctx context.Context, testcase *server.TestCase var suite *remote.TestSuite if suite, err = s.GetTestSuite(ctx, &remote.TestSuite{ Name: testcase.SuiteName, - }); err != nil { - return + }); err == nil { + suite.Items = append(suite.Items, testcase) + _, err = s.CreateTestSuite(ctx, suite) } - - suite.Items = append(suite.Items, testcase) - _, err = s.CreateTestSuite(ctx, suite) return } func (s *s3Client) GetTestCase(ctx context.Context, testcase *server.TestCase) (result *server.TestCase, err error) { var suite *remote.TestSuite if suite, err = s.GetTestSuite(ctx, &remote.TestSuite{ Name: testcase.SuiteName, - }); err != nil { - return - } - - for _, item := range suite.Items { - if item.Name == testcase.Name { - result = item - break - } + }); err == nil { + result = getTestCaseByName(suite, testcase.Name) } return } @@ -166,36 +153,20 @@ func (s *s3Client) UpdateTestCase(ctx context.Context, testcase *server.TestCase var suite *remote.TestSuite if suite, err = s.GetTestSuite(ctx, &remote.TestSuite{ Name: testcase.SuiteName, - }); err != nil { - return - } - - for i, item := range suite.Items { - if item.Name == testcase.Name { - suite.Items[i] = testcase - break - } + }); err == nil { + suite = updateTestCase(suite, testcase) + _, err = s.CreateTestSuite(ctx, suite) } - - _, err = s.CreateTestSuite(ctx, suite) return } func (s *s3Client) DeleteTestCase(ctx context.Context, testcase *server.TestCase) (reply *server.Empty, err error) { var suite *remote.TestSuite if suite, err = s.GetTestSuite(ctx, &remote.TestSuite{ Name: testcase.SuiteName, - }); err != nil { - return - } - - for i, item := range suite.Items { - if item.Name == testcase.Name { - suite.Items = append(suite.Items[:i], suite.Items[i+1:]...) - break - } + }); err == nil { + suite = removeTestCaseByName(suite, testcase.Name) + _, err = s.UpdateTestSuite(ctx, suite) } - - _, err = s.UpdateTestSuite(ctx, suite) return } func (s *s3Client) getClient(ctx context.Context) (db *s3WithBucket, err error) { @@ -221,13 +192,11 @@ func (s *s3Client) getClient(ctx context.Context) (db *s3WithBucket, err error) var sess *session.Session sess, err = session.NewSession(&config) - if err != nil { - return + if err == nil { + svc := s.S3Creator.New(sess) // s3.New(sess) + db = &s3WithBucket{S3API: svc, bucket: options.Bucket} + clientCache[store.Name] = db } - - svc := s3.New(sess) - db = &s3WithBucket{S3: svc, bucket: options.Bucket} - clientCache[store.Name] = db } return } @@ -243,6 +212,40 @@ func mapToS3Options(data map[string]string) (opt s3Options) { return } +func generateKey(name string) *string { + return aws.String(name + ".yaml") +} + +func removeTestCaseByName(suite *remote.TestSuite, name string) *remote.TestSuite { + for i, item := range suite.Items { + if item.Name == name { + suite.Items = append(suite.Items[:i], suite.Items[i+1:]...) + break + } + } + return suite +} + +func updateTestCase(suite *remote.TestSuite, testcase *server.TestCase) *remote.TestSuite { + for i, item := range suite.Items { + if item.Name == testcase.Name { + suite.Items[i] = testcase + break + } + } + return suite +} + +func getTestCaseByName(suite *remote.TestSuite, name string) (result *server.TestCase) { + for _, item := range suite.Items { + if item.Name == name { + result = item + break + } + } + return +} + type s3Options struct { // AWS Access key ID AccessKeyID string `yaml:"accessKeyID"` @@ -258,7 +261,7 @@ type s3Options struct { } type s3WithBucket struct { - *s3.S3 + S3API bucket string } diff --git a/extensions/store-s3/pkg/s3_server_test.go b/extensions/store-s3/pkg/s3_server_test.go index b654e201..ba97f681 100644 --- a/extensions/store-s3/pkg/s3_server_test.go +++ b/extensions/store-s3/pkg/s3_server_test.go @@ -1,13 +1,201 @@ package pkg import ( + "context" "testing" + "github.com/aws/aws-sdk-go/aws" + "github.com/linuxsuren/api-testing/pkg/server" + atest "github.com/linuxsuren/api-testing/pkg/testing" + "github.com/linuxsuren/api-testing/pkg/testing/remote" "github.com/stretchr/testify/assert" ) +func newRemoteServer(t *testing.T) remote.LoaderServer { + remoteServer := NewRemoteServer(&fakeS3{data: map[*string][]byte{ + aws.String("invalid"): []byte("invalid"), + }}) + assert.NotNil(t, remoteServer) + return remoteServer +} + func TestNewRemoteServer(t *testing.T) { - server, err := NewRemoteServer() - assert.NotNil(t, server) - assert.NoError(t, err) + emptyCtx := context.Background() + defaultCtx := remote.WithIncomingStoreContext(emptyCtx, &atest.Store{}) + + t.Run("ListTestSuite, no required info in context", func(t *testing.T) { + _, err := newRemoteServer(t).ListTestSuite(emptyCtx, nil) + assert.Error(t, err) + }) + + t.Run("ListTestSuite", func(t *testing.T) { + _, err := newRemoteServer(t).ListTestSuite(defaultCtx, nil) + assert.NoError(t, err) + }) + + t.Run("CreateTestSuite", func(t *testing.T) { + server := newRemoteServer(t) + _, err := server.CreateTestSuite(defaultCtx, &remote.TestSuite{ + Name: "fake", + }) + assert.NoError(t, err) + + var suites *remote.TestSuites + suites, err = server.ListTestSuite(defaultCtx, nil) + if assert.NoError(t, err) { + assert.Equal(t, "fake", suites.Data[0].Name) + } + + var suite *remote.TestSuite + suite, err = server.GetTestSuite(defaultCtx, &remote.TestSuite{Name: "fake"}) + if assert.NoError(t, err) { + assert.Equal(t, "fake", suite.Name) + } + }) + + t.Run("GetTestSuite", func(t *testing.T) { + _, err := newRemoteServer(t).GetTestSuite(defaultCtx, &remote.TestSuite{ + Name: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("UpdateTestSuite", func(t *testing.T) { + _, err := newRemoteServer(t).UpdateTestSuite(defaultCtx, &remote.TestSuite{ + Name: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("DeleteTestSuite", func(t *testing.T) { + server := newRemoteServer(t) + _, err := server.CreateTestSuite(defaultCtx, &remote.TestSuite{ + Name: "fake", + }) + assert.NoError(t, err) + + _, err = server.DeleteTestSuite(defaultCtx, &remote.TestSuite{ + Name: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("ListTestCases", func(t *testing.T) { + _, err := newRemoteServer(t).ListTestCases(defaultCtx, &remote.TestSuite{ + Name: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("CreateTestCase", func(t *testing.T) { + _, err := newRemoteServer(t).CreateTestCase(defaultCtx, &server.TestCase{ + Name: "fake", + SuiteName: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("GetTestCase", func(t *testing.T) { + _, err := newRemoteServer(t).GetTestCase(defaultCtx, &server.TestCase{ + Name: "fake", + SuiteName: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("UpdateTestCase", func(t *testing.T) { + _, err := newRemoteServer(t).UpdateTestCase(defaultCtx, &server.TestCase{ + Name: "fake", + SuiteName: "fake", + }) + assert.NoError(t, err) + }) + + t.Run("DeleteTestCase", func(t *testing.T) { + _, err := newRemoteServer(t).DeleteTestCase(defaultCtx, &server.TestCase{ + Name: "fake", + SuiteName: "fake", + }) + assert.NoError(t, err) + }) +} + +func TestCommonFuns(t *testing.T) { + t.Run("generateKey", func(t *testing.T) { + assert.Equal(t, "test.yaml", *generateKey("test")) + }) + + t.Run("mapToS3Options", func(t *testing.T) { + assert.Equal(t, s3Options{ + AccessKeyID: "id", + SecretAccessKey: "secret", + SessionToken: "token", + Region: "region", + DisableSSL: true, + ForcePathStyle: true, + Bucket: "bucket", + }, mapToS3Options(map[string]string{ + "accesskeyid": "id", + "secretaccesskey": "secret", + "sessiontoken": "token", + "region": "region", + "disablessl": "true", + "forcepathstyle": "true", + "bucket": "bucket", + })) + }) + + t.Run("removeTestCaseByName, an empty TestSuite", func(t *testing.T) { + suite := &remote.TestSuite{ + Items: []*server.TestCase{{ + Name: "fake", + }}, + } + + assert.Equal(t, suite, removeTestCaseByName(suite, "test")) + }) + + t.Run("removeTestCaseByName, a normal TestSuite", func(t *testing.T) { + suite := &remote.TestSuite{ + Items: []*server.TestCase{{ + Name: "fake", + }}, + } + + assert.Empty(t, removeTestCaseByName(suite, "fake").Items) + }) + + t.Run("updateTestCase", func(t *testing.T) { + suite := &remote.TestSuite{ + Items: []*server.TestCase{{ + Name: "fake", + Request: &server.Request{ + Method: "GET", + }, + }}, + } + + suite = updateTestCase(suite, &server.TestCase{ + Name: "fake", + Request: &server.Request{ + Method: "POST", + }, + }) + assert.Equal(t, "POST", suite.Items[0].Request.Method) + }) + + t.Run("getTestCaseByName", func(t *testing.T) { + testCase := &server.TestCase{ + Name: "fake", + Request: &server.Request{ + Api: "http://fake.com", + }, + } + sampleTestSuite := &remote.TestSuite{ + Items: []*server.TestCase{testCase}, + } + + testcase := getTestCaseByName(sampleTestSuite, "fake") + assert.Equal(t, testCase, testcase) + }) } diff --git a/pkg/testing/remote/context.go b/pkg/testing/remote/context.go index bf6a169b..a2fb9373 100644 --- a/pkg/testing/remote/context.go +++ b/pkg/testing/remote/context.go @@ -42,6 +42,10 @@ func GetStoreFromContext(ctx context.Context) (store *testing.Store) { return } +func WithIncomingStoreContext(ctx context.Context, store *testing.Store) context.Context { + return metadata.NewIncomingContext(ctx, metadata.New(store.ToMap())) +} + func MDToStore(md metadata.MD) *testing.Store { data := make(map[string]string) for key, val := range md {