From 7486b073d936097151a5641db66276edabe93676 Mon Sep 17 00:00:00 2001 From: rick Date: Mon, 4 Aug 2025 13:49:01 +0800 Subject: [PATCH 1/3] feat: support read mock response from local file --- docs/api-testing-mock-schema.json | 10 +- docs/site/content/zh/latest/tasks/mock.md | 17 +- pkg/mock/in_memory.go | 1216 +++++----- pkg/mock/in_memory_test.go | 399 ++-- pkg/mock/server.go | 2 +- pkg/mock/testdata/api.yaml | 6 + pkg/mock/types.go | 11 +- pkg/server/remote_server.go | 2450 +++++++++++---------- 8 files changed, 2083 insertions(+), 2028 deletions(-) diff --git a/docs/api-testing-mock-schema.json b/docs/api-testing-mock-schema.json index 7af68faa..6b2b86fd 100644 --- a/docs/api-testing-mock-schema.json +++ b/docs/api-testing-mock-schema.json @@ -60,11 +60,19 @@ "type": "object", "properties": { "encoder": { - "type": "string" + "type": "string", + "enum": [ + "base64", + "url", + "raw" + ] }, "body": { "type": "string" }, + "bodyFromFile": { + "type": "string" + }, "header": { "type": "object", "description": "HTTP response headers. Common headers include 'Content-Type', 'Cache-Control', 'Set-Cookie', etc.", diff --git a/docs/site/content/zh/latest/tasks/mock.md b/docs/site/content/zh/latest/tasks/mock.md index 37f894fd..f415e886 100644 --- a/docs/site/content/zh/latest/tasks/mock.md +++ b/docs/site/content/zh/latest/tasks/mock.md @@ -102,7 +102,9 @@ items: curl http://localhost:6060/mock/api/v1/repos/atest/prs -v ``` -另外,为了满足复杂的场景,还可以对 Response Body 做特定的解码,目前支持:`base64`、`url`: +另外,为了满足复杂的场景,还可以对 Response Body 做特定的解码,目前支持:`base64`、`url`、`raw`: + +> encoder 为 `raw` 时,表示不进行处理 ```yaml #!api-testing-mock @@ -136,6 +138,19 @@ items: encoder: url ``` +如果你的响应内容比较大,或者保存在一个本地文件中,那么你可以这么写: + +```yaml +#!api-testing-mock +# yaml-language-server: $schema=https://linuxsuren.github.io/api-testing/api-testing-mock-schema.json +items: + - name: baidu + request: + path: /api/v1/baidu + response: + bodyFromFile: /tmp/baidu.html +``` + 在实际情况中,往往是向已有系统或平台添加新的 API,此时要 Mock 所有已经存在的 API 就既没必要也需要很多工作量。因此,我们提供了一种简单的方式,即可以增加**代理**的方式把已有的 API 请求转发到实际的地址,只对新增的 API 进行 Mock 处理。如下所示: ```yaml diff --git a/pkg/mock/in_memory.go b/pkg/mock/in_memory.go index 7296aad2..804c2e18 100644 --- a/pkg/mock/in_memory.go +++ b/pkg/mock/in_memory.go @@ -16,697 +16,709 @@ limitations under the License. package mock import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "sort" - "strings" - "sync" - "time" - - jsonpatch "github.com/evanphx/json-patch" - "github.com/swaggest/openapi-go/openapi3" - "github.com/swaggest/rest/gorillamux" - - "github.com/linuxsuren/api-testing/pkg/version" - - "github.com/linuxsuren/api-testing/pkg/logging" - "github.com/linuxsuren/api-testing/pkg/render" - "github.com/linuxsuren/api-testing/pkg/util" - - "github.com/gorilla/mux" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "sort" + "strings" + "sync" + "time" + + jsonpatch "github.com/evanphx/json-patch" + "github.com/swaggest/openapi-go/openapi3" + "github.com/swaggest/rest/gorillamux" + + "github.com/linuxsuren/api-testing/pkg/version" + + "github.com/linuxsuren/api-testing/pkg/logging" + "github.com/linuxsuren/api-testing/pkg/render" + "github.com/linuxsuren/api-testing/pkg/util" + + "github.com/gorilla/mux" ) var ( - memLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("memory") + memLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("memory") ) type inMemoryServer struct { - data map[string][]map[string]interface{} - mux *mux.Router - listener net.Listener - certFile, keyFile string - port int - prefix string - wg sync.WaitGroup - ctx context.Context - cancelFunc context.CancelFunc - reader Reader - metrics RequestMetrics + data map[string][]map[string]interface{} + mux *mux.Router + listener net.Listener + certFile, keyFile string + port int + prefix string + wg sync.WaitGroup + ctx context.Context + cancelFunc context.CancelFunc + reader Reader + metrics RequestMetrics } func NewInMemoryServer(ctx context.Context, port int) DynamicServer { - ctx, cancel := context.WithCancel(ctx) - return &inMemoryServer{ - port: port, - wg: sync.WaitGroup{}, - ctx: ctx, - cancelFunc: cancel, - metrics: NewNoopMetrics(), - } + ctx, cancel := context.WithCancel(ctx) + return &inMemoryServer{ + port: port, + wg: sync.WaitGroup{}, + ctx: ctx, + cancelFunc: cancel, + metrics: NewNoopMetrics(), + } } func (s *inMemoryServer) SetupHandler(reader Reader, prefix string) (handler http.Handler, err error) { - s.reader = reader - // init the data - s.data = make(map[string][]map[string]interface{}) - s.mux = mux.NewRouter().PathPrefix(prefix).Subrouter() - s.prefix = prefix - handler = s.mux - s.metrics.AddMetricsHandler(s.mux) - err = s.Load() - return + s.reader = reader + // init the data + s.data = make(map[string][]map[string]interface{}) + s.mux = mux.NewRouter().PathPrefix(prefix).Subrouter() + s.prefix = prefix + handler = s.mux + s.metrics.AddMetricsHandler(s.mux) + err = s.Load() + return } func (s *inMemoryServer) WithTLS(certFile, keyFile string) DynamicServer { - s.certFile = certFile - s.keyFile = keyFile - return s + s.certFile = certFile + s.keyFile = keyFile + return s } func (s *inMemoryServer) WithLogWriter(writer io.Writer) DynamicServer { - if writer != nil { - memLogger = memLogger.WithNameAndWriter("stream", writer) - } - return s + if writer != nil { + memLogger = memLogger.WithNameAndWriter("stream", writer) + } + return s } func (s *inMemoryServer) GetTLS() (string, string) { - return s.certFile, s.keyFile + return s.certFile, s.keyFile } func (s *inMemoryServer) Load() (err error) { - var server *Server - if server, err = s.reader.Parse(); err != nil { - return - } - - memLogger.Info("start to run all the APIs from objects", "count", len(server.Objects)) - for _, obj := range server.Objects { - memLogger.Info("start mock server from object", "name", obj.Name) - s.startObject(obj) - s.initObjectData(obj) - } - - memLogger.Info("start to run all the APIs from items", "count", len(server.Items)) - for _, item := range server.Items { - s.startItem(item) - } - - memLogger.Info("start webhook servers", "count", len(server.Webhooks)) - for _, item := range server.Webhooks { - if err = s.startWebhook(&item); err != nil { - continue - } - } - - s.handleOpenAPI() - - for i, proxy := range server.Proxies { - memLogger.Info("start to proxy", "target", proxy.Target) - switch proxy.Protocol { - case "http", "": - s.httpProxy(&proxy) - case "tcp": - s.tcpProxy(&server.Proxies[i]) - default: - memLogger.Error(fmt.Errorf("unsupported protocol: %s", proxy.Protocol), "failed to start proxy") - } - } - return + var server *Server + if server, err = s.reader.Parse(); err != nil { + return + } + + memLogger.Info("start to run all the APIs from objects", "count", len(server.Objects)) + for _, obj := range server.Objects { + memLogger.Info("start mock server from object", "name", obj.Name) + s.startObject(obj) + s.initObjectData(obj) + } + + memLogger.Info("start to run all the APIs from items", "count", len(server.Items)) + for _, item := range server.Items { + s.startItem(item) + } + + memLogger.Info("start webhook servers", "count", len(server.Webhooks)) + for _, item := range server.Webhooks { + if err = s.startWebhook(&item); err != nil { + continue + } + } + + s.handleOpenAPI() + + for i, proxy := range server.Proxies { + memLogger.Info("start to proxy", "target", proxy.Target) + switch proxy.Protocol { + case "http", "": + s.httpProxy(&proxy) + case "tcp": + s.tcpProxy(&server.Proxies[i]) + default: + memLogger.Error(fmt.Errorf("unsupported protocol: %s", proxy.Protocol), "failed to start proxy") + } + } + return } func (s *inMemoryServer) httpProxy(proxy *Proxy) { - s.mux.HandleFunc(proxy.Path, func(w http.ResponseWriter, req *http.Request) { - if !strings.HasSuffix(proxy.Target, "/") { - proxy.Target += "/" - } - targetPath := strings.TrimPrefix(req.URL.Path, s.prefix) - targetPath = strings.TrimPrefix(targetPath, "/") - - apiRaw := fmt.Sprintf("%s%s", proxy.Target, targetPath) - api, err := render.Render("proxy api", apiRaw, s) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to render proxy api", "api", apiRaw) - return - } - memLogger.Info("redirect to", "target", api) - - var requestBody []byte - if requestBody, err = io.ReadAll(req.Body); err != nil { - w.WriteHeader(http.StatusInternalServerError) - } - - if proxy.RequestAmend.BodyPatch != "" && len(requestBody) > 0 { - var patch jsonpatch.Patch - if patch, err = jsonpatch.DecodePatch([]byte(proxy.RequestAmend.BodyPatch)); err != nil { - return - } - - fmt.Println("before patch:", string(requestBody)) - if requestBody, err = patch.Apply(requestBody); err != nil { - fmt.Println(err) - return - } - fmt.Println("after patch:", string(requestBody)) - } - - targetReq, err := http.NewRequestWithContext(req.Context(), req.Method, api, bytes.NewBuffer(requestBody)) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to create proxy request") - return - } - - for k, v := range req.Header { - targetReq.Header.Add(k, v[0]) - } - - resp, err := http.DefaultClient.Do(targetReq) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to do proxy request") - return - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to read response body") - return - } - - for k, v := range resp.Header { - w.Header().Add(k, v[0]) - } - w.Write(data) - }) + s.mux.HandleFunc(proxy.Path, func(w http.ResponseWriter, req *http.Request) { + if !strings.HasSuffix(proxy.Target, "/") { + proxy.Target += "/" + } + targetPath := strings.TrimPrefix(req.URL.Path, s.prefix) + targetPath = strings.TrimPrefix(targetPath, "/") + + apiRaw := fmt.Sprintf("%s%s", proxy.Target, targetPath) + api, err := render.Render("proxy api", apiRaw, s) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to render proxy api", "api", apiRaw) + return + } + memLogger.Info("redirect to", "target", api) + + var requestBody []byte + if requestBody, err = io.ReadAll(req.Body); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + + if proxy.RequestAmend.BodyPatch != "" && len(requestBody) > 0 { + var patch jsonpatch.Patch + if patch, err = jsonpatch.DecodePatch([]byte(proxy.RequestAmend.BodyPatch)); err != nil { + return + } + + fmt.Println("before patch:", string(requestBody)) + if requestBody, err = patch.Apply(requestBody); err != nil { + fmt.Println(err) + return + } + fmt.Println("after patch:", string(requestBody)) + } + + targetReq, err := http.NewRequestWithContext(req.Context(), req.Method, api, bytes.NewBuffer(requestBody)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to create proxy request") + return + } + + for k, v := range req.Header { + targetReq.Header.Add(k, v[0]) + } + + resp, err := http.DefaultClient.Do(targetReq) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to do proxy request") + return + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to read response body") + return + } + + for k, v := range resp.Header { + w.Header().Add(k, v[0]) + } + w.Write(data) + }) } func (s *inMemoryServer) tcpProxy(proxy *Proxy) { - fmt.Println("start to proxy", proxy.Port) - lisener, err := net.Listen("tcp", fmt.Sprintf(":%d", proxy.Port)) - if err != nil { - memLogger.Error(err, "failed to listen") - return - } - fmt.Printf("proxy local: %d, target: %s\n", proxy.Port, proxy.Target) - defer lisener.Close() - - for { - conn, err := lisener.Accept() - if err != nil { - memLogger.Error(err, "failed to accept") - continue - } - - fmt.Println("accept connection") - go handleConnection(conn, proxy.Target) - } + fmt.Println("start to proxy", proxy.Port) + lisener, err := net.Listen("tcp", fmt.Sprintf(":%d", proxy.Port)) + if err != nil { + memLogger.Error(err, "failed to listen") + return + } + fmt.Printf("proxy local: %d, target: %s\n", proxy.Port, proxy.Target) + defer lisener.Close() + + for { + conn, err := lisener.Accept() + if err != nil { + memLogger.Error(err, "failed to accept") + continue + } + + fmt.Println("accept connection") + go handleConnection(conn, proxy.Target) + } } func handleConnection(clientConn net.Conn, targetAddr string) { - defer clientConn.Close() + defer clientConn.Close() - targetConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second) - if err != nil { - fmt.Printf("Failed to connect to target server: %v\n", err) - return - } - defer targetConn.Close() + targetConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second) + if err != nil { + fmt.Printf("Failed to connect to target server: %v\n", err) + return + } + defer targetConn.Close() - fmt.Printf("Connection established between %s and %s\n", clientConn.RemoteAddr(), targetConn.RemoteAddr()) + fmt.Printf("Connection established between %s and %s\n", clientConn.RemoteAddr(), targetConn.RemoteAddr()) - go io.Copy(clientConn, targetConn) - go io.Copy(targetConn, clientConn) + go io.Copy(clientConn, targetConn) + go io.Copy(targetConn, clientConn) - select {} + select {} } func (s *inMemoryServer) Start(reader Reader, prefix string) (err error) { - var handler http.Handler - if handler, err = s.SetupHandler(reader, prefix); err == nil { - if s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.port)); err == nil { - go func() { - if s.certFile != "" && s.keyFile != "" { - if err = http.ServeTLS(s.listener, handler, s.certFile, s.keyFile); err != nil { - memLogger.Error(err, "failed to start TLS mock server") - } - } else { - memLogger.Info("start HTTP mock server") - err = http.Serve(s.listener, handler) - } - }() - } - } - return + var handler http.Handler + if handler, err = s.SetupHandler(reader, prefix); err == nil { + if s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.port)); err == nil { + go func() { + if s.certFile != "" && s.keyFile != "" { + if err = http.ServeTLS(s.listener, handler, s.certFile, s.keyFile); err != nil { + memLogger.Error(err, "failed to start TLS mock server") + } + } else { + memLogger.Info("start HTTP mock server") + err = http.Serve(s.listener, handler) + } + }() + } + } + return } func (s *inMemoryServer) EnableMetrics() { - s.metrics = NewInMemoryMetrics() + s.metrics = NewInMemoryMetrics() } func (s *inMemoryServer) startObject(obj Object) { - // create a simple CRUD server - s.mux.HandleFunc("/"+obj.Name, func(w http.ResponseWriter, req *http.Request) { - memLogger.Info("mock server received request", "path", req.URL.Path) - s.metrics.RecordRequest(req.URL.Path) - method := req.Method - w.Header().Set(util.ContentType, util.JSON) - - switch method { - case http.MethodGet: - // list all items - allItems := s.data[obj.Name] - filteredItems := make([]map[string]interface{}, 0) - - for i, item := range allItems { - exclude := false - - for k, v := range req.URL.Query() { - if len(v) == 0 { - continue - } - - if val, ok := item[k]; ok && val != v[0] { - exclude = true - break - } - } - - if !exclude { - filteredItems = append(filteredItems, allItems[i]) - } - } - - if len(filteredItems) != len(allItems) { - allItems = filteredItems - } - - data, err := json.Marshal(allItems) - writeResponse(w, data, err) - case http.MethodPost: - // create an item - if data, err := io.ReadAll(req.Body); err == nil { - objData := map[string]interface{}{} - - jsonErr := json.Unmarshal(data, &objData) - if jsonErr != nil { - memLogger.Info(jsonErr.Error()) - return - } - - s.data[obj.Name] = append(s.data[obj.Name], objData) - - _, _ = w.Write(data) - } else { - memLogger.Info("failed to read from body", "error", err) - } - default: - w.WriteHeader(http.StatusMethodNotAllowed) - } - }) - - // handle a single object - s.mux.HandleFunc(fmt.Sprintf("/%s/{name}", obj.Name), func(w http.ResponseWriter, req *http.Request) { - s.metrics.RecordRequest(req.URL.Path) - w.Header().Set(util.ContentType, util.JSON) - objects := s.data[obj.Name] - if objects != nil { - name := mux.Vars(req)["name"] - var data []byte - for _, obj := range objects { - if obj["name"] == name { - - data, _ = json.Marshal(obj) - break - } - } - - if len(data) == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - - method := req.Method - switch method { - case http.MethodGet: - writeResponse(w, data, nil) - case http.MethodPut: - objData := map[string]interface{}{} - if data, err := io.ReadAll(req.Body); err == nil { - - jsonErr := json.Unmarshal(data, &objData) - if jsonErr != nil { - memLogger.Info(jsonErr.Error()) - return - } - for i, item := range s.data[obj.Name] { - if item["name"] == name { - s.data[obj.Name][i] = objData - break - } - } - _, _ = w.Write(data) - } - case http.MethodDelete: - for i, item := range s.data[obj.Name] { - if item["name"] == name { - if len(s.data[obj.Name]) == i+1 { - s.data[obj.Name] = s.data[obj.Name][:i] - } else { - s.data[obj.Name] = append(s.data[obj.Name][:i], s.data[obj.Name][i+1]) - } - - writeResponse(w, []byte(`{"msg": "deleted"}`), nil) - } - } - default: - w.WriteHeader(http.StatusMethodNotAllowed) - } - - } - }) + // create a simple CRUD server + s.mux.HandleFunc("/"+obj.Name, func(w http.ResponseWriter, req *http.Request) { + memLogger.Info("mock server received request", "path", req.URL.Path) + s.metrics.RecordRequest(req.URL.Path) + method := req.Method + w.Header().Set(util.ContentType, util.JSON) + + switch method { + case http.MethodGet: + // list all items + allItems := s.data[obj.Name] + filteredItems := make([]map[string]interface{}, 0) + + for i, item := range allItems { + exclude := false + + for k, v := range req.URL.Query() { + if len(v) == 0 { + continue + } + + if val, ok := item[k]; ok && val != v[0] { + exclude = true + break + } + } + + if !exclude { + filteredItems = append(filteredItems, allItems[i]) + } + } + + if len(filteredItems) != len(allItems) { + allItems = filteredItems + } + + data, err := json.Marshal(allItems) + writeResponse(w, data, err) + case http.MethodPost: + // create an item + if data, err := io.ReadAll(req.Body); err == nil { + objData := map[string]interface{}{} + + jsonErr := json.Unmarshal(data, &objData) + if jsonErr != nil { + memLogger.Info(jsonErr.Error()) + return + } + + s.data[obj.Name] = append(s.data[obj.Name], objData) + + _, _ = w.Write(data) + } else { + memLogger.Info("failed to read from body", "error", err) + } + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + + // handle a single object + s.mux.HandleFunc(fmt.Sprintf("/%s/{name}", obj.Name), func(w http.ResponseWriter, req *http.Request) { + s.metrics.RecordRequest(req.URL.Path) + w.Header().Set(util.ContentType, util.JSON) + objects := s.data[obj.Name] + if objects != nil { + name := mux.Vars(req)["name"] + var data []byte + for _, obj := range objects { + if obj["name"] == name { + + data, _ = json.Marshal(obj) + break + } + } + + if len(data) == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + + method := req.Method + switch method { + case http.MethodGet: + writeResponse(w, data, nil) + case http.MethodPut: + objData := map[string]interface{}{} + if data, err := io.ReadAll(req.Body); err == nil { + + jsonErr := json.Unmarshal(data, &objData) + if jsonErr != nil { + memLogger.Info(jsonErr.Error()) + return + } + for i, item := range s.data[obj.Name] { + if item["name"] == name { + s.data[obj.Name][i] = objData + break + } + } + _, _ = w.Write(data) + } + case http.MethodDelete: + for i, item := range s.data[obj.Name] { + if item["name"] == name { + if len(s.data[obj.Name]) == i+1 { + s.data[obj.Name] = s.data[obj.Name][:i] + } else { + s.data[obj.Name] = append(s.data[obj.Name][:i], s.data[obj.Name][i+1]) + } + + writeResponse(w, []byte(`{"msg": "deleted"}`), nil) + } + } + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } + + } + }) } func (s *inMemoryServer) startItem(item Item) { - method := util.EmptyThenDefault(item.Request.Method, http.MethodGet) - memLogger.Info("register mock service", "method", method, "path", item.Request.Path, "encoder", item.Response.Encoder) - - var headerSlices []string - for k, v := range item.Request.Header { - headerSlices = append(headerSlices, k, v) - } - - adHandler := &advanceHandler{ - item: &item, - metrics: s.metrics, - mu: sync.Mutex{}, - } - existedRoute := s.mux.GetRoute(item.Name) - if existedRoute == nil { - s.mux.NewRoute().Name(item.Name).Methods(strings.Split(method, ",")...).Headers(headerSlices...).Path(item.Request.Path).HandlerFunc(adHandler.handle) - } else { - existedRoute.HandlerFunc(adHandler.handle) - } + method := util.EmptyThenDefault(item.Request.Method, http.MethodGet) + memLogger.Info("register mock service", "method", method, "path", item.Request.Path, "encoder", item.Response.Encoder) + + var headerSlices []string + for k, v := range item.Request.Header { + headerSlices = append(headerSlices, k, v) + } + + adHandler := &advanceHandler{ + item: &item, + metrics: s.metrics, + mu: sync.Mutex{}, + } + existedRoute := s.mux.GetRoute(item.Name) + if existedRoute == nil { + s.mux.NewRoute().Name(item.Name).Methods(strings.Split(method, ",")...).Headers(headerSlices...).Path(item.Request.Path).HandlerFunc(adHandler.handle) + } else { + existedRoute.HandlerFunc(adHandler.handle) + } } type advanceHandler struct { - item *Item - metrics RequestMetrics - mu sync.Mutex + item *Item + metrics RequestMetrics + mu sync.Mutex } func (h *advanceHandler) handle(w http.ResponseWriter, req *http.Request) { - h.mu.Lock() - defer h.mu.Unlock() - - h.metrics.RecordRequest(req.URL.Path) - memLogger.Info("receiving mock request", "name", h.item.Name, "method", req.Method, "path", req.URL.Path, - "encoder", h.item.Response.Encoder) - - h.item.Param = mux.Vars(req) - if h.item.Param == nil { - h.item.Param = make(map[string]string) - } - h.item.Param["Host"] = req.Host - if h.item.Response.Header == nil { - h.item.Response.Header = make(map[string]string) - } - h.item.Response.Header[headerMockServer] = fmt.Sprintf("api-testing: %s", version.GetVersion()) - for k, v := range h.item.Response.Header { - hv, hErr := render.Render("mock-server-header", v, &h.item) - if hErr != nil { - hv = v - memLogger.Error(hErr, "failed render mock-server-header", "value", v) - } - - w.Header().Set(k, hv) - } - - var err error - if h.item.Response.Encoder == "base64" { - h.item.Response.BodyData, err = base64.StdEncoding.DecodeString(h.item.Response.Body) - } else if h.item.Response.Encoder == "url" { - var resp *http.Response - if resp, err = http.Get(h.item.Response.Body); err == nil { - h.item.Response.BodyData, err = io.ReadAll(resp.Body) - } - } else { - if h.item.Response.BodyData, err = render.RenderAsBytes("start-item", h.item.Response.Body, h.item); err != nil { - fmt.Printf("failed to render body: %v", err) - } - } - - if err == nil { - h.item.Response.Header[util.ContentLength] = fmt.Sprintf("%d", len(h.item.Response.BodyData)) - w.Header().Set(util.ContentLength, h.item.Response.Header[util.ContentLength]) - } - - writeResponse(w, h.item.Response.BodyData, err) + h.mu.Lock() + defer h.mu.Unlock() + + h.metrics.RecordRequest(req.URL.Path) + memLogger.Info("receiving mock request", "name", h.item.Name, "method", req.Method, "path", req.URL.Path, + "encoder", h.item.Response.Encoder) + + h.item.Param = mux.Vars(req) + if h.item.Param == nil { + h.item.Param = make(map[string]string) + } + h.item.Param["Host"] = req.Host + if h.item.Response.Header == nil { + h.item.Response.Header = make(map[string]string) + } + h.item.Response.Header[headerMockServer] = fmt.Sprintf("api-testing: %s", version.GetVersion()) + for k, v := range h.item.Response.Header { + hv, hErr := render.Render("mock-server-header", v, &h.item) + if hErr != nil { + hv = v + memLogger.Error(hErr, "failed render mock-server-header", "value", v) + } + + w.Header().Set(k, hv) + } + + if h.item.Response.BodyFromFile != "" { + // read from file + if data, readErr := os.ReadFile(h.item.Response.BodyFromFile); readErr != nil { + memLogger.Error(readErr, "failed to read file", "file", h.item.Response.BodyFromFile) + } else { + h.item.Response.Body = string(data) + } + } + + var err error + if h.item.Response.Encoder == "base64" { + h.item.Response.BodyData, err = base64.StdEncoding.DecodeString(h.item.Response.Body) + } else if h.item.Response.Encoder == "url" { + var resp *http.Response + if resp, err = http.Get(h.item.Response.Body); err == nil { + h.item.Response.BodyData, err = io.ReadAll(resp.Body) + } + } else if h.item.Response.Encoder == "raw" { + h.item.Response.BodyData = []byte(h.item.Response.Body) + } else { + if h.item.Response.BodyData, err = render.RenderAsBytes("start-item", h.item.Response.Body, h.item); err != nil { + memLogger.Error(err, "failed to render body") + } + } + + if err == nil { + h.item.Response.Header[util.ContentLength] = fmt.Sprintf("%d", len(h.item.Response.BodyData)) + w.Header().Set(util.ContentLength, h.item.Response.Header[util.ContentLength]) + } + + writeResponse(w, h.item.Response.BodyData, err) } func writeResponse(w http.ResponseWriter, data []byte, err error) { - if err == nil { - w.Write(data) - } else { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(err.Error())) - } + if err == nil { + w.Write(data) + } else { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + } } func (s *inMemoryServer) initObjectData(obj Object) { - if obj.Sample == "" { - return - } - - defaultCount := 1 - if obj.InitCount == nil { - obj.InitCount = &defaultCount - } - - for i := 0; i < *obj.InitCount; i++ { - objData, jsonErr := jsonStrToInterface(obj.Sample) - if jsonErr == nil { - s.data[obj.Name] = append(s.data[obj.Name], objData) - } else { - memLogger.Info(jsonErr.Error()) - } - } + if obj.Sample == "" { + return + } + + defaultCount := 1 + if obj.InitCount == nil { + obj.InitCount = &defaultCount + } + + for i := 0; i < *obj.InitCount; i++ { + objData, jsonErr := jsonStrToInterface(obj.Sample) + if jsonErr == nil { + s.data[obj.Name] = append(s.data[obj.Name], objData) + } else { + memLogger.Info(jsonErr.Error()) + } + } } func (s *inMemoryServer) startWebhook(webhook *Webhook) (err error) { - if webhook.Timer == "" || webhook.Name == "" { - return - } - - var duration time.Duration - duration, err = time.ParseDuration(webhook.Timer) - if err != nil { - memLogger.Error(err, "Error parsing webhook timer") - return - } - - s.wg.Add(1) - go func(wh *Webhook) { - defer s.wg.Done() - - memLogger.Info("start webhook server", "name", wh.Name) - timer := time.NewTimer(duration) - for { - timer.Reset(duration) - select { - case <-s.ctx.Done(): - memLogger.Info("stop webhook server", "name", wh.Name) - return - case <-timer.C: - if err = runWebhook(s.ctx, s, wh); err != nil { - memLogger.Error(err, "Error when run webhook") - } - } - } - }(webhook) - return + if webhook.Timer == "" || webhook.Name == "" { + return + } + + var duration time.Duration + duration, err = time.ParseDuration(webhook.Timer) + if err != nil { + memLogger.Error(err, "Error parsing webhook timer") + return + } + + s.wg.Add(1) + go func(wh *Webhook) { + defer s.wg.Done() + + memLogger.Info("start webhook server", "name", wh.Name) + timer := time.NewTimer(duration) + for { + timer.Reset(duration) + select { + case <-s.ctx.Done(): + memLogger.Info("stop webhook server", "name", wh.Name) + return + case <-timer.C: + if err = runWebhook(s.ctx, s, wh); err != nil { + memLogger.Error(err, "Error when run webhook") + } + } + } + }(webhook) + return } func runWebhook(ctx context.Context, objCtx interface{}, wh *Webhook) (err error) { - rawParams := make(map[string]string, len(wh.Param)) - paramKeys := make([]string, 0, len(wh.Param)) - for k, v := range wh.Param { - paramKeys = append(paramKeys, k) - rawParams[k] = v - } - sort.Strings(paramKeys) - - for _, k := range paramKeys { - v, vErr := render.Render("mock webhook server param", wh.Param[k], wh) - if vErr == nil { - wh.Param[k] = v - } - } - - var payload io.Reader - payload, err = render.RenderAsReader("mock webhook server payload", wh.Request.Body, wh) - if err != nil { - err = fmt.Errorf("error when render payload: %w", err) - return - } - wh.Param = rawParams - - var api string - api, err = render.Render("webhook request api", wh.Request.Path, objCtx) - if err != nil { - err = fmt.Errorf("error when render api: %w, template: %s", err, wh.Request.Path) - return - } - - switch wh.Request.Protocol { - case "syslog": - err = sendSyslogWebhookRequest(ctx, wh, api, payload) - default: - err = sendHTTPWebhookRequest(ctx, wh, api, payload) - } - return + rawParams := make(map[string]string, len(wh.Param)) + paramKeys := make([]string, 0, len(wh.Param)) + for k, v := range wh.Param { + paramKeys = append(paramKeys, k) + rawParams[k] = v + } + sort.Strings(paramKeys) + + for _, k := range paramKeys { + v, vErr := render.Render("mock webhook server param", wh.Param[k], wh) + if vErr == nil { + wh.Param[k] = v + } + } + + var payload io.Reader + payload, err = render.RenderAsReader("mock webhook server payload", wh.Request.Body, wh) + if err != nil { + err = fmt.Errorf("error when render payload: %w", err) + return + } + wh.Param = rawParams + + var api string + api, err = render.Render("webhook request api", wh.Request.Path, objCtx) + if err != nil { + err = fmt.Errorf("error when render api: %w, template: %s", err, wh.Request.Path) + return + } + + switch wh.Request.Protocol { + case "syslog": + err = sendSyslogWebhookRequest(ctx, wh, api, payload) + default: + err = sendHTTPWebhookRequest(ctx, wh, api, payload) + } + return } func sendSyslogWebhookRequest(ctx context.Context, wh *Webhook, api string, payload io.Reader) (err error) { - var conn net.Conn - if conn, err = net.Dial("udp", api); err == nil { - _, err = io.Copy(conn, payload) - } - return + var conn net.Conn + if conn, err = net.Dial("udp", api); err == nil { + _, err = io.Copy(conn, payload) + } + return } func sendHTTPWebhookRequest(ctx context.Context, wh *Webhook, api string, payload io.Reader) (err error) { - method := util.EmptyThenDefault(wh.Request.Method, http.MethodPost) - client := http.DefaultClient - - var bearerToken string - bearerToken, err = getBearerToken(ctx, wh.Request) - if err != nil { - memLogger.Error(err, "Error when render bearer token") - return - } - - var req *http.Request - req, err = http.NewRequestWithContext(ctx, method, api, payload) - if err != nil { - memLogger.Error(err, "Error when create request") - return - } - - if bearerToken != "" { - memLogger.V(7).Info("set bearer token", "token", bearerToken) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) - } - - for k, v := range wh.Request.Header { - req.Header.Set(k, v) - } - - memLogger.Info("send webhook request", "api", api) - resp, err := client.Do(req) - if err != nil { - err = fmt.Errorf("error when sending webhook: %v", err) - } else { - if resp.StatusCode != http.StatusOK { - memLogger.Info("unexpected status", "code", resp.StatusCode) - } - - data, _ := io.ReadAll(resp.Body) - memLogger.V(7).Info("received from webhook", "code", resp.StatusCode, "response", string(data)) - } - return + method := util.EmptyThenDefault(wh.Request.Method, http.MethodPost) + client := http.DefaultClient + + var bearerToken string + bearerToken, err = getBearerToken(ctx, wh.Request) + if err != nil { + memLogger.Error(err, "Error when render bearer token") + return + } + + var req *http.Request + req, err = http.NewRequestWithContext(ctx, method, api, payload) + if err != nil { + memLogger.Error(err, "Error when create request") + return + } + + if bearerToken != "" { + memLogger.V(7).Info("set bearer token", "token", bearerToken) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) + } + + for k, v := range wh.Request.Header { + req.Header.Set(k, v) + } + + memLogger.Info("send webhook request", "api", api) + resp, err := client.Do(req) + if err != nil { + err = fmt.Errorf("error when sending webhook: %v", err) + } else { + if resp.StatusCode != http.StatusOK { + memLogger.Info("unexpected status", "code", resp.StatusCode) + } + + data, _ := io.ReadAll(resp.Body) + memLogger.V(7).Info("received from webhook", "code", resp.StatusCode, "response", string(data)) + } + return } type bearerToken struct { - Token string `json:"token"` + Token string `json:"token"` } func getBearerToken(ctx context.Context, request RequestWithAuth) (token string, err error) { - if request.BearerAPI == "" { - return - } - - if request.BearerAPI, err = render.Render("bearer token request", request.BearerAPI, &request); err != nil { - return - } - - var data []byte - if data, err = json.Marshal(&request); err == nil { - client := http.DefaultClient - var req *http.Request - if req, err = http.NewRequestWithContext(ctx, http.MethodPost, request.BearerAPI, bytes.NewBuffer(data)); err == nil { - req.Header.Set(util.ContentType, util.JSON) - - var resp *http.Response - if resp, err = client.Do(req); err == nil && resp.StatusCode == http.StatusOK { - if data, err = io.ReadAll(resp.Body); err == nil { - var tokenObj bearerToken - if err = json.Unmarshal(data, &tokenObj); err == nil { - token = tokenObj.Token - } - } - } - } - } - - return + if request.BearerAPI == "" { + return + } + + if request.BearerAPI, err = render.Render("bearer token request", request.BearerAPI, &request); err != nil { + return + } + + var data []byte + if data, err = json.Marshal(&request); err == nil { + client := http.DefaultClient + var req *http.Request + if req, err = http.NewRequestWithContext(ctx, http.MethodPost, request.BearerAPI, bytes.NewBuffer(data)); err == nil { + req.Header.Set(util.ContentType, util.JSON) + + var resp *http.Response + if resp, err = client.Do(req); err == nil && resp.StatusCode == http.StatusOK { + if data, err = io.ReadAll(resp.Body); err == nil { + var tokenObj bearerToken + if err = json.Unmarshal(data, &tokenObj); err == nil { + token = tokenObj.Token + } + } + } + } + } + + return } func (s *inMemoryServer) handleOpenAPI() { - s.mux.HandleFunc("/api.json", func(w http.ResponseWriter, req *http.Request) { - // Setup OpenAPI schema - reflector := openapi3.NewReflector() - reflector.SpecSchema().SetTitle("Mock Server API") - reflector.SpecSchema().SetVersion(version.GetVersion()) - reflector.SpecSchema().SetDescription("Powered by https://github.com/linuxsuren/api-testing") - - // Walk the router with OpenAPI collector - c := gorillamux.NewOpenAPICollector(reflector) - - _ = s.mux.Walk(c.Walker) - - // Get the resulting schema - if jsonData, err := reflector.Spec.MarshalJSON(); err == nil { - w.Header().Set(util.ContentType, util.JSON) - _, _ = w.Write(jsonData) - } else { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(err.Error())) - } - }) + s.mux.HandleFunc("/api.json", func(w http.ResponseWriter, req *http.Request) { + // Setup OpenAPI schema + reflector := openapi3.NewReflector() + reflector.SpecSchema().SetTitle("Mock Server API") + reflector.SpecSchema().SetVersion(version.GetVersion()) + reflector.SpecSchema().SetDescription("Powered by https://github.com/linuxsuren/api-testing") + + // Walk the router with OpenAPI collector + c := gorillamux.NewOpenAPICollector(reflector) + + _ = s.mux.Walk(c.Walker) + + // Get the resulting schema + if jsonData, err := reflector.Spec.MarshalJSON(); err == nil { + w.Header().Set(util.ContentType, util.JSON) + _, _ = w.Write(jsonData) + } else { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + } + }) } func jsonStrToInterface(jsonStr string) (objData map[string]interface{}, err error) { - if jsonStr, err = render.Render("init object", jsonStr, nil); err == nil { - objData = map[string]interface{}{} - err = json.Unmarshal([]byte(jsonStr), &objData) - } - return + if jsonStr, err = render.Render("init object", jsonStr, nil); err == nil { + objData = map[string]interface{}{} + err = json.Unmarshal([]byte(jsonStr), &objData) + } + return } func (s *inMemoryServer) GetPort() string { - return util.GetPort(s.listener) + return util.GetPort(s.listener) } func (s *inMemoryServer) Stop() (err error) { - if s.listener != nil { - if err = s.listener.Close(); err != nil { - memLogger.Error(err, "failed to close listener") - } - } else { - memLogger.Info("listener is nil") - } - if s.cancelFunc != nil { - s.cancelFunc() - } - s.wg.Wait() - return + if s.listener != nil { + if err = s.listener.Close(); err != nil { + memLogger.Error(err, "failed to close listener") + } + } else { + memLogger.Info("listener is nil") + } + if s.cancelFunc != nil { + s.cancelFunc() + } + s.wg.Wait() + return } diff --git a/pkg/mock/in_memory_test.go b/pkg/mock/in_memory_test.go index f6977859..74a5dc68 100644 --- a/pkg/mock/in_memory_test.go +++ b/pkg/mock/in_memory_test.go @@ -16,227 +16,238 @@ limitations under the License. package mock import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "strings" - "testing" - - "github.com/linuxsuren/api-testing/pkg/util" - "github.com/stretchr/testify/assert" + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + + _ "embed" + "github.com/linuxsuren/api-testing/pkg/util" + "github.com/stretchr/testify/assert" ) +//go:embed testdata/api.yaml +var mockFile []byte + func TestInMemoryServer(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - server.EnableMetrics() + server := NewInMemoryServer(context.Background(), 0) + server.EnableMetrics() - err := server.Start(NewLocalFileReader("testdata/api.yaml"), "/mock") - assert.NoError(t, err) - defer func() { - server.Stop() - }() + err := server.Start(NewLocalFileReader("testdata/api.yaml"), "/mock") + assert.NoError(t, err) + defer func() { + server.Stop() + }() - api := "http://localhost:" + server.GetPort() + "/mock" + api := "http://localhost:" + server.GetPort() + "/mock" - _, err = http.Post(api+"/team", "", bytes.NewBufferString(`{ + _, err = http.Post(api+"/team", "", bytes.NewBufferString(`{ "name": "test", "members": [] }`)) - assert.NoError(t, err) - - var resp *http.Response - resp, err = http.Get(api + "/team") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, `[{"name":"someone"},{"members":[],"name":"test"}]`, string(data)) - } - - t.Run("check the /api.json", func(t *testing.T) { - var resp *http.Response - resp, err = http.Get(api + "/api.json") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.NotEmpty(t, string(data)) - } - }) - - t.Run("list with filter", func(t *testing.T) { - var resp *http.Response - resp, err = http.Get(api + "/team?name=someone") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, `[{"name":"someone"}]`, string(data)) - } - }) - - t.Run("update object", func(t *testing.T) { - updateReq, err := http.NewRequest(http.MethodPut, api+"/team/test", bytes.NewBufferString(`{ + assert.NoError(t, err) + + var resp *http.Response + resp, err = http.Get(api + "/team") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, `[{"name":"someone"},{"members":[],"name":"test"}]`, string(data)) + } + + t.Run("check the /api.json", func(t *testing.T) { + var resp *http.Response + resp, err = http.Get(api + "/api.json") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.NotEmpty(t, string(data)) + } + }) + + t.Run("list with filter", func(t *testing.T) { + var resp *http.Response + resp, err = http.Get(api + "/team?name=someone") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, `[{"name":"someone"}]`, string(data)) + } + }) + + t.Run("update object", func(t *testing.T) { + updateReq, err := http.NewRequest(http.MethodPut, api+"/team/test", bytes.NewBufferString(`{ "name": "test", "members": [{ "name": "rick" }] }`)) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(updateReq) - assert.NoError(t, err) - }) - - t.Run("get a single object", func(t *testing.T) { - resp, err = http.Get(api + "/team/test") - assert.NoError(t, err) - - var data []byte - data, err = io.ReadAll(resp.Body) - assert.NoError(t, err) - - assert.Equal(t, `{"members":[{"name":"rick"}],"name":"test"}`, string(data)) - }) - - // delete object - delReq, err := http.NewRequest(http.MethodDelete, api+"/team/test", nil) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(delReq) - assert.NoError(t, err) - - t.Run("check if deleted", func(t *testing.T) { - var resp *http.Response - resp, err = http.Get(api + "/team") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, `[{"name":"someone"}]`, string(data)) - } - - resp, err = http.Get(api + "/team/test") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - } - }) - - t.Run("invalid request method", func(t *testing.T) { - delReq, err := http.NewRequest("fake", api+"/team", nil) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(delReq) - assert.NoError(t, err) - assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) - }) - - t.Run("only accept GET method in getting a single object", func(t *testing.T) { - wrongMethodReq, err := http.NewRequest(http.MethodPut, api+"/team", nil) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(wrongMethodReq) - assert.NoError(t, err) - assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) - }) - - t.Run("mock item", func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) - assert.NoError(t, err) - req.Header.Set("name", "rick") - - resp, err = http.DefaultClient.Do(req) - assert.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "194", resp.Header.Get(util.ContentLength)) - assert.Equal(t, "mock", resp.Header.Get("Server")) - assert.NotEmpty(t, resp.Header.Get(headerMockServer)) - - data, _ := io.ReadAll(resp.Body) - assert.True(t, strings.Contains(string(data), `"message": "mock"`), string(data)) - }) - - t.Run("miss match header", func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) - assert.NoError(t, err) - - resp, err = http.DefaultClient.Do(req) - assert.NoError(t, err) - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - }) - - t.Run("base64 encoder", func(t *testing.T) { - resp, err = http.Get(api + "/v1/base64") - assert.NoError(t, err) - data, _ := io.ReadAll(resp.Body) - assert.Equal(t, "hello", string(data)) - }) - - t.Run("not found config file", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewLocalFileReader("fake"), "/") - assert.Error(t, err) - }) - - t.Run("invalid webhook", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(updateReq) + assert.NoError(t, err) + }) + + t.Run("get a single object", func(t *testing.T) { + resp, err = http.Get(api + "/team/test") + assert.NoError(t, err) + + var data []byte + data, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + + assert.Equal(t, `{"members":[{"name":"rick"}],"name":"test"}`, string(data)) + }) + + // delete object + delReq, err := http.NewRequest(http.MethodDelete, api+"/team/test", nil) + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(delReq) + assert.NoError(t, err) + + t.Run("check if deleted", func(t *testing.T) { + var resp *http.Response + resp, err = http.Get(api + "/team") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, `[{"name":"someone"}]`, string(data)) + } + + resp, err = http.Get(api + "/team/test") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } + }) + + t.Run("invalid request method", func(t *testing.T) { + delReq, err := http.NewRequest("fake", api+"/team", nil) + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(delReq) + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + }) + + t.Run("only accept GET method in getting a single object", func(t *testing.T) { + wrongMethodReq, err := http.NewRequest(http.MethodPut, api+"/team", nil) + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(wrongMethodReq) + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + }) + + t.Run("mock item", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) + assert.NoError(t, err) + req.Header.Set("name", "rick") + + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "194", resp.Header.Get(util.ContentLength)) + assert.Equal(t, "mock", resp.Header.Get("Server")) + assert.NotEmpty(t, resp.Header.Get(headerMockServer)) + + data, _ := io.ReadAll(resp.Body) + assert.True(t, strings.Contains(string(data), `"message": "mock"`), string(data)) + }) + + t.Run("miss match header", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) + assert.NoError(t, err) + + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("base64 encoder", func(t *testing.T) { + resp, err = http.Get(api + "/v1/base64") + assert.NoError(t, err) + data, _ := io.ReadAll(resp.Body) + assert.Equal(t, "hello", string(data)) + }) + + t.Run("read response from file", func(t *testing.T) { + resp, err = http.Get(api + "/v1/readResponseFromFile") + assert.NoError(t, err) + data, _ := io.ReadAll(resp.Body) + assert.Equal(t, mockFile, data) + }) + + t.Run("not found config file", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewLocalFileReader("fake"), "/") + assert.Error(t, err) + }) + + t.Run("invalid webhook", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - timer: aa name: fake`), "/") - assert.Error(t, err) - }) + assert.Error(t, err) + }) - t.Run("missing name or timer in webhook", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + t.Run("missing name or timer in webhook", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - timer: 1s`), "/") - assert.Error(t, err) - }) + assert.Error(t, err) + }) - t.Run("invalid webhook payload", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + t.Run("invalid webhook payload", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - name: invalid timer: 1ms request: body: "{{.fake"`), "/") - assert.Error(t, err) - }) + assert.Error(t, err) + }) - t.Run("invalid webhook api template", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + t.Run("invalid webhook api template", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - name: invalid timer: 1ms request: body: "{}" path: "{{.fake"`), "/") - assert.NoError(t, err) - }) - - t.Run("proxy", func(t *testing.T) { - resp, err = http.Get(api + "/v1/myProjects") - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - resp, err = http.Get(api + "/v1/invalid-template") - assert.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - }) - - t.Run("metrics", func(t *testing.T) { - resp, err = http.Get(api + "/metrics") - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("go template support in response body", func(t *testing.T) { - repoName := "myRepo" - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/v1/repos/%s/prs", api, repoName), nil) - assert.NoError(t, err) - - var resp *http.Response - req.Header.Set("name", "rick") - resp, err = http.DefaultClient.Do(req) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - data, _ := io.ReadAll(resp.Body) - assert.Contains(t, string(data), repoName) - }) + assert.NoError(t, err) + }) + + t.Run("proxy", func(t *testing.T) { + resp, err = http.Get(api + "/v1/myProjects") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = http.Get(api + "/v1/invalid-template") + assert.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + }) + + t.Run("metrics", func(t *testing.T) { + resp, err = http.Get(api + "/metrics") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("go template support in response body", func(t *testing.T) { + repoName := "myRepo" + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/v1/repos/%s/prs", api, repoName), nil) + assert.NoError(t, err) + + var resp *http.Response + req.Header.Set("name", "rick") + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, _ := io.ReadAll(resp.Body) + assert.Contains(t, string(data), repoName) + }) } diff --git a/pkg/mock/server.go b/pkg/mock/server.go index 1ddb05a1..1b034f70 100644 --- a/pkg/mock/server.go +++ b/pkg/mock/server.go @@ -1,5 +1,5 @@ /* -Copyright 2024 API Testing Authors. +Copyright 2024-2025 API Testing Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/pkg/mock/testdata/api.yaml b/pkg/mock/testdata/api.yaml index 2b9758ad..ac9f6107 100644 --- a/pkg/mock/testdata/api.yaml +++ b/pkg/mock/testdata/api.yaml @@ -30,6 +30,12 @@ items: response: body: aGVsbG8= encoder: base64 + - name: readResponseFromFile + request: + path: /v1/readResponseFromFile + response: + encoder: raw + bodyFromFile: testdata/api.yaml - name: prList request: path: /v1/repos/{repo}/prs diff --git a/pkg/mock/types.go b/pkg/mock/types.go index 5b9f456f..5b0f365b 100644 --- a/pkg/mock/types.go +++ b/pkg/mock/types.go @@ -44,11 +44,12 @@ type RequestWithAuth struct { } type Response struct { - Encoder string `yaml:"encoder" json:"encoder"` - Body string `yaml:"body" json:"body"` - Header map[string]string `yaml:"header" json:"header"` - StatusCode int `yaml:"statusCode" json:"statusCode"` - BodyData []byte + Encoder string `yaml:"encoder" json:"encoder"` + Body string `yaml:"body" json:"body"` + BodyFromFile string `yaml:"bodyFromFile" json:"bodyFromFile"` + Header map[string]string `yaml:"header" json:"header"` + StatusCode int `yaml:"statusCode" json:"statusCode"` + BodyData []byte } type Webhook struct { diff --git a/pkg/server/remote_server.go b/pkg/server/remote_server.go index 16788328..28ae8bcb 100644 --- a/pkg/server/remote_server.go +++ b/pkg/server/remote_server.go @@ -17,78 +17,78 @@ limitations under the License. package server import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "mime" - "net/http" - "os" - "path/filepath" - reflect "reflect" - "regexp" - "slices" - "strconv" - "strings" - "sync" - "time" - - "github.com/expr-lang/expr/builtin" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - - "github.com/linuxsuren/api-testing/docs" - "github.com/linuxsuren/api-testing/pkg/util/home" - - "github.com/linuxsuren/api-testing/pkg/mock" - - _ "embed" - - "github.com/linuxsuren/api-testing/pkg/generator" - "github.com/linuxsuren/api-testing/pkg/logging" - "github.com/linuxsuren/api-testing/pkg/oauth" - "github.com/linuxsuren/api-testing/pkg/render" - "github.com/linuxsuren/api-testing/pkg/runner" - "github.com/linuxsuren/api-testing/pkg/testing" - "github.com/linuxsuren/api-testing/pkg/util" - "github.com/linuxsuren/api-testing/pkg/version" - "github.com/linuxsuren/api-testing/sample" - - "google.golang.org/grpc/metadata" - "gopkg.in/yaml.v3" + "bytes" + "context" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + reflect "reflect" + "regexp" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/expr-lang/expr/builtin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/linuxsuren/api-testing/docs" + "github.com/linuxsuren/api-testing/pkg/util/home" + + "github.com/linuxsuren/api-testing/pkg/mock" + + _ "embed" + + "github.com/linuxsuren/api-testing/pkg/generator" + "github.com/linuxsuren/api-testing/pkg/logging" + "github.com/linuxsuren/api-testing/pkg/oauth" + "github.com/linuxsuren/api-testing/pkg/render" + "github.com/linuxsuren/api-testing/pkg/runner" + "github.com/linuxsuren/api-testing/pkg/testing" + "github.com/linuxsuren/api-testing/pkg/util" + "github.com/linuxsuren/api-testing/pkg/version" + "github.com/linuxsuren/api-testing/sample" + + "google.golang.org/grpc/metadata" + "gopkg.in/yaml.v3" ) var ( - remoteServerLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("remote_server") - GrpcMaxRecvMsgSize int + remoteServerLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("remote_server") + GrpcMaxRecvMsgSize int ) type server struct { - UnimplementedRunnerServer - UnimplementedDataServerServer - UnimplementedThemeExtensionServer + UnimplementedRunnerServer + UnimplementedDataServerServer + UnimplementedThemeExtensionServer - loader testing.Writer - storeWriterFactory testing.StoreWriterFactory - configDir string - storeExtMgr ExtManager + loader testing.Writer + storeWriterFactory testing.StoreWriterFactory + configDir string + storeExtMgr ExtManager - secretServer SecretServiceServer + secretServer SecretServiceServer - grpcMaxRecvMsgSize int + grpcMaxRecvMsgSize int } type SecretServiceServer interface { - GetSecrets(context.Context, *Empty) (*Secrets, error) - CreateSecret(context.Context, *Secret) (*CommonResult, error) - DeleteSecret(context.Context, *Secret) (*CommonResult, error) - UpdateSecret(context.Context, *Secret) (*CommonResult, error) + GetSecrets(context.Context, *Empty) (*Secrets, error) + CreateSecret(context.Context, *Secret) (*CommonResult, error) + DeleteSecret(context.Context, *Secret) (*CommonResult, error) + UpdateSecret(context.Context, *Secret) (*CommonResult, error) } type SecertServiceGetable interface { - GetSecret(context.Context, *Secret) (*Secret, error) + GetSecret(context.Context, *Secret) (*Secret, error) } type fakeSecretServer struct{} @@ -96,1451 +96,1453 @@ type fakeSecretServer struct{} var errNoSecretService = errors.New("no secret service found") func (f *fakeSecretServer) GetSecrets(ctx context.Context, in *Empty) (reply *Secrets, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } func (f *fakeSecretServer) CreateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } func (f *fakeSecretServer) DeleteSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } func (f *fakeSecretServer) UpdateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } // NewRemoteServer creates a remote server instance func NewRemoteServer(loader testing.Writer, storeWriterFactory testing.StoreWriterFactory, secretServer SecretServiceServer, storeExtMgr ExtManager, configDir string, grpcMaxRecvMsgSize int) RunnerServer { - if secretServer == nil { - secretServer = &fakeSecretServer{} - } - GrpcMaxRecvMsgSize = grpcMaxRecvMsgSize - return &server{ - loader: loader, - storeWriterFactory: storeWriterFactory, - configDir: configDir, - secretServer: secretServer, - storeExtMgr: storeExtMgr, - grpcMaxRecvMsgSize: grpcMaxRecvMsgSize, - } + if secretServer == nil { + secretServer = &fakeSecretServer{} + } + GrpcMaxRecvMsgSize = grpcMaxRecvMsgSize + return &server{ + loader: loader, + storeWriterFactory: storeWriterFactory, + configDir: configDir, + secretServer: secretServer, + storeExtMgr: storeExtMgr, + grpcMaxRecvMsgSize: grpcMaxRecvMsgSize, + } } func withDefaultValue(old, defVal any) any { - if old == "" || old == nil { - old = defVal - } - return old + if old == "" || old == nil { + old = defVal + } + return old } func parseSuiteWithItems(data []byte) (suite *testing.TestSuite, err error) { - suite, err = testing.ParseFromData(data) - if err == nil && (suite == nil || suite.Items == nil) { - err = errNoTestSuiteFound - } - return + suite, err = testing.ParseFromData(data) + if err == nil && (suite == nil || suite.Items == nil) { + err = errNoTestSuiteFound + } + return } func (s *server) getSuiteFromTestTask(task *TestTask) (suite *testing.TestSuite, err error) { - switch task.Kind { - case "suite": - suite, err = parseSuiteWithItems([]byte(task.Data)) - case "testcase": - var testCase *testing.TestCase - if testCase, err = testing.ParseTestCaseFromData([]byte(task.Data)); err != nil { - return - } - suite = &testing.TestSuite{ - Items: []testing.TestCase{*testCase}, - } - case "testcaseInSuite": - suite, err = parseSuiteWithItems([]byte(task.Data)) - if err != nil { - return - } - - var targetTestcase *testing.TestCase - for _, item := range suite.Items { - if item.Name == task.CaseName { - targetTestcase = &item - break - } - } - - if targetTestcase != nil { - parentCases := findParentTestCases(targetTestcase, suite) - remoteServerLogger.Info("find parent cases", "num", len(parentCases)) - suite.Items = append(parentCases, *targetTestcase) - } else { - err = fmt.Errorf("cannot found testcase %s", task.CaseName) - } - default: - err = fmt.Errorf("not support '%s'", task.Kind) - } - return + switch task.Kind { + case "suite": + suite, err = parseSuiteWithItems([]byte(task.Data)) + case "testcase": + var testCase *testing.TestCase + if testCase, err = testing.ParseTestCaseFromData([]byte(task.Data)); err != nil { + return + } + suite = &testing.TestSuite{ + Items: []testing.TestCase{*testCase}, + } + case "testcaseInSuite": + suite, err = parseSuiteWithItems([]byte(task.Data)) + if err != nil { + return + } + + var targetTestcase *testing.TestCase + for _, item := range suite.Items { + if item.Name == task.CaseName { + targetTestcase = &item + break + } + } + + if targetTestcase != nil { + parentCases := findParentTestCases(targetTestcase, suite) + remoteServerLogger.Info("find parent cases", "num", len(parentCases)) + suite.Items = append(parentCases, *targetTestcase) + } else { + err = fmt.Errorf("cannot found testcase %s", task.CaseName) + } + default: + err = fmt.Errorf("not support '%s'", task.Kind) + } + return } func resetEnv(oldEnv map[string]string) { - for key, val := range oldEnv { - os.Setenv(key, val) - } + for key, val := range oldEnv { + os.Setenv(key, val) + } } func (s *server) getLoader(ctx context.Context) (loader testing.Writer) { - var ok bool - loader = s.loader - - var mdd metadata.MD - if mdd, ok = metadata.FromIncomingContext(ctx); ok { - storeNameMeta := mdd.Get(HeaderKeyStoreName) - if len(storeNameMeta) > 0 { - storeName := strings.TrimSpace(storeNameMeta[0]) - if storeName == "local" || storeName == "" { - return - } - - var err error - if loader, err = s.getLoaderByStoreName(storeName); err != nil { - remoteServerLogger.Info("failed to get loader", "name", storeName, "error", err) - loader = testing.NewNonWriter() - } - } - } - return + var ok bool + loader = s.loader + + var mdd metadata.MD + if mdd, ok = metadata.FromIncomingContext(ctx); ok { + storeNameMeta := mdd.Get(HeaderKeyStoreName) + if len(storeNameMeta) > 0 { + storeName := strings.TrimSpace(storeNameMeta[0]) + if storeName == "local" || storeName == "" { + return + } + + var err error + if loader, err = s.getLoaderByStoreName(storeName); err != nil { + remoteServerLogger.Info("failed to get loader", "name", storeName, "error", err) + loader = testing.NewNonWriter() + } + } + } + return } // Run start to run the test task func (s *server) Run(ctx context.Context, task *TestTask) (reply *TestResult, err error) { - task.Level = withDefaultValue(task.Level, "info").(string) - task.Env = withDefaultValue(task.Env, map[string]string{}).(map[string]string) - - var suite *testing.TestSuite - // TODO may not safe in multiple threads - oldEnv := map[string]string{} - for key, val := range task.Env { - oldEnv[key] = os.Getenv(key) - os.Setenv(key, val) - } - - defer func() { - resetEnv(oldEnv) - }() - - if suite, err = s.getSuiteFromTestTask(task); err != nil { - return - } - - remoteServerLogger.Info("prepare to run", "name", suite.Name, " with level: ", task.Level) - remoteServerLogger.Info("task kind to run", "kind", task.Kind, "lens", len(suite.Items)) - dataContext := map[string]interface{}{} - - if err = suite.Render(dataContext); err != nil { - reply.Error = err.Error() - err = nil - return - } - // inject the parameters from input - if len(task.Parameters) > 0 { - dataContext[testing.ContextKeyGlobalParam] = pairToMap(task.Parameters) - } - - buf := new(bytes.Buffer) - reply = &TestResult{} - - for _, testCase := range suite.Items { - suiteRunner := runner.GetTestSuiteRunner(suite) - suiteRunner.WithOutputWriter(buf) - suiteRunner.WithWriteLevel(task.Level) - suiteRunner.WithSecure(suite.Spec.Secure) - suiteRunner.WithSuite(suite) - - // reuse the API prefix - testCase.Request.RenderAPI(suite.API) - historyHeader := make(map[string]string) - for k, v := range testCase.Request.Header { - historyHeader[k] = v - } - - output, testErr := suiteRunner.RunTestCase(&testCase, dataContext, ctx) - if getter, ok := suiteRunner.(runner.ResponseRecord); ok { - resp := getter.GetResponseRecord() - //resp, err = runner.HandleLargeResponseBody(resp, suite.Name, testCase.Name) - reply.TestCaseResult = append(reply.TestCaseResult, &TestCaseResult{ - StatusCode: int32(resp.StatusCode), - Body: resp.Body, - Header: mapToPair(resp.Header), - Id: testCase.ID, - Output: buf.String(), - }) - } - - if testErr == nil { - dataContext[testCase.Name] = output - } else { - reply.Error = testErr.Error() - break - } - // create history record - go func(historyHeader map[string]string) { - loader := s.getLoader(ctx) - defer loader.Close() - for _, testCaseResult := range reply.TestCaseResult { - err = loader.CreateHistoryTestCase(ToNormalTestCaseResult(testCaseResult), suite, historyHeader) - if err != nil { - remoteServerLogger.Info("error create history") - } - } - }(historyHeader) - } - - if reply.Error != "" { - fmt.Fprintln(buf, reply.Error) - } - reply.Message = buf.String() - return + task.Level = withDefaultValue(task.Level, "info").(string) + task.Env = withDefaultValue(task.Env, map[string]string{}).(map[string]string) + + var suite *testing.TestSuite + // TODO may not safe in multiple threads + oldEnv := map[string]string{} + for key, val := range task.Env { + oldEnv[key] = os.Getenv(key) + os.Setenv(key, val) + } + + defer func() { + resetEnv(oldEnv) + }() + + if suite, err = s.getSuiteFromTestTask(task); err != nil { + return + } + + remoteServerLogger.Info("prepare to run", "name", suite.Name, " with level: ", task.Level) + remoteServerLogger.Info("task kind to run", "kind", task.Kind, "lens", len(suite.Items)) + dataContext := map[string]interface{}{} + + if err = suite.Render(dataContext); err != nil { + reply.Error = err.Error() + err = nil + return + } + // inject the parameters from input + if len(task.Parameters) > 0 { + dataContext[testing.ContextKeyGlobalParam] = pairToMap(task.Parameters) + } + + buf := new(bytes.Buffer) + reply = &TestResult{} + + for _, testCase := range suite.Items { + suiteRunner := runner.GetTestSuiteRunner(suite) + suiteRunner.WithOutputWriter(buf) + suiteRunner.WithWriteLevel(task.Level) + suiteRunner.WithSecure(suite.Spec.Secure) + suiteRunner.WithSuite(suite) + + // reuse the API prefix + testCase.Request.RenderAPI(suite.API) + historyHeader := make(map[string]string) + for k, v := range testCase.Request.Header { + historyHeader[k] = v + } + + output, testErr := suiteRunner.RunTestCase(&testCase, dataContext, ctx) + if getter, ok := suiteRunner.(runner.ResponseRecord); ok { + resp := getter.GetResponseRecord() + //resp, err = runner.HandleLargeResponseBody(resp, suite.Name, testCase.Name) + reply.TestCaseResult = append(reply.TestCaseResult, &TestCaseResult{ + StatusCode: int32(resp.StatusCode), + Body: resp.Body, + Header: mapToPair(resp.Header), + Id: testCase.ID, + Output: buf.String(), + }) + } + + if testErr == nil { + dataContext[testCase.Name] = output + } else { + reply.Error = testErr.Error() + break + } + // create history record + go func(historyHeader map[string]string) { + loader := s.getLoader(ctx) + defer loader.Close() + for _, testCaseResult := range reply.TestCaseResult { + err = loader.CreateHistoryTestCase(ToNormalTestCaseResult(testCaseResult), suite, historyHeader) + if err != nil { + remoteServerLogger.Info("error create history") + } + } + }(historyHeader) + } + + if reply.Error != "" { + fmt.Fprintln(buf, reply.Error) + } + reply.Message = buf.String() + return } func (s *server) BatchRun(srv Runner_BatchRunServer) (err error) { - ctx := srv.Context() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - var in *BatchTestTask - in, err = srv.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - - for i := 0; i < int(in.Count); i++ { - var reply *TestCaseResult - if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ - Suite: in.SuiteName, - Testcase: in.CaseName, - }); err != nil { - return - } - - if err = srv.Send(&TestResult{ - TestCaseResult: []*TestCaseResult{reply}, - Error: reply.Error, - }); err != nil { - return err - } - - var interval string - if interval, err = render.Render("batch run interval", in.Interval, nil); err != nil { - return - } - - var duration time.Duration - if duration, err = time.ParseDuration(interval); err != nil { - return - } - time.Sleep(duration) - } - } - } + ctx := srv.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + var in *BatchTestTask + in, err = srv.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + for i := 0; i < int(in.Count); i++ { + var reply *TestCaseResult + if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ + Suite: in.SuiteName, + Testcase: in.CaseName, + }); err != nil { + return + } + + if err = srv.Send(&TestResult{ + TestCaseResult: []*TestCaseResult{reply}, + Error: reply.Error, + }); err != nil { + return err + } + + var interval string + if interval, err = render.Render("batch run interval", in.Interval, nil); err != nil { + return + } + + var duration time.Duration + if duration, err = time.ParseDuration(interval); err != nil { + return + } + time.Sleep(duration) + } + } + } } func (s *server) DownloadResponseFile(ctx context.Context, in *TestCase) (reply *FileData, err error) { - if in.Response != nil { - tempFileName := in.Response.Body - if tempFileName == "" { - return nil, errors.New("file name is empty") - } - - tempDir := os.TempDir() - filePath := filepath.Join(tempDir, tempFileName) - if filepath.Clean(filePath) != filepath.Join(tempDir, filepath.Base(tempFileName)) { - return nil, errors.New("invalid file path") - } - - fmt.Println("get file from", filePath) - fileContent, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("failed to read file: %s", filePath) - } - - mimeType := mime.TypeByExtension(filepath.Ext(filePath)) - if mimeType == "" { - mimeType = "application/octet-stream" - } - - filename := filepath.Base(filePath) - // try to get the original filename - var originalFileName []byte - if originalFileName, err = os.ReadFile(filePath + "name"); err == nil && len(originalFileName) > 0 { - filename = string(originalFileName) - } - - reply = &FileData{ - Data: fileContent, - ContentType: mimeType, - Filename: filename, - } - - return reply, nil - } else { - return reply, errors.New("response is empty") - } + if in.Response != nil { + tempFileName := in.Response.Body + if tempFileName == "" { + return nil, errors.New("file name is empty") + } + + tempDir := os.TempDir() + filePath := filepath.Join(tempDir, tempFileName) + if filepath.Clean(filePath) != filepath.Join(tempDir, filepath.Base(tempFileName)) { + return nil, errors.New("invalid file path") + } + + fmt.Println("get file from", filePath) + fileContent, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %s", filePath) + } + + mimeType := mime.TypeByExtension(filepath.Ext(filePath)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + + filename := filepath.Base(filePath) + // try to get the original filename + var originalFileName []byte + if originalFileName, err = os.ReadFile(filePath + "name"); err == nil && len(originalFileName) > 0 { + filename = string(originalFileName) + } + + reply = &FileData{ + Data: fileContent, + ContentType: mimeType, + Filename: filename, + } + + return reply, nil + } else { + return reply, errors.New("response is empty") + } } func (s *server) RunTestSuite(srv Runner_RunTestSuiteServer) (err error) { - ctx := srv.Context() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - var in *TestSuiteIdentity - in, err = srv.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - - var suite *Suite - if suite, err = s.ListTestCase(ctx, in); err != nil { - return - } - - for _, item := range suite.Items { - var reply *TestCaseResult - if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ - Suite: in.Name, - Testcase: item.Name, - }); err != nil { - return - } - - if err = srv.Send(&TestResult{ - TestCaseResult: []*TestCaseResult{reply}, - Error: reply.Error, - }); err != nil { - return err - } - } - } - } + ctx := srv.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + var in *TestSuiteIdentity + in, err = srv.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + var suite *Suite + if suite, err = s.ListTestCase(ctx, in); err != nil { + return + } + + for _, item := range suite.Items { + var reply *TestCaseResult + if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ + Suite: in.Name, + Testcase: item.Name, + }); err != nil { + return + } + + if err = srv.Send(&TestResult{ + TestCaseResult: []*TestCaseResult{reply}, + Error: reply.Error, + }); err != nil { + return err + } + } + } + } } func (s *server) GetSchema(ctx context.Context, in *SimpleQuery) (result *CommonResult, err error) { - result = &CommonResult{ - Success: true, - } - switch in.Name { - case "core": - result.Message = docs.Schema - case "mock": - result.Message = docs.MockSchema - } - return + result = &CommonResult{ + Success: true, + } + switch in.Name { + case "core": + result.Message = docs.Schema + case "mock": + result.Message = docs.MockSchema + } + return } // GetVersion returns the version func (s *server) GetVersion(ctx context.Context, in *Empty) (reply *Version, err error) { - reply = &Version{ - Version: version.GetVersion(), - Date: version.GetDate(), - Commit: version.GetCommit(), - } - return + reply = &Version{ + Version: version.GetVersion(), + Date: version.GetDate(), + Commit: version.GetCommit(), + } + return } func (s *server) GetSuites(ctx context.Context, in *Empty) (reply *Suites, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &Suites{ - Data: make(map[string]*Items), - } - - var suites []testing.TestSuite - if suites, err = loader.ListTestSuite(); err == nil && suites != nil { - for _, suite := range suites { - items := &Items{} - for _, item := range suite.Items { - items.Data = append(items.Data, item.Name) - } - items.Kind = suite.Spec.Kind - reply.Data[suite.Name] = items - } - } - - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &Suites{ + Data: make(map[string]*Items), + } + + var suites []testing.TestSuite + if suites, err = loader.ListTestSuite(); err == nil && suites != nil { + for _, suite := range suites { + items := &Items{} + for _, item := range suite.Items { + items.Data = append(items.Data, item.Name) + } + items.Kind = suite.Spec.Kind + reply.Data[suite.Name] = items + } + } + + return } func (s *server) GetHistorySuites(ctx context.Context, in *Empty) (reply *HistorySuites, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HistorySuites{ - Data: make(map[string]*HistoryItems), - } - - var suites []testing.HistoryTestSuite - if suites, err = loader.ListHistoryTestSuite(); err == nil && suites != nil { - for _, suite := range suites { - items := &HistoryItems{} - for _, item := range suite.Items { - data := &HistoryCaseIdentity{ - ID: item.ID, - HistorySuiteName: item.HistorySuiteName, - Kind: item.SuiteSpec.Kind, - Suite: item.SuiteName, - Testcase: item.CaseName, - } - items.Data = append(items.Data, data) - } - reply.Data[suite.HistorySuiteName] = items - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HistorySuites{ + Data: make(map[string]*HistoryItems), + } + + var suites []testing.HistoryTestSuite + if suites, err = loader.ListHistoryTestSuite(); err == nil && suites != nil { + for _, suite := range suites { + items := &HistoryItems{} + for _, item := range suite.Items { + data := &HistoryCaseIdentity{ + ID: item.ID, + HistorySuiteName: item.HistorySuiteName, + Kind: item.SuiteSpec.Kind, + Suite: item.SuiteName, + Testcase: item.CaseName, + } + items.Data = append(items.Data, data) + } + reply.Data[suite.HistorySuiteName] = items + } + } + return } func (s *server) CreateTestSuite(ctx context.Context, in *TestSuiteIdentity) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - if loader == nil { - reply.Error = "no loader found" - } else { - if err = loader.CreateSuite(in.Name, in.Api); err == nil { - toUpdate := testing.TestSuite{ - Name: in.Name, - API: in.Api, - Spec: testing.APISpec{ - Kind: in.Kind, - }, - } - - switch strings.ToLower(in.Kind) { - case "grpc", "trpc": - toUpdate.Spec.RPC = &testing.RPCDesc{} - } - - err = loader.UpdateSuite(toUpdate) - } - } - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + if loader == nil { + reply.Error = "no loader found" + } else { + if err = loader.CreateSuite(in.Name, in.Api); err == nil { + toUpdate := testing.TestSuite{ + Name: in.Name, + API: in.Api, + Spec: testing.APISpec{ + Kind: in.Kind, + }, + } + + switch strings.ToLower(in.Kind) { + case "grpc", "trpc": + toUpdate.Spec.RPC = &testing.RPCDesc{} + } + + err = loader.UpdateSuite(toUpdate) + } + } + return } func (s *server) ImportTestSuite(ctx context.Context, in *TestSuiteSource) (result *CommonResult, err error) { - result = &CommonResult{} - var dataImporter generator.Importer - switch in.Kind { - case "postman": - dataImporter = generator.NewPostmanImporter() - case "native", "": - dataImporter = generator.NewNativeImporter() - default: - result.Success = false - result.Message = fmt.Sprintf("not support kind: %s", in.Kind) - return - } - - remoteServerLogger.Logger.Info("import test suite", "kind", in.Kind, "url", in.Url) - var suite *testing.TestSuite - if in.Url != "" { - suite, err = dataImporter.ConvertFromURL(in.Url) - } else if in.Data != "" { - suite, err = dataImporter.Convert([]byte(in.Data)) - } else { - err = errors.New("url or data is required") - } - - if err != nil { - result.Success = false - result.Message = err.Error() - return - } - - loader := s.getLoader(ctx) - defer loader.Close() - - if err = loader.CreateSuite(suite.Name, suite.API); err != nil { - return - } - - for _, item := range suite.Items { - if err = loader.CreateTestCase(suite.Name, item); err != nil { - break - } - } - result.Success = true - return + result = &CommonResult{} + var dataImporter generator.Importer + switch in.Kind { + case "postman": + dataImporter = generator.NewPostmanImporter() + case "native", "": + dataImporter = generator.NewNativeImporter() + default: + result.Success = false + result.Message = fmt.Sprintf("not support kind: %s", in.Kind) + return + } + + remoteServerLogger.Logger.Info("import test suite", "kind", in.Kind, "url", in.Url) + var suite *testing.TestSuite + if in.Url != "" { + suite, err = dataImporter.ConvertFromURL(in.Url) + } else if in.Data != "" { + suite, err = dataImporter.Convert([]byte(in.Data)) + } else { + err = errors.New("url or data is required") + } + + if err != nil { + result.Success = false + result.Message = err.Error() + return + } + + loader := s.getLoader(ctx) + defer loader.Close() + + if err = loader.CreateSuite(suite.Name, suite.API); err != nil { + return + } + + for _, item := range suite.Items { + if err = loader.CreateTestCase(suite.Name, item); err != nil { + break + } + } + result.Success = true + return } func (s *server) GetTestSuite(ctx context.Context, in *TestSuiteIdentity) (result *TestSuite, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - var suite *testing.TestSuite - if suite, _, err = loader.GetSuite(in.Name); err == nil && suite != nil { - result = ToGRPCSuite(suite) - } - return + loader := s.getLoader(ctx) + defer loader.Close() + var suite *testing.TestSuite + if suite, _, err = loader.GetSuite(in.Name); err == nil && suite != nil { + result = ToGRPCSuite(suite) + } + return } func (s *server) UpdateTestSuite(ctx context.Context, in *TestSuite) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.UpdateSuite(*ToNormalSuite(in)) - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.UpdateSuite(*ToNormalSuite(in)) + return } func (s *server) DeleteTestSuite(ctx context.Context, in *TestSuiteIdentity) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.DeleteSuite(in.Name) - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.DeleteSuite(in.Name) + return } func (s *server) DuplicateTestSuite(ctx context.Context, in *TestSuiteDuplicate) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - - if in.SourceSuiteName == in.TargetSuiteName { - reply.Error = "source and target suite name should be different" - return - } - - var suite testing.TestSuite - if suite, err = loader.GetTestSuite(in.SourceSuiteName, true); err == nil { - suite.Name = in.TargetSuiteName - if err = loader.CreateSuite(suite.Name, suite.API); err == nil { - for _, testCase := range suite.Items { - if err = loader.CreateTestCase(suite.Name, testCase); err != nil { - break - } - } - } - } - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + + if in.SourceSuiteName == in.TargetSuiteName { + reply.Error = "source and target suite name should be different" + return + } + + var suite testing.TestSuite + if suite, err = loader.GetTestSuite(in.SourceSuiteName, true); err == nil { + suite.Name = in.TargetSuiteName + if err = loader.CreateSuite(suite.Name, suite.API); err == nil { + for _, testCase := range suite.Items { + if err = loader.CreateTestCase(suite.Name, testCase); err != nil { + break + } + } + } + } + return } func (s *server) RenameTestSuite(ctx context.Context, in *TestSuiteDuplicate) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.RenameTestSuite(in.SourceSuiteName, in.TargetSuiteName) - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.RenameTestSuite(in.SourceSuiteName, in.TargetSuiteName) + return } func (s *server) ListTestCase(ctx context.Context, in *TestSuiteIdentity) (result *Suite, err error) { - var items []testing.TestCase - loader := s.getLoader(ctx) - defer loader.Close() - if items, err = loader.ListTestCase(in.Name); err == nil { - result = &Suite{} - for _, item := range items { - result.Items = append(result.Items, ToGRPCTestCase(item)) - } - } - return + var items []testing.TestCase + loader := s.getLoader(ctx) + defer loader.Close() + if items, err = loader.ListTestCase(in.Name); err == nil { + result = &Suite{} + for _, item := range items { + result.Items = append(result.Items, ToGRPCTestCase(item)) + } + } + return } func (s *server) GetTestSuiteYaml(ctx context.Context, in *TestSuiteIdentity) (reply *YamlData, err error) { - var data []byte - loader := s.getLoader(ctx) - defer loader.Close() - if data, err = loader.GetTestSuiteYaml(in.Name); err == nil { - reply = &YamlData{ - Data: data, - } - } - return + var data []byte + loader := s.getLoader(ctx) + defer loader.Close() + if data, err = loader.GetTestSuiteYaml(in.Name); err == nil { + reply = &YamlData{ + Data: data, + } + } + return } func (s *server) GetTestCase(ctx context.Context, in *TestCaseIdentity) (reply *TestCase, err error) { - var result testing.TestCase - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetTestCase(in.Suite, in.Testcase); err == nil { - reply = ToGRPCTestCase(result) + var result testing.TestCase + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetTestCase(in.Suite, in.Testcase); err == nil { + reply = ToGRPCTestCase(result) - var suite testing.TestSuite - if suite, err = loader.GetTestSuite(in.Suite, false); err == nil { - reply.Server = suite.API - } - } - return + var suite testing.TestSuite + if suite, err = loader.GetTestSuite(in.Suite, false); err == nil { + reply.Server = suite.API + } + } + return } func (s *server) GetHistoryTestCaseWithResult(ctx context.Context, in *HistoryTestCase) (reply *HistoryTestResult, err error) { - var result testing.HistoryTestResult - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetHistoryTestCaseWithResult(in.ID); err == nil { - reply = ToGRPCHistoryTestCaseResult(result) - } - return + var result testing.HistoryTestResult + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetHistoryTestCaseWithResult(in.ID); err == nil { + reply = ToGRPCHistoryTestCaseResult(result) + } + return } func (s *server) GetHistoryTestCase(ctx context.Context, in *HistoryTestCase) (reply *HistoryTestCase, err error) { - var result testing.HistoryTestCase - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetHistoryTestCase(in.ID); err == nil { - reply = ConvertToGRPCHistoryTestCase(result) - } - return + var result testing.HistoryTestCase + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetHistoryTestCase(in.ID); err == nil { + reply = ConvertToGRPCHistoryTestCase(result) + } + return } var ExecutionCountNum = promauto.NewCounter(prometheus.CounterOpts{ - Name: "atest_execution_count", - Help: "The total number of request execution", + Name: "atest_execution_count", + Help: "The total number of request execution", }) var ExecutionSuccessNum = promauto.NewCounter(prometheus.CounterOpts{ - Name: "atest_execution_success", - Help: "The total number of request execution success", + Name: "atest_execution_success", + Help: "The total number of request execution success", }) var ExecutionFailNum = promauto.NewCounter(prometheus.CounterOpts{ - Name: "atest_execution_fail", - Help: "The total number of request execution fail", + Name: "atest_execution_fail", + Help: "The total number of request execution fail", }) func (s *server) GetTestCaseAllHistory(ctx context.Context, in *TestCase) (result *HistoryTestCases, err error) { - var items []testing.HistoryTestCase - loader := s.getLoader(ctx) - defer loader.Close() - if items, err = loader.GetTestCaseAllHistory(in.SuiteName, in.Name); err == nil { - result = &HistoryTestCases{} - for _, item := range items { - result.Data = append(result.Data, ConvertToGRPCHistoryTestCase(item)) - } - } - return + var items []testing.HistoryTestCase + loader := s.getLoader(ctx) + defer loader.Close() + if items, err = loader.GetTestCaseAllHistory(in.SuiteName, in.Name); err == nil { + result = &HistoryTestCases{} + for _, item := range items { + result.Data = append(result.Data, ConvertToGRPCHistoryTestCase(item)) + } + } + return } func (s *server) RunTestCase(ctx context.Context, in *TestCaseIdentity) (result *TestCaseResult, err error) { - var targetTestSuite testing.TestSuite - ExecutionCountNum.Inc() - defer func() { - if result.Error == "" { - ExecutionSuccessNum.Inc() - } else { - ExecutionFailNum.Inc() - } - }() - - result = &TestCaseResult{} - loader := s.getLoader(ctx) - defer loader.Close() - targetTestSuite, err = loader.GetTestSuite(in.Suite, true) - if err != nil || targetTestSuite.Name == "" { - err = nil - result.Error = fmt.Sprintf("not found suite: %s", in.Suite) - return - } - - var data []byte - if data, err = yaml.Marshal(targetTestSuite); err == nil { - task := &TestTask{ - Kind: "testcaseInSuite", - Data: string(data), - CaseName: in.Testcase, - Level: "debug", - Parameters: in.Parameters, - } - - var reply *TestResult - var lastItem *TestCaseResult - if reply, err = s.Run(ctx, task); err == nil && len(reply.TestCaseResult) > 0 { - lastIndex := len(reply.TestCaseResult) - 1 - lastItem = reply.TestCaseResult[lastIndex] - - if len(lastItem.Body) > GrpcMaxRecvMsgSize { - e := "the HTTP response body exceeded the maximum message size limit received by the gRPC client" - result = &TestCaseResult{ - Output: reply.Message, - Error: e, - Body: "", - Header: lastItem.Header, - StatusCode: http.StatusOK, - } - return - } - - result = &TestCaseResult{ - Output: reply.Message, - Error: reply.Error, - Body: lastItem.Body, - Header: lastItem.Header, - StatusCode: lastItem.StatusCode, - } - } else if err != nil { - result.Error = err.Error() - } else { - result = &TestCaseResult{ - Output: reply.Message, - Error: reply.Error, - } - } - - if reply != nil { - result.Output = reply.Message - result.Error = reply.Error - } - if lastItem != nil { - result.Body = lastItem.Body - result.Header = lastItem.Header - result.StatusCode = lastItem.StatusCode - } - } - return + var targetTestSuite testing.TestSuite + ExecutionCountNum.Inc() + defer func() { + if result.Error == "" { + ExecutionSuccessNum.Inc() + } else { + ExecutionFailNum.Inc() + } + }() + + result = &TestCaseResult{} + loader := s.getLoader(ctx) + defer loader.Close() + targetTestSuite, err = loader.GetTestSuite(in.Suite, true) + if err != nil || targetTestSuite.Name == "" { + err = nil + result.Error = fmt.Sprintf("not found suite: %s", in.Suite) + return + } + + var data []byte + if data, err = yaml.Marshal(targetTestSuite); err == nil { + task := &TestTask{ + Kind: "testcaseInSuite", + Data: string(data), + CaseName: in.Testcase, + Level: "debug", + Parameters: in.Parameters, + } + + var reply *TestResult + var lastItem *TestCaseResult + if reply, err = s.Run(ctx, task); err == nil && len(reply.TestCaseResult) > 0 { + lastIndex := len(reply.TestCaseResult) - 1 + lastItem = reply.TestCaseResult[lastIndex] + + if len(lastItem.Body) > GrpcMaxRecvMsgSize { + e := "the HTTP response body exceeded the maximum message size limit received by the gRPC client" + result = &TestCaseResult{ + Output: reply.Message, + Error: e, + Body: "", + Header: lastItem.Header, + StatusCode: http.StatusOK, + } + return + } + + result = &TestCaseResult{ + Output: reply.Message, + Error: reply.Error, + Body: lastItem.Body, + Header: lastItem.Header, + StatusCode: lastItem.StatusCode, + } + } else if err != nil { + result.Error = err.Error() + } else { + result = &TestCaseResult{ + Output: reply.Message, + Error: reply.Error, + } + } + + if reply != nil { + result.Output = reply.Message + result.Error = reply.Error + } + if lastItem != nil { + result.Body = lastItem.Body + result.Header = lastItem.Header + result.StatusCode = lastItem.StatusCode + } + } + return } func mapInterToPair(data map[string]interface{}) (pairs []*Pair) { - pairs = make([]*Pair, 0) - for k, v := range data { - pairs = append(pairs, &Pair{ - Key: k, - Value: fmt.Sprintf("%v", v), - }) - } - return + pairs = make([]*Pair, 0) + for k, v := range data { + pairs = append(pairs, &Pair{ + Key: k, + Value: fmt.Sprintf("%v", v), + }) + } + return } func mapToPair(data map[string]string) (pairs []*Pair) { - pairs = make([]*Pair, 0) - for k, v := range data { - pairs = append(pairs, &Pair{ - Key: k, - Value: v, - }) - } - return + pairs = make([]*Pair, 0) + for k, v := range data { + pairs = append(pairs, &Pair{ + Key: k, + Value: v, + }) + } + return } func pairToInterMap(pairs []*Pair) (data map[string]interface{}) { - data = make(map[string]interface{}) - for _, pair := range pairs { - if pair.Key == "" { - continue - } - data[pair.Key] = pair.Value - } - return + data = make(map[string]interface{}) + for _, pair := range pairs { + if pair.Key == "" { + continue + } + data[pair.Key] = pair.Value + } + return } func pairToMap(pairs []*Pair) (data map[string]string) { - data = make(map[string]string) - for _, pair := range pairs { - if pair.Key == "" { - continue - } - data[pair.Key] = pair.Value - } - return + data = make(map[string]string) + for _, pair := range pairs { + if pair.Key == "" { + continue + } + data[pair.Key] = pair.Value + } + return } func convertConditionalVerify(verify []*ConditionalVerify) (result []testing.ConditionalVerify) { - if verify != nil { - result = make([]testing.ConditionalVerify, 0) + if verify != nil { + result = make([]testing.ConditionalVerify, 0) - for _, item := range verify { - result = append(result, testing.ConditionalVerify{ - Condition: item.Condition, - Verify: item.Verify, - }) - } - } - return + for _, item := range verify { + result = append(result, testing.ConditionalVerify{ + Condition: item.Condition, + Verify: item.Verify, + }) + } + } + return } func (s *server) CreateTestCase(ctx context.Context, in *TestCaseWithSuite) (reply *HelloReply, err error) { - reply = &HelloReply{} - if in.Data == nil { - err = errors.New("data is required") - } else { - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.CreateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) - } - return + reply = &HelloReply{} + if in.Data == nil { + err = errors.New("data is required") + } else { + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.CreateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) + } + return } func (s *server) UpdateTestCase(ctx context.Context, in *TestCaseWithSuite) (reply *HelloReply, err error) { - reply = &HelloReply{} - if in.Data == nil { - err = errors.New("data is required") - return - } - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.UpdateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) - return + reply = &HelloReply{} + if in.Data == nil { + err = errors.New("data is required") + return + } + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.UpdateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) + return } func (s *server) DeleteTestCase(ctx context.Context, in *TestCaseIdentity) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} - err = loader.DeleteTestCase(in.Suite, in.Testcase) - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} + err = loader.DeleteTestCase(in.Suite, in.Testcase) + return } func (s *server) DeleteHistoryTestCase(ctx context.Context, in *HistoryTestCase) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} - err = loader.DeleteHistoryTestCase(in.ID) - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} + err = loader.DeleteHistoryTestCase(in.ID) + return } func (s *server) DeleteAllHistoryTestCase(ctx context.Context, in *HistoryTestCase) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} - err = loader.DeleteAllHistoryTestCase(in.SuiteName, in.CaseName) - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} + err = loader.DeleteAllHistoryTestCase(in.SuiteName, in.CaseName) + return } func (s *server) DuplicateTestCase(ctx context.Context, in *TestCaseDuplicate) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} - if in.SourceCaseName == in.TargetCaseName { - reply.Error = "source and target case name should be different" - return - } + if in.SourceCaseName == in.TargetCaseName { + reply.Error = "source and target case name should be different" + return + } - var testcase testing.TestCase - if testcase, err = loader.GetTestCase(in.SourceSuiteName, in.SourceCaseName); err == nil { - testcase.Name = in.TargetCaseName - err = loader.CreateTestCase(in.TargetSuiteName, testcase) - } - return + var testcase testing.TestCase + if testcase, err = loader.GetTestCase(in.SourceSuiteName, in.SourceCaseName); err == nil { + testcase.Name = in.TargetCaseName + err = loader.CreateTestCase(in.TargetSuiteName, testcase) + } + return } func (s *server) RenameTestCase(ctx context.Context, in *TestCaseDuplicate) (result *HelloReply, err error) { - result = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.RenameTestCase(in.SourceSuiteName, in.SourceCaseName, in.TargetCaseName) - return + result = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.RenameTestCase(in.SourceSuiteName, in.SourceCaseName, in.TargetCaseName) + return } // code generator func (s *server) ListCodeGenerator(ctx context.Context, in *Empty) (reply *SimpleList, err error) { - reply = &SimpleList{} + reply = &SimpleList{} - generators := generator.GetCodeGenerators() - for name := range generators { - reply.Data = append(reply.Data, &Pair{ - Key: name, - }) - } - return + generators := generator.GetCodeGenerators() + for name := range generators { + reply.Data = append(reply.Data, &Pair{ + Key: name, + }) + } + return } func (s *server) GenerateCode(ctx context.Context, in *CodeGenerateRequest) (reply *CommonResult, err error) { - reply = &CommonResult{} - instance := generator.GetCodeGenerator(in.Generator) - if instance == nil { - reply.Success = false - reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) - } else { - var result testing.TestCase - var suite testing.TestSuite - - loader := s.getLoader(ctx) - if suite, err = loader.GetTestSuite(in.TestSuite, true); err != nil { - return - } - - dataContext := map[string]interface{}{} - if err = suite.Render(dataContext); err != nil { - return - } - - var output string - var genErr error - if in.TestCase == "" { - output, genErr = instance.Generate(&suite, nil) - } else { - if result, err = loader.GetTestCase(in.TestSuite, in.TestCase); err == nil { - result.Request.RenderAPI(suite.API) - - output, genErr = instance.Generate(&suite, &result) - } - } - reply.Success = genErr == nil - reply.Message = util.OrErrorMessage(genErr, output) - } - return + reply = &CommonResult{} + instance := generator.GetCodeGenerator(in.Generator) + if instance == nil { + reply.Success = false + reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) + } else { + var result testing.TestCase + var suite testing.TestSuite + + loader := s.getLoader(ctx) + if suite, err = loader.GetTestSuite(in.TestSuite, true); err != nil { + return + } + + dataContext := map[string]interface{}{} + if err = suite.Render(dataContext); err != nil { + return + } + + var output string + var genErr error + if in.TestCase == "" { + output, genErr = instance.Generate(&suite, nil) + } else { + if result, err = loader.GetTestCase(in.TestSuite, in.TestCase); err == nil { + result.Request.RenderAPI(suite.API) + + output, genErr = instance.Generate(&suite, &result) + } + } + reply.Success = genErr == nil + reply.Message = util.OrErrorMessage(genErr, output) + } + return } func (s *server) HistoryGenerateCode(ctx context.Context, in *CodeGenerateRequest) (reply *CommonResult, err error) { - reply = &CommonResult{} - instance := generator.GetCodeGenerator(in.Generator) - if instance == nil { - reply.Success = false - reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) - } else { - loader := s.getLoader(ctx) - var result testing.HistoryTestCase - result, err = loader.GetHistoryTestCase(in.ID) - var testCase testing.TestCase - var suite testing.TestSuite - testCase = result.Data - suite.Name = result.SuiteName - suite.API = result.SuiteAPI - suite.Spec = result.SuiteSpec - suite.Param = result.SuiteParam - - output, genErr := instance.Generate(&suite, &testCase) - reply.Success = genErr == nil - reply.Message = util.OrErrorMessage(genErr, output) - } - return + reply = &CommonResult{} + instance := generator.GetCodeGenerator(in.Generator) + if instance == nil { + reply.Success = false + reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) + } else { + loader := s.getLoader(ctx) + var result testing.HistoryTestCase + result, err = loader.GetHistoryTestCase(in.ID) + var testCase testing.TestCase + var suite testing.TestSuite + testCase = result.Data + suite.Name = result.SuiteName + suite.API = result.SuiteAPI + suite.Spec = result.SuiteSpec + suite.Param = result.SuiteParam + + output, genErr := instance.Generate(&suite, &testCase) + reply.Success = genErr == nil + reply.Message = util.OrErrorMessage(genErr, output) + } + return } // converter func (s *server) ListConverter(ctx context.Context, in *Empty) (reply *SimpleList, err error) { - reply = &SimpleList{} - converters := generator.GetTestSuiteConverters() - for name := range converters { - reply.Data = append(reply.Data, &Pair{ - Key: name, - }) - } - return + reply = &SimpleList{} + converters := generator.GetTestSuiteConverters() + for name := range converters { + reply.Data = append(reply.Data, &Pair{ + Key: name, + }) + } + return } func (s *server) ConvertTestSuite(ctx context.Context, in *CodeGenerateRequest) (reply *CommonResult, err error) { - reply = &CommonResult{} - - instance := generator.GetTestSuiteConverter(in.Generator) - if instance == nil { - reply.Success = false - reply.Message = fmt.Sprintf("converter '%s' not found", in.Generator) - } else { - var result testing.TestSuite - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetTestSuite(in.TestSuite, true); err == nil { - output, genErr := instance.Convert(&result) - reply.Success = genErr == nil - reply.Message = util.OrErrorMessage(genErr, output) - } - } - return + reply = &CommonResult{} + + instance := generator.GetTestSuiteConverter(in.Generator) + if instance == nil { + reply.Success = false + reply.Message = fmt.Sprintf("converter '%s' not found", in.Generator) + } else { + var result testing.TestSuite + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetTestSuite(in.TestSuite, true); err == nil { + output, genErr := instance.Convert(&result) + reply.Success = genErr == nil + reply.Message = util.OrErrorMessage(genErr, output) + } + } + return } // Sample returns a sample of the test task func (s *server) Sample(ctx context.Context, in *Empty) (reply *HelloReply, err error) { - reply = &HelloReply{Message: sample.TestSuiteGitLab} - return + reply = &HelloReply{Message: sample.TestSuiteGitLab} + return } // PopularHeaders returns a list of popular headers func (s *server) PopularHeaders(ctx context.Context, in *Empty) (pairs *Pairs, err error) { - pairs = &Pairs{ - Data: []*Pair{}, - } + pairs = &Pairs{ + Data: []*Pair{}, + } - err = yaml.Unmarshal(popularHeaders, &pairs.Data) - return + err = yaml.Unmarshal(popularHeaders, &pairs.Data) + return } // GetSuggestedAPIs returns a list of suggested APIs func (s *server) GetSuggestedAPIs(ctx context.Context, in *TestSuiteIdentity) (reply *TestCases, err error) { - reply = &TestCases{} + reply = &TestCases{} - var suite *testing.TestSuite - loader := s.getLoader(ctx) - defer loader.Close() - if suite, _, err = loader.GetSuite(in.Name); err != nil || suite == nil { - return - } + var suite *testing.TestSuite + loader := s.getLoader(ctx) + defer loader.Close() + if suite, _, err = loader.GetSuite(in.Name); err != nil || suite == nil { + return + } - remoteServerLogger.Info("Finding APIs from", "name", in.Name, "with loader", reflect.TypeOf(loader)) + remoteServerLogger.Info("Finding APIs from", "name", in.Name, "with loader", reflect.TypeOf(loader)) - suiteRunner := runner.GetTestSuiteRunner(suite) - var result []*testing.TestCase - if result, err = suiteRunner.GetSuggestedAPIs(suite, in.Api); err == nil && result != nil { - for i := range result { - reply.Data = append(reply.Data, ToGRPCTestCase(*result[i])) - } - } - return + suiteRunner := runner.GetTestSuiteRunner(suite) + var result []*testing.TestCase + if result, err = suiteRunner.GetSuggestedAPIs(suite, in.Api); err == nil && result != nil { + for i := range result { + reply.Data = append(reply.Data, ToGRPCTestCase(*result[i])) + } + } + return } // FunctionsQuery returns a list of functions func (s *server) FunctionsQuery(ctx context.Context, in *SimpleQuery) (reply *Pairs, err error) { - reply = &Pairs{} - in.Name = strings.ToLower(in.Name) - - if in.Kind == "verify" { - for _, fn := range builtin.Builtins { - lowerName := strings.ToLower(fn.Name) - if in.Name == "" || strings.Contains(lowerName, in.Name) { - reply.Data = append(reply.Data, &Pair{ - Key: fn.Name, - Value: fmt.Sprintf("%v", reflect.TypeOf(fn.Func)), - }) - } - } - } else { - for name, fn := range render.FuncMap() { - lowerName := strings.ToLower(name) - if in.Name == "" || strings.Contains(lowerName, in.Name) { - reply.Data = append(reply.Data, &Pair{ - Key: name, - Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), - Description: render.FuncUsage(name), - }) - } - } - } - return + reply = &Pairs{} + in.Name = strings.ToLower(in.Name) + + if in.Kind == "verify" { + for _, fn := range builtin.Builtins { + lowerName := strings.ToLower(fn.Name) + if in.Name == "" || strings.Contains(lowerName, in.Name) { + reply.Data = append(reply.Data, &Pair{ + Key: fn.Name, + Value: fmt.Sprintf("%v", reflect.TypeOf(fn.Func)), + }) + } + } + } else { + for name, fn := range render.FuncMap() { + lowerName := strings.ToLower(name) + if in.Name == "" || strings.Contains(lowerName, in.Name) { + reply.Data = append(reply.Data, &Pair{ + Key: name, + Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), + Description: render.FuncUsage(name), + }) + } + } + } + return } // FunctionsQueryStream works like FunctionsQuery but is implemented in bidirectional streaming func (s *server) FunctionsQueryStream(srv Runner_FunctionsQueryStreamServer) error { - ctx := srv.Context() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - in, err := srv.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - reply := &Pairs{} - in.Name = strings.ToLower(in.Name) - - for name, fn := range render.FuncMap() { - lowerCaseName := strings.ToLower(name) - if in.Name == "" || strings.Contains(lowerCaseName, in.Name) { - reply.Data = append(reply.Data, &Pair{ - Key: name, - Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), - }) - } - } - if err := srv.Send(reply); err != nil { - return err - } - } - } + ctx := srv.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + in, err := srv.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + reply := &Pairs{} + in.Name = strings.ToLower(in.Name) + + for name, fn := range render.FuncMap() { + lowerCaseName := strings.ToLower(name) + if in.Name == "" || strings.Contains(lowerCaseName, in.Name) { + reply.Data = append(reply.Data, &Pair{ + Key: name, + Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), + }) + } + } + if err := srv.Send(reply); err != nil { + return err + } + } + } } func (s *server) GetStoreKinds(context.Context, *Empty) (kinds *StoreKinds, err error) { - storeFactory := testing.NewStoreFactory(s.configDir) - var stores []testing.StoreKind - if stores, err = storeFactory.GetStoreKinds(); err == nil { - kinds = &StoreKinds{} - for _, store := range stores { - kinds.Data = append(kinds.Data, &StoreKind{ - Name: store.Name, - Enabled: store.Enabled, - Url: store.URL, - }) - } - } - return + storeFactory := testing.NewStoreFactory(s.configDir) + var stores []testing.StoreKind + if stores, err = storeFactory.GetStoreKinds(); err == nil { + kinds = &StoreKinds{} + for _, store := range stores { + kinds.Data = append(kinds.Data, &StoreKind{ + Name: store.Name, + Enabled: store.Enabled, + Url: store.URL, + }) + } + } + return } func (s *server) GetStores(ctx context.Context, in *Empty) (reply *Stores, err error) { - user := oauth.GetUserFromContext(ctx) - storeFactory := testing.NewStoreFactory(s.configDir) - var stores []testing.Store - var owner string - if user != nil { - owner = user.Name - } - if stores, err = storeFactory.GetStoresByOwner(owner); err == nil { - reply = &Stores{ - Data: make([]*Store, 0), - } - wg := sync.WaitGroup{} - mu := sync.Mutex{} - for _, item := range stores { - wg.Add(1) - go func() { - defer wg.Done() - - grpcStore := ToGRPCStore(item) - if item.Disabled { - return - } - - storeStatus, sErr := s.VerifyStore(ctx, &SimpleQuery{Name: item.Name}) - grpcStore.Ready = sErr == nil && storeStatus.Ready - grpcStore.ReadOnly = storeStatus.ReadOnly - grpcStore.Password = util.PasswordPlaceholder - - mu.Lock() - reply.Data = append(reply.Data, grpcStore) - mu.Unlock() - }() - } - wg.Wait() - slices.SortFunc(reply.Data, func(a, b *Store) int { - return strings.Compare(a.Name, b.Name) - }) - reply.Data = append(reply.Data, &Store{ - Name: "local", - Kind: &StoreKind{}, - Ready: true, - }) - } - return + user := oauth.GetUserFromContext(ctx) + storeFactory := testing.NewStoreFactory(s.configDir) + var stores []testing.Store + var owner string + if user != nil { + owner = user.Name + } + if stores, err = storeFactory.GetStoresByOwner(owner); err == nil { + reply = &Stores{ + Data: make([]*Store, 0), + } + wg := sync.WaitGroup{} + mu := sync.Mutex{} + for _, item := range stores { + wg.Add(1) + go func() { + defer wg.Done() + + grpcStore := ToGRPCStore(item) + if item.Disabled { + return + } + + storeStatus, sErr := s.VerifyStore(ctx, &SimpleQuery{Name: item.Name}) + grpcStore.Ready = sErr == nil && storeStatus.Ready + grpcStore.ReadOnly = storeStatus.ReadOnly + grpcStore.Password = util.PasswordPlaceholder + + mu.Lock() + reply.Data = append(reply.Data, grpcStore) + mu.Unlock() + }() + } + wg.Wait() + slices.SortFunc(reply.Data, func(a, b *Store) int { + return strings.Compare(a.Name, b.Name) + }) + reply.Data = append(reply.Data, &Store{ + Name: "local", + Kind: &StoreKind{}, + Ready: true, + }) + } + return } func (s *server) CreateStore(ctx context.Context, in *Store) (reply *Store, err error) { - reply = &Store{} - user := oauth.GetUserFromContext(ctx) - if user != nil { - in.Owner = user.Name - } + reply = &Store{} + user := oauth.GetUserFromContext(ctx) + if user != nil { + in.Owner = user.Name + } - storeFactory := testing.NewStoreFactory(s.configDir) - store := ToNormalStore(in) + storeFactory := testing.NewStoreFactory(s.configDir) + store := ToNormalStore(in) - if store.Kind.URL == "" { - store.Kind.URL = fmt.Sprintf("unix://%s", home.GetExtensionSocketPath(store.Kind.Name)) - } + if store.Kind.URL == "" { + store.Kind.URL = fmt.Sprintf("unix://%s", home.GetExtensionSocketPath(store.Kind.Name)) + } - if err = storeFactory.CreateStore(store); err == nil && s.storeExtMgr != nil { - err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) - } - return + if err = storeFactory.CreateStore(store); err == nil && s.storeExtMgr != nil { + err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) + } + return } func (s *server) UpdateStore(ctx context.Context, in *Store) (reply *Store, err error) { - reply = &Store{} - storeFactory := testing.NewStoreFactory(s.configDir) - store := ToNormalStore(in) - if err = storeFactory.UpdateStore(store); err == nil && s.storeExtMgr != nil { - // TODO need to restart extension if config was changed - err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) - } - return + reply = &Store{} + storeFactory := testing.NewStoreFactory(s.configDir) + store := ToNormalStore(in) + if err = storeFactory.UpdateStore(store); err == nil && s.storeExtMgr != nil { + // TODO need to restart extension if config was changed + err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) + } + return } func (s *server) DeleteStore(ctx context.Context, in *Store) (reply *Store, err error) { - reply = &Store{} - storeFactory := testing.NewStoreFactory(s.configDir) - err = storeFactory.DeleteStore(in.Name) - return + reply = &Store{} + storeFactory := testing.NewStoreFactory(s.configDir) + err = storeFactory.DeleteStore(in.Name) + return } func (s *server) VerifyStore(ctx context.Context, in *SimpleQuery) (reply *ExtensionStatus, err error) { - reply = &ExtensionStatus{} - var loader testing.Writer - if loader, err = s.getLoaderByStoreName(in.Name); err == nil && loader != nil { - readOnly, verifyErr := loader.Verify() - reply.Ready = verifyErr == nil - reply.ReadOnly = readOnly - reply.Message = util.OKOrErrorMessage(verifyErr) - } - return + reply = &ExtensionStatus{} + var loader testing.Writer + if loader, err = s.getLoaderByStoreName(in.Name); err == nil && loader != nil { + readOnly, verifyErr := loader.Verify() + reply.Ready = verifyErr == nil + reply.ReadOnly = readOnly + reply.Message = util.OKOrErrorMessage(verifyErr) + } + return } // secret related interfaces func (s *server) GetSecrets(ctx context.Context, in *Empty) (reply *Secrets, err error) { - return s.secretServer.GetSecrets(ctx, in) + return s.secretServer.GetSecrets(ctx, in) } func (s *server) CreateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - return s.secretServer.CreateSecret(ctx, in) + return s.secretServer.CreateSecret(ctx, in) } func (s *server) DeleteSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - return s.secretServer.DeleteSecret(ctx, in) + return s.secretServer.DeleteSecret(ctx, in) } func (s *server) UpdateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - return s.secretServer.UpdateSecret(ctx, in) + return s.secretServer.UpdateSecret(ctx, in) } func (s *server) PProf(ctx context.Context, in *PProfRequest) (reply *PProfData, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &PProfData{ - Data: loader.PProf(in.Name), - } - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &PProfData{ + Data: loader.PProf(in.Name), + } + return } func (s *server) Query(ctx context.Context, query *DataQuery) (result *DataQueryResult, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - - // render the SQL query - var sql string - if sql, err = render.Render("sql render", query.Sql, nil); err != nil { - return nil, fmt.Errorf("failed to render SQL query: %w", err) - } - - var dataResult testing.DataResult - if dataResult, err = loader.Query(map[string]string{ - "sql": sql, - "key": query.Key, - "offset": fmt.Sprintf("%d", query.Offset), - "limit": fmt.Sprintf("%d", query.Limit), - }); err == nil { - result = &DataQueryResult{ - Data: mapToPair(dataResult.Pairs), - } - for _, item := range dataResult.Rows { - result.Items = append(result.Items, &Pairs{ - Data: mapToPair(item), - }) - } - result.Meta = &DataMeta{ - CurrentDatabase: dataResult.CurrentDatabase, - Databases: dataResult.Databases, - Tables: dataResult.Tables, - Duration: dataResult.Duration, - Labels: mapToPair(dataResult.Labels), - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + + // render the SQL query + var sql string + if sql, err = render.Render("sql render", query.Sql, nil); err != nil { + return nil, fmt.Errorf("failed to render SQL query: %w", err) + } + + var dataResult testing.DataResult + if dataResult, err = loader.Query(map[string]string{ + "sql": sql, + "key": query.Key, + "offset": fmt.Sprintf("%d", query.Offset), + "limit": fmt.Sprintf("%d", query.Limit), + }); err == nil { + result = &DataQueryResult{ + Data: mapToPair(dataResult.Pairs), + } + for _, item := range dataResult.Rows { + result.Items = append(result.Items, &Pairs{ + Data: mapToPair(item), + }) + } + result.Meta = &DataMeta{ + CurrentDatabase: dataResult.CurrentDatabase, + Databases: dataResult.Databases, + Tables: dataResult.Tables, + Duration: dataResult.Duration, + Labels: mapToPair(dataResult.Labels), + } + } + return } func (s *server) GetThemes(ctx context.Context, _ *Empty) (result *SimpleList, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - - result = &SimpleList{} - var themes []string - if themes, err = loader.GetThemes(); err == nil { - for _, theme := range themes { - result.Data = append(result.Data, &Pair{ - Key: theme, - Value: "", - }) - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + + result = &SimpleList{} + var themes []string + if themes, err = loader.GetThemes(); err == nil { + for _, theme := range themes { + result.Data = append(result.Data, &Pair{ + Key: theme, + Value: "", + }) + } + } + return } func (s *server) GetTheme(ctx context.Context, in *SimpleName) (result *CommonResult, err error) { - loader := s.getLoader(ctx) - defer loader.Close() + loader := s.getLoader(ctx) + defer loader.Close() - result = &CommonResult{} - result.Message, err = loader.GetTheme(in.Name) - if err != nil { - result.Message = fmt.Sprintf("failed to get theme: %v", err) - } - return + result = &CommonResult{} + result.Message, err = loader.GetTheme(in.Name) + if err != nil { + result.Message = fmt.Sprintf("failed to get theme: %v", err) + } + return } func (s *server) GetBindings(ctx context.Context, _ *Empty) (result *SimpleList, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - - result = &SimpleList{} - var bindings []string - if bindings, err = loader.GetBindings(); err == nil { - for _, theme := range bindings { - result.Data = append(result.Data, &Pair{ - Key: theme, - Value: "", - }) - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + + result = &SimpleList{} + var bindings []string + if bindings, err = loader.GetBindings(); err == nil { + for _, theme := range bindings { + result.Data = append(result.Data, &Pair{ + Key: theme, + Value: "", + }) + } + } + return } func (s *server) GetBinding(ctx context.Context, in *SimpleName) (result *CommonResult, err error) { - loader := s.getLoader(ctx) - defer loader.Close() + loader := s.getLoader(ctx) + defer loader.Close() - result = &CommonResult{} - result.Message, err = loader.GetBinding(in.Name) - if err != nil { - result.Message = fmt.Sprintf("failed to get binding: %v", err) - } - return + result = &CommonResult{} + result.Message, err = loader.GetBinding(in.Name) + if err != nil { + result.Message = fmt.Sprintf("failed to get binding: %v", err) + } + return } // implement the mock server // Start starts the mock server type mockServerController struct { - UnimplementedMockServer - mockWriter mock.ReaderAndWriter - loader mock.Loadable - reader mock.Reader - logData chan string - prefix string - combinePort int + UnimplementedMockServer + mockWriter mock.ReaderAndWriter + loader mock.Loadable + reader mock.Reader + logData chan string + prefix string + combinePort int } func NewMockServerController(mockWriter mock.ReaderAndWriter, loader mock.Loadable, combinePort int) MockServer { - return &mockServerController{ - mockWriter: mockWriter, - loader: loader, - prefix: "/mock/server", - logData: make(chan string, 100), - combinePort: combinePort, - } + return &mockServerController{ + mockWriter: mockWriter, + loader: loader, + prefix: "/mock/server", + logData: make(chan string, 100), + combinePort: combinePort, + } } func (s *mockServerController) Reload(ctx context.Context, in *MockConfig) (reply *Empty, err error) { - s.mockWriter.Write([]byte(in.Config)) - s.prefix = in.Prefix - if dServer, ok := s.loader.(mock.DynamicServer); ok && dServer.GetPort() != strconv.Itoa(int(in.GetPort())) { - if strconv.Itoa(s.combinePort) != dServer.GetPort() { - if stopErr := dServer.Stop(); stopErr != nil { - remoteServerLogger.Info("failed to stop old server", "error", stopErr) - } else { - remoteServerLogger.Info("old server stopped", "port", dServer.GetPort()) - } - } - - server := mock.NewInMemoryServer(ctx, int(in.GetPort())).WithTLS(dServer.GetTLS()) - server.Start(s.mockWriter, in.Prefix) - server.WithLogWriter(s) - s.loader = server - } - err = s.loader.Load() - return + s.mockWriter.Write([]byte(in.Config)) + s.prefix = in.Prefix + if dServer, ok := s.loader.(mock.DynamicServer); ok && dServer.GetPort() != strconv.Itoa(int(in.GetPort())) { + if strconv.Itoa(s.combinePort) != dServer.GetPort() { + if stopErr := dServer.Stop(); stopErr != nil { + remoteServerLogger.Info("failed to stop old server", "error", stopErr) + } else { + remoteServerLogger.Info("old server stopped", "port", dServer.GetPort()) + } + } + + server := mock.NewInMemoryServer(ctx, int(in.GetPort())).WithTLS(dServer.GetTLS()) + if err = server.Start(s.mockWriter, in.Prefix); err != nil { + return + } + server.WithLogWriter(s) + s.loader = server + } + err = s.loader.Load() + return } func (s *mockServerController) GetConfig(ctx context.Context, in *Empty) (reply *MockConfig, err error) { - reply = &MockConfig{ - Prefix: s.prefix, - Config: string(s.mockWriter.GetData()), - } - if dServer, ok := s.loader.(mock.DynamicServer); ok { - if port, pErr := strconv.ParseInt(dServer.GetPort(), 10, 32); pErr == nil { - reply.Port = int32(port) - } - } - return + reply = &MockConfig{ + Prefix: s.prefix, + Config: string(s.mockWriter.GetData()), + } + if dServer, ok := s.loader.(mock.DynamicServer); ok { + if port, pErr := strconv.ParseInt(dServer.GetPort(), 10, 32); pErr == nil { + reply.Port = int32(port) + } + } + return } func (s *mockServerController) LogWatch(e *Empty, logServer Mock_LogWatchServer) (err error) { - logServer.Send(&CommonResult{ - Success: true, - Message: "Mock server log watch started\n", - }) - for msg := range s.logData { - logServer.Send(&CommonResult{ - Success: true, - Message: msg, - }) - } - return + logServer.Send(&CommonResult{ + Success: true, + Message: "Mock server log watch started\n", + }) + for msg := range s.logData { + logServer.Send(&CommonResult{ + Success: true, + Message: msg, + }) + } + return } func (s *mockServerController) Write(p []byte) (n int, err error) { - select { - case s.logData <- fmt.Sprintf("%s: %s", time.Now().Format(time.RFC3339), string(p)): - default: - } - return + select { + case s.logData <- fmt.Sprintf("%s: %s", time.Now().Format(time.RFC3339), string(p)): + default: + } + return } func (s *server) getLoaderByStoreName(storeName string) (loader testing.Writer, err error) { - var store *testing.Store - store, err = testing.NewStoreFactory(s.configDir).GetStore(storeName) - if err == nil && store != nil { - loader, err = s.storeWriterFactory.NewInstance(*store) - if err != nil { - err = fmt.Errorf("failed to new grpc loader from store %s, err: %v", store.Name, err) - } - } else { - err = fmt.Errorf("failed to get store %s, err: %v", storeName, err) - } - return + var store *testing.Store + store, err = testing.NewStoreFactory(s.configDir).GetStore(storeName) + if err == nil && store != nil { + loader, err = s.storeWriterFactory.NewInstance(*store) + if err != nil { + err = fmt.Errorf("failed to new grpc loader from store %s, err: %v", store.Name, err) + } + } else { + err = fmt.Errorf("failed to get store %s, err: %v", storeName, err) + } + return } //go:embed data/headers.yaml var popularHeaders []byte func findParentTestCases(testcase *testing.TestCase, suite *testing.TestSuite) (testcases []testing.TestCase) { - reg, matchErr := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) - targetReg, targetErr := regexp.Compile(`\.\w*`) - - expectNames := new(UniqueSlice[string]) - if matchErr == nil && targetErr == nil { - var expectName string - for _, val := range testcase.Request.Header { - if matched := reg.MatchString(val); matched { - expectName = targetReg.FindString(val) - expectName = strings.TrimPrefix(expectName, ".") - expectNames.Push(expectName) - } - } - - findExpectNames(testcase.Request.API, expectNames) - findExpectNames(testcase.Request.Body.String(), expectNames) - - remoteServerLogger.Info("expect test case names", "name", expectNames.GetAll()) - for _, item := range suite.Items { - if expectNames.Exist(item.Name) { - testcases = append(testcases, item) - } - } - } - return + reg, matchErr := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) + targetReg, targetErr := regexp.Compile(`\.\w*`) + + expectNames := new(UniqueSlice[string]) + if matchErr == nil && targetErr == nil { + var expectName string + for _, val := range testcase.Request.Header { + if matched := reg.MatchString(val); matched { + expectName = targetReg.FindString(val) + expectName = strings.TrimPrefix(expectName, ".") + expectNames.Push(expectName) + } + } + + findExpectNames(testcase.Request.API, expectNames) + findExpectNames(testcase.Request.Body.String(), expectNames) + + remoteServerLogger.Info("expect test case names", "name", expectNames.GetAll()) + for _, item := range suite.Items { + if expectNames.Exist(item.Name) { + testcases = append(testcases, item) + } + } + } + return } func findExpectNames(target string, expectNames *UniqueSlice[string]) { - reg, _ := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) - targetReg, _ := regexp.Compile(`\.\w*`) + reg, _ := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) + targetReg, _ := regexp.Compile(`\.\w*`) - for _, sub := range reg.FindStringSubmatch(target) { - // remove {{ and }} - if left, leftErr := regexp.Compile(`.*\{\{`); leftErr == nil { - body := left.ReplaceAllString(sub, "") + for _, sub := range reg.FindStringSubmatch(target) { + // remove {{ and }} + if left, leftErr := regexp.Compile(`.*\{\{`); leftErr == nil { + body := left.ReplaceAllString(sub, "") - expectName := targetReg.FindString(body) - expectName = strings.TrimPrefix(expectName, ".") - expectNames.Push(expectName) - } - } + expectName := targetReg.FindString(body) + expectName = strings.TrimPrefix(expectName, ".") + expectNames.Push(expectName) + } + } } // UniqueSlice represents an unique slice type UniqueSlice[T comparable] struct { - data []T + data []T } // Push pushes an item if it's not exist func (s *UniqueSlice[T]) Push(item T) *UniqueSlice[T] { - if s.data == nil { - s.data = []T{item} - } else { - for _, it := range s.data { - if it == item { - return s - } - } - s.data = append(s.data, item) - } - return s + if s.data == nil { + s.data = []T{item} + } else { + for _, it := range s.data { + if it == item { + return s + } + } + s.data = append(s.data, item) + } + return s } // Exist checks if the item exist, return true it exists func (s *UniqueSlice[T]) Exist(item T) bool { - if s.data != nil { - for _, it := range s.data { - if it == item { - return true - } - } - } - return false + if s.data != nil { + for _, it := range s.data { + if it == item { + return true + } + } + } + return false } // GetAll returns all the items func (s *UniqueSlice[T]) GetAll() []T { - return s.data + return s.data } var errNoTestSuiteFound = errors.New("no test suite found") From bf212a144966fdfffee22fe7259b01f0d8480f5b Mon Sep 17 00:00:00 2001 From: rick Date: Mon, 4 Aug 2025 17:33:22 +0800 Subject: [PATCH 2/3] docs: add more instructors about the mock conditional response --- docs/site/content/zh/latest/tasks/mock.md | 36 +- pkg/mock/in_memory.go | 1228 +++++------ pkg/mock/in_memory_test.go | 404 ++-- pkg/server/remote_server.go | 2452 ++++++++++----------- 4 files changed, 2076 insertions(+), 2044 deletions(-) diff --git a/docs/site/content/zh/latest/tasks/mock.md b/docs/site/content/zh/latest/tasks/mock.md index f415e886..a8aa4822 100644 --- a/docs/site/content/zh/latest/tasks/mock.md +++ b/docs/site/content/zh/latest/tasks/mock.md @@ -102,6 +102,8 @@ items: curl http://localhost:6060/mock/api/v1/repos/atest/prs -v ``` +#### 编码器 + 另外,为了满足复杂的场景,还可以对 Response Body 做特定的解码,目前支持:`base64`、`url`、`raw`: > encoder 为 `raw` 时,表示不进行处理 @@ -151,6 +153,36 @@ items: bodyFromFile: /tmp/baidu.html ``` +#### 条件判断 + +对于查询类的 API,通常会接收参数,并根据参数的不同,返回相应的数据。这时候,可以用到条件判断的表达式: + +```yaml +items: + - name: cats + request: + path: /api/v1/cats/{size} + response: + header: + Content-Type: application/json + body: | + {{if eq .Param.size "big"}} + { + "name": "big cat" + } + {{else if eq .Param.size "middle"}} + { + "name": "middle cat" + } + {{else if eq .Param.size "small"}} + { + "name": "small cat" + } + {{end}} +``` + +## 代理 + 在实际情况中,往往是向已有系统或平台添加新的 API,此时要 Mock 所有已经存在的 API 就既没必要也需要很多工作量。因此,我们提供了一种简单的方式,即可以增加**代理**的方式把已有的 API 请求转发到实际的地址,只对新增的 API 进行 Mock 处理。如下所示: ```yaml @@ -175,7 +207,7 @@ proxies: target: http://192.168.123.58:9200 ``` -## TCP 协议代理 +### TCP 协议代理 ```yaml proxies: @@ -185,7 +217,7 @@ proxies: target: 192.168.123.58:33060 ``` -## 代理多个服务 +### 代理多个服务 ```shell atest mock-compose bin/compose.yaml diff --git a/pkg/mock/in_memory.go b/pkg/mock/in_memory.go index 804c2e18..fceed94f 100644 --- a/pkg/mock/in_memory.go +++ b/pkg/mock/in_memory.go @@ -16,709 +16,709 @@ limitations under the License. package mock import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "os" - "sort" - "strings" - "sync" - "time" - - jsonpatch "github.com/evanphx/json-patch" - "github.com/swaggest/openapi-go/openapi3" - "github.com/swaggest/rest/gorillamux" - - "github.com/linuxsuren/api-testing/pkg/version" - - "github.com/linuxsuren/api-testing/pkg/logging" - "github.com/linuxsuren/api-testing/pkg/render" - "github.com/linuxsuren/api-testing/pkg/util" - - "github.com/gorilla/mux" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "sort" + "strings" + "sync" + "time" + + jsonpatch "github.com/evanphx/json-patch" + "github.com/swaggest/openapi-go/openapi3" + "github.com/swaggest/rest/gorillamux" + + "github.com/linuxsuren/api-testing/pkg/version" + + "github.com/linuxsuren/api-testing/pkg/logging" + "github.com/linuxsuren/api-testing/pkg/render" + "github.com/linuxsuren/api-testing/pkg/util" + + "github.com/gorilla/mux" ) var ( - memLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("memory") + memLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("memory") ) type inMemoryServer struct { - data map[string][]map[string]interface{} - mux *mux.Router - listener net.Listener - certFile, keyFile string - port int - prefix string - wg sync.WaitGroup - ctx context.Context - cancelFunc context.CancelFunc - reader Reader - metrics RequestMetrics + data map[string][]map[string]interface{} + mux *mux.Router + listener net.Listener + certFile, keyFile string + port int + prefix string + wg sync.WaitGroup + ctx context.Context + cancelFunc context.CancelFunc + reader Reader + metrics RequestMetrics } func NewInMemoryServer(ctx context.Context, port int) DynamicServer { - ctx, cancel := context.WithCancel(ctx) - return &inMemoryServer{ - port: port, - wg: sync.WaitGroup{}, - ctx: ctx, - cancelFunc: cancel, - metrics: NewNoopMetrics(), - } + ctx, cancel := context.WithCancel(ctx) + return &inMemoryServer{ + port: port, + wg: sync.WaitGroup{}, + ctx: ctx, + cancelFunc: cancel, + metrics: NewNoopMetrics(), + } } func (s *inMemoryServer) SetupHandler(reader Reader, prefix string) (handler http.Handler, err error) { - s.reader = reader - // init the data - s.data = make(map[string][]map[string]interface{}) - s.mux = mux.NewRouter().PathPrefix(prefix).Subrouter() - s.prefix = prefix - handler = s.mux - s.metrics.AddMetricsHandler(s.mux) - err = s.Load() - return + s.reader = reader + // init the data + s.data = make(map[string][]map[string]interface{}) + s.mux = mux.NewRouter().PathPrefix(prefix).Subrouter() + s.prefix = prefix + handler = s.mux + s.metrics.AddMetricsHandler(s.mux) + err = s.Load() + return } func (s *inMemoryServer) WithTLS(certFile, keyFile string) DynamicServer { - s.certFile = certFile - s.keyFile = keyFile - return s + s.certFile = certFile + s.keyFile = keyFile + return s } func (s *inMemoryServer) WithLogWriter(writer io.Writer) DynamicServer { - if writer != nil { - memLogger = memLogger.WithNameAndWriter("stream", writer) - } - return s + if writer != nil { + memLogger = memLogger.WithNameAndWriter("stream", writer) + } + return s } func (s *inMemoryServer) GetTLS() (string, string) { - return s.certFile, s.keyFile + return s.certFile, s.keyFile } func (s *inMemoryServer) Load() (err error) { - var server *Server - if server, err = s.reader.Parse(); err != nil { - return - } - - memLogger.Info("start to run all the APIs from objects", "count", len(server.Objects)) - for _, obj := range server.Objects { - memLogger.Info("start mock server from object", "name", obj.Name) - s.startObject(obj) - s.initObjectData(obj) - } - - memLogger.Info("start to run all the APIs from items", "count", len(server.Items)) - for _, item := range server.Items { - s.startItem(item) - } - - memLogger.Info("start webhook servers", "count", len(server.Webhooks)) - for _, item := range server.Webhooks { - if err = s.startWebhook(&item); err != nil { - continue - } - } - - s.handleOpenAPI() - - for i, proxy := range server.Proxies { - memLogger.Info("start to proxy", "target", proxy.Target) - switch proxy.Protocol { - case "http", "": - s.httpProxy(&proxy) - case "tcp": - s.tcpProxy(&server.Proxies[i]) - default: - memLogger.Error(fmt.Errorf("unsupported protocol: %s", proxy.Protocol), "failed to start proxy") - } - } - return + var server *Server + if server, err = s.reader.Parse(); err != nil { + return + } + + memLogger.Info("start to run all the APIs from objects", "count", len(server.Objects)) + for _, obj := range server.Objects { + memLogger.Info("start mock server from object", "name", obj.Name) + s.startObject(obj) + s.initObjectData(obj) + } + + memLogger.Info("start to run all the APIs from items", "count", len(server.Items)) + for _, item := range server.Items { + s.startItem(item) + } + + memLogger.Info("start webhook servers", "count", len(server.Webhooks)) + for _, item := range server.Webhooks { + if err = s.startWebhook(&item); err != nil { + continue + } + } + + s.handleOpenAPI() + + for i, proxy := range server.Proxies { + memLogger.Info("start to proxy", "target", proxy.Target) + switch proxy.Protocol { + case "http", "": + s.httpProxy(&proxy) + case "tcp": + s.tcpProxy(&server.Proxies[i]) + default: + memLogger.Error(fmt.Errorf("unsupported protocol: %s", proxy.Protocol), "failed to start proxy") + } + } + return } func (s *inMemoryServer) httpProxy(proxy *Proxy) { - s.mux.HandleFunc(proxy.Path, func(w http.ResponseWriter, req *http.Request) { - if !strings.HasSuffix(proxy.Target, "/") { - proxy.Target += "/" - } - targetPath := strings.TrimPrefix(req.URL.Path, s.prefix) - targetPath = strings.TrimPrefix(targetPath, "/") - - apiRaw := fmt.Sprintf("%s%s", proxy.Target, targetPath) - api, err := render.Render("proxy api", apiRaw, s) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to render proxy api", "api", apiRaw) - return - } - memLogger.Info("redirect to", "target", api) - - var requestBody []byte - if requestBody, err = io.ReadAll(req.Body); err != nil { - w.WriteHeader(http.StatusInternalServerError) - } - - if proxy.RequestAmend.BodyPatch != "" && len(requestBody) > 0 { - var patch jsonpatch.Patch - if patch, err = jsonpatch.DecodePatch([]byte(proxy.RequestAmend.BodyPatch)); err != nil { - return - } - - fmt.Println("before patch:", string(requestBody)) - if requestBody, err = patch.Apply(requestBody); err != nil { - fmt.Println(err) - return - } - fmt.Println("after patch:", string(requestBody)) - } - - targetReq, err := http.NewRequestWithContext(req.Context(), req.Method, api, bytes.NewBuffer(requestBody)) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to create proxy request") - return - } - - for k, v := range req.Header { - targetReq.Header.Add(k, v[0]) - } - - resp, err := http.DefaultClient.Do(targetReq) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to do proxy request") - return - } - - data, err := io.ReadAll(resp.Body) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - memLogger.Error(err, "failed to read response body") - return - } - - for k, v := range resp.Header { - w.Header().Add(k, v[0]) - } - w.Write(data) - }) + s.mux.HandleFunc(proxy.Path, func(w http.ResponseWriter, req *http.Request) { + if !strings.HasSuffix(proxy.Target, "/") { + proxy.Target += "/" + } + targetPath := strings.TrimPrefix(req.URL.Path, s.prefix) + targetPath = strings.TrimPrefix(targetPath, "/") + + apiRaw := fmt.Sprintf("%s%s", proxy.Target, targetPath) + api, err := render.Render("proxy api", apiRaw, s) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to render proxy api", "api", apiRaw) + return + } + memLogger.Info("redirect to", "target", api) + + var requestBody []byte + if requestBody, err = io.ReadAll(req.Body); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + + if proxy.RequestAmend.BodyPatch != "" && len(requestBody) > 0 { + var patch jsonpatch.Patch + if patch, err = jsonpatch.DecodePatch([]byte(proxy.RequestAmend.BodyPatch)); err != nil { + return + } + + fmt.Println("before patch:", string(requestBody)) + if requestBody, err = patch.Apply(requestBody); err != nil { + fmt.Println(err) + return + } + fmt.Println("after patch:", string(requestBody)) + } + + targetReq, err := http.NewRequestWithContext(req.Context(), req.Method, api, bytes.NewBuffer(requestBody)) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to create proxy request") + return + } + + for k, v := range req.Header { + targetReq.Header.Add(k, v[0]) + } + + resp, err := http.DefaultClient.Do(targetReq) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to do proxy request") + return + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + memLogger.Error(err, "failed to read response body") + return + } + + for k, v := range resp.Header { + w.Header().Add(k, v[0]) + } + w.Write(data) + }) } func (s *inMemoryServer) tcpProxy(proxy *Proxy) { - fmt.Println("start to proxy", proxy.Port) - lisener, err := net.Listen("tcp", fmt.Sprintf(":%d", proxy.Port)) - if err != nil { - memLogger.Error(err, "failed to listen") - return - } - fmt.Printf("proxy local: %d, target: %s\n", proxy.Port, proxy.Target) - defer lisener.Close() - - for { - conn, err := lisener.Accept() - if err != nil { - memLogger.Error(err, "failed to accept") - continue - } - - fmt.Println("accept connection") - go handleConnection(conn, proxy.Target) - } + fmt.Println("start to proxy", proxy.Port) + lisener, err := net.Listen("tcp", fmt.Sprintf(":%d", proxy.Port)) + if err != nil { + memLogger.Error(err, "failed to listen") + return + } + fmt.Printf("proxy local: %d, target: %s\n", proxy.Port, proxy.Target) + defer lisener.Close() + + for { + conn, err := lisener.Accept() + if err != nil { + memLogger.Error(err, "failed to accept") + continue + } + + fmt.Println("accept connection") + go handleConnection(conn, proxy.Target) + } } func handleConnection(clientConn net.Conn, targetAddr string) { - defer clientConn.Close() + defer clientConn.Close() - targetConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second) - if err != nil { - fmt.Printf("Failed to connect to target server: %v\n", err) - return - } - defer targetConn.Close() + targetConn, err := net.DialTimeout("tcp", targetAddr, 10*time.Second) + if err != nil { + fmt.Printf("Failed to connect to target server: %v\n", err) + return + } + defer targetConn.Close() - fmt.Printf("Connection established between %s and %s\n", clientConn.RemoteAddr(), targetConn.RemoteAddr()) + fmt.Printf("Connection established between %s and %s\n", clientConn.RemoteAddr(), targetConn.RemoteAddr()) - go io.Copy(clientConn, targetConn) - go io.Copy(targetConn, clientConn) + go io.Copy(clientConn, targetConn) + go io.Copy(targetConn, clientConn) - select {} + select {} } func (s *inMemoryServer) Start(reader Reader, prefix string) (err error) { - var handler http.Handler - if handler, err = s.SetupHandler(reader, prefix); err == nil { - if s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.port)); err == nil { - go func() { - if s.certFile != "" && s.keyFile != "" { - if err = http.ServeTLS(s.listener, handler, s.certFile, s.keyFile); err != nil { - memLogger.Error(err, "failed to start TLS mock server") - } - } else { - memLogger.Info("start HTTP mock server") - err = http.Serve(s.listener, handler) - } - }() - } - } - return + var handler http.Handler + if handler, err = s.SetupHandler(reader, prefix); err == nil { + if s.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", s.port)); err == nil { + go func() { + if s.certFile != "" && s.keyFile != "" { + if err = http.ServeTLS(s.listener, handler, s.certFile, s.keyFile); err != nil { + memLogger.Error(err, "failed to start TLS mock server") + } + } else { + memLogger.Info("start HTTP mock server") + err = http.Serve(s.listener, handler) + } + }() + } + } + return } func (s *inMemoryServer) EnableMetrics() { - s.metrics = NewInMemoryMetrics() + s.metrics = NewInMemoryMetrics() } func (s *inMemoryServer) startObject(obj Object) { - // create a simple CRUD server - s.mux.HandleFunc("/"+obj.Name, func(w http.ResponseWriter, req *http.Request) { - memLogger.Info("mock server received request", "path", req.URL.Path) - s.metrics.RecordRequest(req.URL.Path) - method := req.Method - w.Header().Set(util.ContentType, util.JSON) - - switch method { - case http.MethodGet: - // list all items - allItems := s.data[obj.Name] - filteredItems := make([]map[string]interface{}, 0) - - for i, item := range allItems { - exclude := false - - for k, v := range req.URL.Query() { - if len(v) == 0 { - continue - } - - if val, ok := item[k]; ok && val != v[0] { - exclude = true - break - } - } - - if !exclude { - filteredItems = append(filteredItems, allItems[i]) - } - } - - if len(filteredItems) != len(allItems) { - allItems = filteredItems - } - - data, err := json.Marshal(allItems) - writeResponse(w, data, err) - case http.MethodPost: - // create an item - if data, err := io.ReadAll(req.Body); err == nil { - objData := map[string]interface{}{} - - jsonErr := json.Unmarshal(data, &objData) - if jsonErr != nil { - memLogger.Info(jsonErr.Error()) - return - } - - s.data[obj.Name] = append(s.data[obj.Name], objData) - - _, _ = w.Write(data) - } else { - memLogger.Info("failed to read from body", "error", err) - } - default: - w.WriteHeader(http.StatusMethodNotAllowed) - } - }) - - // handle a single object - s.mux.HandleFunc(fmt.Sprintf("/%s/{name}", obj.Name), func(w http.ResponseWriter, req *http.Request) { - s.metrics.RecordRequest(req.URL.Path) - w.Header().Set(util.ContentType, util.JSON) - objects := s.data[obj.Name] - if objects != nil { - name := mux.Vars(req)["name"] - var data []byte - for _, obj := range objects { - if obj["name"] == name { - - data, _ = json.Marshal(obj) - break - } - } - - if len(data) == 0 { - w.WriteHeader(http.StatusNotFound) - return - } - - method := req.Method - switch method { - case http.MethodGet: - writeResponse(w, data, nil) - case http.MethodPut: - objData := map[string]interface{}{} - if data, err := io.ReadAll(req.Body); err == nil { - - jsonErr := json.Unmarshal(data, &objData) - if jsonErr != nil { - memLogger.Info(jsonErr.Error()) - return - } - for i, item := range s.data[obj.Name] { - if item["name"] == name { - s.data[obj.Name][i] = objData - break - } - } - _, _ = w.Write(data) - } - case http.MethodDelete: - for i, item := range s.data[obj.Name] { - if item["name"] == name { - if len(s.data[obj.Name]) == i+1 { - s.data[obj.Name] = s.data[obj.Name][:i] - } else { - s.data[obj.Name] = append(s.data[obj.Name][:i], s.data[obj.Name][i+1]) - } - - writeResponse(w, []byte(`{"msg": "deleted"}`), nil) - } - } - default: - w.WriteHeader(http.StatusMethodNotAllowed) - } - - } - }) + // create a simple CRUD server + s.mux.HandleFunc("/"+obj.Name, func(w http.ResponseWriter, req *http.Request) { + memLogger.Info("mock server received request", "path", req.URL.Path) + s.metrics.RecordRequest(req.URL.Path) + method := req.Method + w.Header().Set(util.ContentType, util.JSON) + + switch method { + case http.MethodGet: + // list all items + allItems := s.data[obj.Name] + filteredItems := make([]map[string]interface{}, 0) + + for i, item := range allItems { + exclude := false + + for k, v := range req.URL.Query() { + if len(v) == 0 { + continue + } + + if val, ok := item[k]; ok && val != v[0] { + exclude = true + break + } + } + + if !exclude { + filteredItems = append(filteredItems, allItems[i]) + } + } + + if len(filteredItems) != len(allItems) { + allItems = filteredItems + } + + data, err := json.Marshal(allItems) + writeResponse(w, data, err) + case http.MethodPost: + // create an item + if data, err := io.ReadAll(req.Body); err == nil { + objData := map[string]interface{}{} + + jsonErr := json.Unmarshal(data, &objData) + if jsonErr != nil { + memLogger.Info(jsonErr.Error()) + return + } + + s.data[obj.Name] = append(s.data[obj.Name], objData) + + _, _ = w.Write(data) + } else { + memLogger.Info("failed to read from body", "error", err) + } + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } + }) + + // handle a single object + s.mux.HandleFunc(fmt.Sprintf("/%s/{name}", obj.Name), func(w http.ResponseWriter, req *http.Request) { + s.metrics.RecordRequest(req.URL.Path) + w.Header().Set(util.ContentType, util.JSON) + objects := s.data[obj.Name] + if objects != nil { + name := mux.Vars(req)["name"] + var data []byte + for _, obj := range objects { + if obj["name"] == name { + + data, _ = json.Marshal(obj) + break + } + } + + if len(data) == 0 { + w.WriteHeader(http.StatusNotFound) + return + } + + method := req.Method + switch method { + case http.MethodGet: + writeResponse(w, data, nil) + case http.MethodPut: + objData := map[string]interface{}{} + if data, err := io.ReadAll(req.Body); err == nil { + + jsonErr := json.Unmarshal(data, &objData) + if jsonErr != nil { + memLogger.Info(jsonErr.Error()) + return + } + for i, item := range s.data[obj.Name] { + if item["name"] == name { + s.data[obj.Name][i] = objData + break + } + } + _, _ = w.Write(data) + } + case http.MethodDelete: + for i, item := range s.data[obj.Name] { + if item["name"] == name { + if len(s.data[obj.Name]) == i+1 { + s.data[obj.Name] = s.data[obj.Name][:i] + } else { + s.data[obj.Name] = append(s.data[obj.Name][:i], s.data[obj.Name][i+1]) + } + + writeResponse(w, []byte(`{"msg": "deleted"}`), nil) + } + } + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } + + } + }) } func (s *inMemoryServer) startItem(item Item) { - method := util.EmptyThenDefault(item.Request.Method, http.MethodGet) - memLogger.Info("register mock service", "method", method, "path", item.Request.Path, "encoder", item.Response.Encoder) - - var headerSlices []string - for k, v := range item.Request.Header { - headerSlices = append(headerSlices, k, v) - } - - adHandler := &advanceHandler{ - item: &item, - metrics: s.metrics, - mu: sync.Mutex{}, - } - existedRoute := s.mux.GetRoute(item.Name) - if existedRoute == nil { - s.mux.NewRoute().Name(item.Name).Methods(strings.Split(method, ",")...).Headers(headerSlices...).Path(item.Request.Path).HandlerFunc(adHandler.handle) - } else { - existedRoute.HandlerFunc(adHandler.handle) - } + method := util.EmptyThenDefault(item.Request.Method, http.MethodGet) + memLogger.Info("register mock service", "method", method, "path", item.Request.Path, "encoder", item.Response.Encoder) + + var headerSlices []string + for k, v := range item.Request.Header { + headerSlices = append(headerSlices, k, v) + } + + adHandler := &advanceHandler{ + item: &item, + metrics: s.metrics, + mu: sync.Mutex{}, + } + existedRoute := s.mux.GetRoute(item.Name) + if existedRoute == nil { + s.mux.NewRoute().Name(item.Name).Methods(strings.Split(method, ",")...).Headers(headerSlices...).Path(item.Request.Path).HandlerFunc(adHandler.handle) + } else { + existedRoute.HandlerFunc(adHandler.handle) + } } type advanceHandler struct { - item *Item - metrics RequestMetrics - mu sync.Mutex + item *Item + metrics RequestMetrics + mu sync.Mutex } func (h *advanceHandler) handle(w http.ResponseWriter, req *http.Request) { - h.mu.Lock() - defer h.mu.Unlock() - - h.metrics.RecordRequest(req.URL.Path) - memLogger.Info("receiving mock request", "name", h.item.Name, "method", req.Method, "path", req.URL.Path, - "encoder", h.item.Response.Encoder) - - h.item.Param = mux.Vars(req) - if h.item.Param == nil { - h.item.Param = make(map[string]string) - } - h.item.Param["Host"] = req.Host - if h.item.Response.Header == nil { - h.item.Response.Header = make(map[string]string) - } - h.item.Response.Header[headerMockServer] = fmt.Sprintf("api-testing: %s", version.GetVersion()) - for k, v := range h.item.Response.Header { - hv, hErr := render.Render("mock-server-header", v, &h.item) - if hErr != nil { - hv = v - memLogger.Error(hErr, "failed render mock-server-header", "value", v) - } - - w.Header().Set(k, hv) - } - - if h.item.Response.BodyFromFile != "" { - // read from file - if data, readErr := os.ReadFile(h.item.Response.BodyFromFile); readErr != nil { - memLogger.Error(readErr, "failed to read file", "file", h.item.Response.BodyFromFile) - } else { - h.item.Response.Body = string(data) - } - } - - var err error - if h.item.Response.Encoder == "base64" { - h.item.Response.BodyData, err = base64.StdEncoding.DecodeString(h.item.Response.Body) - } else if h.item.Response.Encoder == "url" { - var resp *http.Response - if resp, err = http.Get(h.item.Response.Body); err == nil { - h.item.Response.BodyData, err = io.ReadAll(resp.Body) - } - } else if h.item.Response.Encoder == "raw" { - h.item.Response.BodyData = []byte(h.item.Response.Body) - } else { - if h.item.Response.BodyData, err = render.RenderAsBytes("start-item", h.item.Response.Body, h.item); err != nil { - memLogger.Error(err, "failed to render body") - } - } - - if err == nil { - h.item.Response.Header[util.ContentLength] = fmt.Sprintf("%d", len(h.item.Response.BodyData)) - w.Header().Set(util.ContentLength, h.item.Response.Header[util.ContentLength]) - } - - writeResponse(w, h.item.Response.BodyData, err) + h.mu.Lock() + defer h.mu.Unlock() + + h.metrics.RecordRequest(req.URL.Path) + memLogger.Info("receiving mock request", "name", h.item.Name, "method", req.Method, "path", req.URL.Path, + "encoder", h.item.Response.Encoder) + + h.item.Param = mux.Vars(req) + if h.item.Param == nil { + h.item.Param = make(map[string]string) + } + h.item.Param["Host"] = req.Host + if h.item.Response.Header == nil { + h.item.Response.Header = make(map[string]string) + } + h.item.Response.Header[headerMockServer] = fmt.Sprintf("api-testing: %s", version.GetVersion()) + for k, v := range h.item.Response.Header { + hv, hErr := render.Render("mock-server-header", v, &h.item) + if hErr != nil { + hv = v + memLogger.Error(hErr, "failed render mock-server-header", "value", v) + } + + w.Header().Set(k, hv) + } + + if h.item.Response.BodyFromFile != "" { + // read from file + if data, readErr := os.ReadFile(h.item.Response.BodyFromFile); readErr != nil { + memLogger.Error(readErr, "failed to read file", "file", h.item.Response.BodyFromFile) + } else { + h.item.Response.Body = string(data) + } + } + + var err error + if h.item.Response.Encoder == "base64" { + h.item.Response.BodyData, err = base64.StdEncoding.DecodeString(h.item.Response.Body) + } else if h.item.Response.Encoder == "url" { + var resp *http.Response + if resp, err = http.Get(h.item.Response.Body); err == nil { + h.item.Response.BodyData, err = io.ReadAll(resp.Body) + } + } else if h.item.Response.Encoder == "raw" { + h.item.Response.BodyData = []byte(h.item.Response.Body) + } else { + if h.item.Response.BodyData, err = render.RenderAsBytes("start-item", h.item.Response.Body, h.item); err != nil { + memLogger.Error(err, "failed to render body") + } + } + + if err == nil { + h.item.Response.Header[util.ContentLength] = fmt.Sprintf("%d", len(h.item.Response.BodyData)) + w.Header().Set(util.ContentLength, h.item.Response.Header[util.ContentLength]) + } + + writeResponse(w, h.item.Response.BodyData, err) } func writeResponse(w http.ResponseWriter, data []byte, err error) { - if err == nil { - w.Write(data) - } else { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(err.Error())) - } + if err == nil { + w.Write(data) + } else { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) + } } func (s *inMemoryServer) initObjectData(obj Object) { - if obj.Sample == "" { - return - } - - defaultCount := 1 - if obj.InitCount == nil { - obj.InitCount = &defaultCount - } - - for i := 0; i < *obj.InitCount; i++ { - objData, jsonErr := jsonStrToInterface(obj.Sample) - if jsonErr == nil { - s.data[obj.Name] = append(s.data[obj.Name], objData) - } else { - memLogger.Info(jsonErr.Error()) - } - } + if obj.Sample == "" { + return + } + + defaultCount := 1 + if obj.InitCount == nil { + obj.InitCount = &defaultCount + } + + for i := 0; i < *obj.InitCount; i++ { + objData, jsonErr := jsonStrToInterface(obj.Sample) + if jsonErr == nil { + s.data[obj.Name] = append(s.data[obj.Name], objData) + } else { + memLogger.Info(jsonErr.Error()) + } + } } func (s *inMemoryServer) startWebhook(webhook *Webhook) (err error) { - if webhook.Timer == "" || webhook.Name == "" { - return - } - - var duration time.Duration - duration, err = time.ParseDuration(webhook.Timer) - if err != nil { - memLogger.Error(err, "Error parsing webhook timer") - return - } - - s.wg.Add(1) - go func(wh *Webhook) { - defer s.wg.Done() - - memLogger.Info("start webhook server", "name", wh.Name) - timer := time.NewTimer(duration) - for { - timer.Reset(duration) - select { - case <-s.ctx.Done(): - memLogger.Info("stop webhook server", "name", wh.Name) - return - case <-timer.C: - if err = runWebhook(s.ctx, s, wh); err != nil { - memLogger.Error(err, "Error when run webhook") - } - } - } - }(webhook) - return + if webhook.Timer == "" || webhook.Name == "" { + return + } + + var duration time.Duration + duration, err = time.ParseDuration(webhook.Timer) + if err != nil { + memLogger.Error(err, "Error parsing webhook timer") + return + } + + s.wg.Add(1) + go func(wh *Webhook) { + defer s.wg.Done() + + memLogger.Info("start webhook server", "name", wh.Name) + timer := time.NewTimer(duration) + for { + timer.Reset(duration) + select { + case <-s.ctx.Done(): + memLogger.Info("stop webhook server", "name", wh.Name) + return + case <-timer.C: + if err = runWebhook(s.ctx, s, wh); err != nil { + memLogger.Error(err, "Error when run webhook") + } + } + } + }(webhook) + return } func runWebhook(ctx context.Context, objCtx interface{}, wh *Webhook) (err error) { - rawParams := make(map[string]string, len(wh.Param)) - paramKeys := make([]string, 0, len(wh.Param)) - for k, v := range wh.Param { - paramKeys = append(paramKeys, k) - rawParams[k] = v - } - sort.Strings(paramKeys) - - for _, k := range paramKeys { - v, vErr := render.Render("mock webhook server param", wh.Param[k], wh) - if vErr == nil { - wh.Param[k] = v - } - } - - var payload io.Reader - payload, err = render.RenderAsReader("mock webhook server payload", wh.Request.Body, wh) - if err != nil { - err = fmt.Errorf("error when render payload: %w", err) - return - } - wh.Param = rawParams - - var api string - api, err = render.Render("webhook request api", wh.Request.Path, objCtx) - if err != nil { - err = fmt.Errorf("error when render api: %w, template: %s", err, wh.Request.Path) - return - } - - switch wh.Request.Protocol { - case "syslog": - err = sendSyslogWebhookRequest(ctx, wh, api, payload) - default: - err = sendHTTPWebhookRequest(ctx, wh, api, payload) - } - return + rawParams := make(map[string]string, len(wh.Param)) + paramKeys := make([]string, 0, len(wh.Param)) + for k, v := range wh.Param { + paramKeys = append(paramKeys, k) + rawParams[k] = v + } + sort.Strings(paramKeys) + + for _, k := range paramKeys { + v, vErr := render.Render("mock webhook server param", wh.Param[k], wh) + if vErr == nil { + wh.Param[k] = v + } + } + + var payload io.Reader + payload, err = render.RenderAsReader("mock webhook server payload", wh.Request.Body, wh) + if err != nil { + err = fmt.Errorf("error when render payload: %w", err) + return + } + wh.Param = rawParams + + var api string + api, err = render.Render("webhook request api", wh.Request.Path, objCtx) + if err != nil { + err = fmt.Errorf("error when render api: %w, template: %s", err, wh.Request.Path) + return + } + + switch wh.Request.Protocol { + case "syslog": + err = sendSyslogWebhookRequest(ctx, wh, api, payload) + default: + err = sendHTTPWebhookRequest(ctx, wh, api, payload) + } + return } func sendSyslogWebhookRequest(ctx context.Context, wh *Webhook, api string, payload io.Reader) (err error) { - var conn net.Conn - if conn, err = net.Dial("udp", api); err == nil { - _, err = io.Copy(conn, payload) - } - return + var conn net.Conn + if conn, err = net.Dial("udp", api); err == nil { + _, err = io.Copy(conn, payload) + } + return } func sendHTTPWebhookRequest(ctx context.Context, wh *Webhook, api string, payload io.Reader) (err error) { - method := util.EmptyThenDefault(wh.Request.Method, http.MethodPost) - client := http.DefaultClient - - var bearerToken string - bearerToken, err = getBearerToken(ctx, wh.Request) - if err != nil { - memLogger.Error(err, "Error when render bearer token") - return - } - - var req *http.Request - req, err = http.NewRequestWithContext(ctx, method, api, payload) - if err != nil { - memLogger.Error(err, "Error when create request") - return - } - - if bearerToken != "" { - memLogger.V(7).Info("set bearer token", "token", bearerToken) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) - } - - for k, v := range wh.Request.Header { - req.Header.Set(k, v) - } - - memLogger.Info("send webhook request", "api", api) - resp, err := client.Do(req) - if err != nil { - err = fmt.Errorf("error when sending webhook: %v", err) - } else { - if resp.StatusCode != http.StatusOK { - memLogger.Info("unexpected status", "code", resp.StatusCode) - } - - data, _ := io.ReadAll(resp.Body) - memLogger.V(7).Info("received from webhook", "code", resp.StatusCode, "response", string(data)) - } - return + method := util.EmptyThenDefault(wh.Request.Method, http.MethodPost) + client := http.DefaultClient + + var bearerToken string + bearerToken, err = getBearerToken(ctx, wh.Request) + if err != nil { + memLogger.Error(err, "Error when render bearer token") + return + } + + var req *http.Request + req, err = http.NewRequestWithContext(ctx, method, api, payload) + if err != nil { + memLogger.Error(err, "Error when create request") + return + } + + if bearerToken != "" { + memLogger.V(7).Info("set bearer token", "token", bearerToken) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) + } + + for k, v := range wh.Request.Header { + req.Header.Set(k, v) + } + + memLogger.Info("send webhook request", "api", api) + resp, err := client.Do(req) + if err != nil { + err = fmt.Errorf("error when sending webhook: %v", err) + } else { + if resp.StatusCode != http.StatusOK { + memLogger.Info("unexpected status", "code", resp.StatusCode) + } + + data, _ := io.ReadAll(resp.Body) + memLogger.V(7).Info("received from webhook", "code", resp.StatusCode, "response", string(data)) + } + return } type bearerToken struct { - Token string `json:"token"` + Token string `json:"token"` } func getBearerToken(ctx context.Context, request RequestWithAuth) (token string, err error) { - if request.BearerAPI == "" { - return - } - - if request.BearerAPI, err = render.Render("bearer token request", request.BearerAPI, &request); err != nil { - return - } - - var data []byte - if data, err = json.Marshal(&request); err == nil { - client := http.DefaultClient - var req *http.Request - if req, err = http.NewRequestWithContext(ctx, http.MethodPost, request.BearerAPI, bytes.NewBuffer(data)); err == nil { - req.Header.Set(util.ContentType, util.JSON) - - var resp *http.Response - if resp, err = client.Do(req); err == nil && resp.StatusCode == http.StatusOK { - if data, err = io.ReadAll(resp.Body); err == nil { - var tokenObj bearerToken - if err = json.Unmarshal(data, &tokenObj); err == nil { - token = tokenObj.Token - } - } - } - } - } - - return + if request.BearerAPI == "" { + return + } + + if request.BearerAPI, err = render.Render("bearer token request", request.BearerAPI, &request); err != nil { + return + } + + var data []byte + if data, err = json.Marshal(&request); err == nil { + client := http.DefaultClient + var req *http.Request + if req, err = http.NewRequestWithContext(ctx, http.MethodPost, request.BearerAPI, bytes.NewBuffer(data)); err == nil { + req.Header.Set(util.ContentType, util.JSON) + + var resp *http.Response + if resp, err = client.Do(req); err == nil && resp.StatusCode == http.StatusOK { + if data, err = io.ReadAll(resp.Body); err == nil { + var tokenObj bearerToken + if err = json.Unmarshal(data, &tokenObj); err == nil { + token = tokenObj.Token + } + } + } + } + } + + return } func (s *inMemoryServer) handleOpenAPI() { - s.mux.HandleFunc("/api.json", func(w http.ResponseWriter, req *http.Request) { - // Setup OpenAPI schema - reflector := openapi3.NewReflector() - reflector.SpecSchema().SetTitle("Mock Server API") - reflector.SpecSchema().SetVersion(version.GetVersion()) - reflector.SpecSchema().SetDescription("Powered by https://github.com/linuxsuren/api-testing") - - // Walk the router with OpenAPI collector - c := gorillamux.NewOpenAPICollector(reflector) - - _ = s.mux.Walk(c.Walker) - - // Get the resulting schema - if jsonData, err := reflector.Spec.MarshalJSON(); err == nil { - w.Header().Set(util.ContentType, util.JSON) - _, _ = w.Write(jsonData) - } else { - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(err.Error())) - } - }) + s.mux.HandleFunc("/api.json", func(w http.ResponseWriter, req *http.Request) { + // Setup OpenAPI schema + reflector := openapi3.NewReflector() + reflector.SpecSchema().SetTitle("Mock Server API") + reflector.SpecSchema().SetVersion(version.GetVersion()) + reflector.SpecSchema().SetDescription("Powered by https://github.com/linuxsuren/api-testing") + + // Walk the router with OpenAPI collector + c := gorillamux.NewOpenAPICollector(reflector) + + _ = s.mux.Walk(c.Walker) + + // Get the resulting schema + if jsonData, err := reflector.Spec.MarshalJSON(); err == nil { + w.Header().Set(util.ContentType, util.JSON) + _, _ = w.Write(jsonData) + } else { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(err.Error())) + } + }) } func jsonStrToInterface(jsonStr string) (objData map[string]interface{}, err error) { - if jsonStr, err = render.Render("init object", jsonStr, nil); err == nil { - objData = map[string]interface{}{} - err = json.Unmarshal([]byte(jsonStr), &objData) - } - return + if jsonStr, err = render.Render("init object", jsonStr, nil); err == nil { + objData = map[string]interface{}{} + err = json.Unmarshal([]byte(jsonStr), &objData) + } + return } func (s *inMemoryServer) GetPort() string { - return util.GetPort(s.listener) + return util.GetPort(s.listener) } func (s *inMemoryServer) Stop() (err error) { - if s.listener != nil { - if err = s.listener.Close(); err != nil { - memLogger.Error(err, "failed to close listener") - } - } else { - memLogger.Info("listener is nil") - } - if s.cancelFunc != nil { - s.cancelFunc() - } - s.wg.Wait() - return + if s.listener != nil { + if err = s.listener.Close(); err != nil { + memLogger.Error(err, "failed to close listener") + } + } else { + memLogger.Info("listener is nil") + } + if s.cancelFunc != nil { + s.cancelFunc() + } + s.wg.Wait() + return } diff --git a/pkg/mock/in_memory_test.go b/pkg/mock/in_memory_test.go index 74a5dc68..591ac5f4 100644 --- a/pkg/mock/in_memory_test.go +++ b/pkg/mock/in_memory_test.go @@ -16,238 +16,238 @@ limitations under the License. package mock import ( - "bytes" - "context" - "fmt" - "io" - "net/http" - "strings" - "testing" - - _ "embed" - "github.com/linuxsuren/api-testing/pkg/util" - "github.com/stretchr/testify/assert" + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + + _ "embed" + "github.com/linuxsuren/api-testing/pkg/util" + "github.com/stretchr/testify/assert" ) //go:embed testdata/api.yaml var mockFile []byte func TestInMemoryServer(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - server.EnableMetrics() + server := NewInMemoryServer(context.Background(), 0) + server.EnableMetrics() - err := server.Start(NewLocalFileReader("testdata/api.yaml"), "/mock") - assert.NoError(t, err) - defer func() { - server.Stop() - }() + err := server.Start(NewLocalFileReader("testdata/api.yaml"), "/mock") + assert.NoError(t, err) + defer func() { + server.Stop() + }() - api := "http://localhost:" + server.GetPort() + "/mock" + api := "http://localhost:" + server.GetPort() + "/mock" - _, err = http.Post(api+"/team", "", bytes.NewBufferString(`{ + _, err = http.Post(api+"/team", "", bytes.NewBufferString(`{ "name": "test", "members": [] }`)) - assert.NoError(t, err) - - var resp *http.Response - resp, err = http.Get(api + "/team") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, `[{"name":"someone"},{"members":[],"name":"test"}]`, string(data)) - } - - t.Run("check the /api.json", func(t *testing.T) { - var resp *http.Response - resp, err = http.Get(api + "/api.json") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.NotEmpty(t, string(data)) - } - }) - - t.Run("list with filter", func(t *testing.T) { - var resp *http.Response - resp, err = http.Get(api + "/team?name=someone") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, `[{"name":"someone"}]`, string(data)) - } - }) - - t.Run("update object", func(t *testing.T) { - updateReq, err := http.NewRequest(http.MethodPut, api+"/team/test", bytes.NewBufferString(`{ + assert.NoError(t, err) + + var resp *http.Response + resp, err = http.Get(api + "/team") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, `[{"name":"someone"},{"members":[],"name":"test"}]`, string(data)) + } + + t.Run("check the /api.json", func(t *testing.T) { + var resp *http.Response + resp, err = http.Get(api + "/api.json") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.NotEmpty(t, string(data)) + } + }) + + t.Run("list with filter", func(t *testing.T) { + var resp *http.Response + resp, err = http.Get(api + "/team?name=someone") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, `[{"name":"someone"}]`, string(data)) + } + }) + + t.Run("update object", func(t *testing.T) { + updateReq, err := http.NewRequest(http.MethodPut, api+"/team/test", bytes.NewBufferString(`{ "name": "test", "members": [{ "name": "rick" }] }`)) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(updateReq) - assert.NoError(t, err) - }) - - t.Run("get a single object", func(t *testing.T) { - resp, err = http.Get(api + "/team/test") - assert.NoError(t, err) - - var data []byte - data, err = io.ReadAll(resp.Body) - assert.NoError(t, err) - - assert.Equal(t, `{"members":[{"name":"rick"}],"name":"test"}`, string(data)) - }) - - // delete object - delReq, err := http.NewRequest(http.MethodDelete, api+"/team/test", nil) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(delReq) - assert.NoError(t, err) - - t.Run("check if deleted", func(t *testing.T) { - var resp *http.Response - resp, err = http.Get(api + "/team") - if assert.NoError(t, err) { - data, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, `[{"name":"someone"}]`, string(data)) - } - - resp, err = http.Get(api + "/team/test") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - } - }) - - t.Run("invalid request method", func(t *testing.T) { - delReq, err := http.NewRequest("fake", api+"/team", nil) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(delReq) - assert.NoError(t, err) - assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) - }) - - t.Run("only accept GET method in getting a single object", func(t *testing.T) { - wrongMethodReq, err := http.NewRequest(http.MethodPut, api+"/team", nil) - assert.NoError(t, err) - resp, err = http.DefaultClient.Do(wrongMethodReq) - assert.NoError(t, err) - assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) - }) - - t.Run("mock item", func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) - assert.NoError(t, err) - req.Header.Set("name", "rick") - - resp, err = http.DefaultClient.Do(req) - assert.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "194", resp.Header.Get(util.ContentLength)) - assert.Equal(t, "mock", resp.Header.Get("Server")) - assert.NotEmpty(t, resp.Header.Get(headerMockServer)) - - data, _ := io.ReadAll(resp.Body) - assert.True(t, strings.Contains(string(data), `"message": "mock"`), string(data)) - }) - - t.Run("miss match header", func(t *testing.T) { - req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) - assert.NoError(t, err) - - resp, err = http.DefaultClient.Do(req) - assert.NoError(t, err) - assert.Equal(t, http.StatusNotFound, resp.StatusCode) - }) - - t.Run("base64 encoder", func(t *testing.T) { - resp, err = http.Get(api + "/v1/base64") - assert.NoError(t, err) - data, _ := io.ReadAll(resp.Body) - assert.Equal(t, "hello", string(data)) - }) - - t.Run("read response from file", func(t *testing.T) { - resp, err = http.Get(api + "/v1/readResponseFromFile") - assert.NoError(t, err) - data, _ := io.ReadAll(resp.Body) - assert.Equal(t, mockFile, data) - }) - - t.Run("not found config file", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewLocalFileReader("fake"), "/") - assert.Error(t, err) - }) - - t.Run("invalid webhook", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(updateReq) + assert.NoError(t, err) + }) + + t.Run("get a single object", func(t *testing.T) { + resp, err = http.Get(api + "/team/test") + assert.NoError(t, err) + + var data []byte + data, err = io.ReadAll(resp.Body) + assert.NoError(t, err) + + assert.Equal(t, `{"members":[{"name":"rick"}],"name":"test"}`, string(data)) + }) + + // delete object + delReq, err := http.NewRequest(http.MethodDelete, api+"/team/test", nil) + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(delReq) + assert.NoError(t, err) + + t.Run("check if deleted", func(t *testing.T) { + var resp *http.Response + resp, err = http.Get(api + "/team") + if assert.NoError(t, err) { + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, `[{"name":"someone"}]`, string(data)) + } + + resp, err = http.Get(api + "/team/test") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + } + }) + + t.Run("invalid request method", func(t *testing.T) { + delReq, err := http.NewRequest("fake", api+"/team", nil) + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(delReq) + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + }) + + t.Run("only accept GET method in getting a single object", func(t *testing.T) { + wrongMethodReq, err := http.NewRequest(http.MethodPut, api+"/team", nil) + assert.NoError(t, err) + resp, err = http.DefaultClient.Do(wrongMethodReq) + assert.NoError(t, err) + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) + }) + + t.Run("mock item", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) + assert.NoError(t, err) + req.Header.Set("name", "rick") + + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "194", resp.Header.Get(util.ContentLength)) + assert.Equal(t, "mock", resp.Header.Get("Server")) + assert.NotEmpty(t, resp.Header.Get(headerMockServer)) + + data, _ := io.ReadAll(resp.Body) + assert.True(t, strings.Contains(string(data), `"message": "mock"`), string(data)) + }) + + t.Run("miss match header", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, api+"/v1/repos/test/prs", nil) + assert.NoError(t, err) + + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("base64 encoder", func(t *testing.T) { + resp, err = http.Get(api + "/v1/base64") + assert.NoError(t, err) + data, _ := io.ReadAll(resp.Body) + assert.Equal(t, "hello", string(data)) + }) + + t.Run("read response from file", func(t *testing.T) { + resp, err = http.Get(api + "/v1/readResponseFromFile") + assert.NoError(t, err) + data, _ := io.ReadAll(resp.Body) + assert.Equal(t, mockFile, data) + }) + + t.Run("not found config file", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewLocalFileReader("fake"), "/") + assert.Error(t, err) + }) + + t.Run("invalid webhook", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - timer: aa name: fake`), "/") - assert.Error(t, err) - }) + assert.Error(t, err) + }) - t.Run("missing name or timer in webhook", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + t.Run("missing name or timer in webhook", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - timer: 1s`), "/") - assert.Error(t, err) - }) + assert.Error(t, err) + }) - t.Run("invalid webhook payload", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + t.Run("invalid webhook payload", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - name: invalid timer: 1ms request: body: "{{.fake"`), "/") - assert.Error(t, err) - }) + assert.Error(t, err) + }) - t.Run("invalid webhook api template", func(t *testing.T) { - server := NewInMemoryServer(context.Background(), 0) - err := server.Start(NewInMemoryReader(`webhooks: + t.Run("invalid webhook api template", func(t *testing.T) { + server := NewInMemoryServer(context.Background(), 0) + err := server.Start(NewInMemoryReader(`webhooks: - name: invalid timer: 1ms request: body: "{}" path: "{{.fake"`), "/") - assert.NoError(t, err) - }) - - t.Run("proxy", func(t *testing.T) { - resp, err = http.Get(api + "/v1/myProjects") - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - resp, err = http.Get(api + "/v1/invalid-template") - assert.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - }) - - t.Run("metrics", func(t *testing.T) { - resp, err = http.Get(api + "/metrics") - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - }) - - t.Run("go template support in response body", func(t *testing.T) { - repoName := "myRepo" - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/v1/repos/%s/prs", api, repoName), nil) - assert.NoError(t, err) - - var resp *http.Response - req.Header.Set("name", "rick") - resp, err = http.DefaultClient.Do(req) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - data, _ := io.ReadAll(resp.Body) - assert.Contains(t, string(data), repoName) - }) + assert.NoError(t, err) + }) + + t.Run("proxy", func(t *testing.T) { + resp, err = http.Get(api + "/v1/myProjects") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + resp, err = http.Get(api + "/v1/invalid-template") + assert.NoError(t, err) + assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + }) + + t.Run("metrics", func(t *testing.T) { + resp, err = http.Get(api + "/metrics") + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("go template support in response body", func(t *testing.T) { + repoName := "myRepo" + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/v1/repos/%s/prs", api, repoName), nil) + assert.NoError(t, err) + + var resp *http.Response + req.Header.Set("name", "rick") + resp, err = http.DefaultClient.Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + data, _ := io.ReadAll(resp.Body) + assert.Contains(t, string(data), repoName) + }) } diff --git a/pkg/server/remote_server.go b/pkg/server/remote_server.go index 28ae8bcb..8f6d55e6 100644 --- a/pkg/server/remote_server.go +++ b/pkg/server/remote_server.go @@ -17,78 +17,78 @@ limitations under the License. package server import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "mime" - "net/http" - "os" - "path/filepath" - reflect "reflect" - "regexp" - "slices" - "strconv" - "strings" - "sync" - "time" - - "github.com/expr-lang/expr/builtin" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - - "github.com/linuxsuren/api-testing/docs" - "github.com/linuxsuren/api-testing/pkg/util/home" - - "github.com/linuxsuren/api-testing/pkg/mock" - - _ "embed" - - "github.com/linuxsuren/api-testing/pkg/generator" - "github.com/linuxsuren/api-testing/pkg/logging" - "github.com/linuxsuren/api-testing/pkg/oauth" - "github.com/linuxsuren/api-testing/pkg/render" - "github.com/linuxsuren/api-testing/pkg/runner" - "github.com/linuxsuren/api-testing/pkg/testing" - "github.com/linuxsuren/api-testing/pkg/util" - "github.com/linuxsuren/api-testing/pkg/version" - "github.com/linuxsuren/api-testing/sample" - - "google.golang.org/grpc/metadata" - "gopkg.in/yaml.v3" + "bytes" + "context" + "errors" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + reflect "reflect" + "regexp" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/expr-lang/expr/builtin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/linuxsuren/api-testing/docs" + "github.com/linuxsuren/api-testing/pkg/util/home" + + "github.com/linuxsuren/api-testing/pkg/mock" + + _ "embed" + + "github.com/linuxsuren/api-testing/pkg/generator" + "github.com/linuxsuren/api-testing/pkg/logging" + "github.com/linuxsuren/api-testing/pkg/oauth" + "github.com/linuxsuren/api-testing/pkg/render" + "github.com/linuxsuren/api-testing/pkg/runner" + "github.com/linuxsuren/api-testing/pkg/testing" + "github.com/linuxsuren/api-testing/pkg/util" + "github.com/linuxsuren/api-testing/pkg/version" + "github.com/linuxsuren/api-testing/sample" + + "google.golang.org/grpc/metadata" + "gopkg.in/yaml.v3" ) var ( - remoteServerLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("remote_server") - GrpcMaxRecvMsgSize int + remoteServerLogger = logging.DefaultLogger(logging.LogLevelInfo).WithName("remote_server") + GrpcMaxRecvMsgSize int ) type server struct { - UnimplementedRunnerServer - UnimplementedDataServerServer - UnimplementedThemeExtensionServer + UnimplementedRunnerServer + UnimplementedDataServerServer + UnimplementedThemeExtensionServer - loader testing.Writer - storeWriterFactory testing.StoreWriterFactory - configDir string - storeExtMgr ExtManager + loader testing.Writer + storeWriterFactory testing.StoreWriterFactory + configDir string + storeExtMgr ExtManager - secretServer SecretServiceServer + secretServer SecretServiceServer - grpcMaxRecvMsgSize int + grpcMaxRecvMsgSize int } type SecretServiceServer interface { - GetSecrets(context.Context, *Empty) (*Secrets, error) - CreateSecret(context.Context, *Secret) (*CommonResult, error) - DeleteSecret(context.Context, *Secret) (*CommonResult, error) - UpdateSecret(context.Context, *Secret) (*CommonResult, error) + GetSecrets(context.Context, *Empty) (*Secrets, error) + CreateSecret(context.Context, *Secret) (*CommonResult, error) + DeleteSecret(context.Context, *Secret) (*CommonResult, error) + UpdateSecret(context.Context, *Secret) (*CommonResult, error) } type SecertServiceGetable interface { - GetSecret(context.Context, *Secret) (*Secret, error) + GetSecret(context.Context, *Secret) (*Secret, error) } type fakeSecretServer struct{} @@ -96,1453 +96,1453 @@ type fakeSecretServer struct{} var errNoSecretService = errors.New("no secret service found") func (f *fakeSecretServer) GetSecrets(ctx context.Context, in *Empty) (reply *Secrets, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } func (f *fakeSecretServer) CreateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } func (f *fakeSecretServer) DeleteSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } func (f *fakeSecretServer) UpdateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - err = errNoSecretService - return + err = errNoSecretService + return } // NewRemoteServer creates a remote server instance func NewRemoteServer(loader testing.Writer, storeWriterFactory testing.StoreWriterFactory, secretServer SecretServiceServer, storeExtMgr ExtManager, configDir string, grpcMaxRecvMsgSize int) RunnerServer { - if secretServer == nil { - secretServer = &fakeSecretServer{} - } - GrpcMaxRecvMsgSize = grpcMaxRecvMsgSize - return &server{ - loader: loader, - storeWriterFactory: storeWriterFactory, - configDir: configDir, - secretServer: secretServer, - storeExtMgr: storeExtMgr, - grpcMaxRecvMsgSize: grpcMaxRecvMsgSize, - } + if secretServer == nil { + secretServer = &fakeSecretServer{} + } + GrpcMaxRecvMsgSize = grpcMaxRecvMsgSize + return &server{ + loader: loader, + storeWriterFactory: storeWriterFactory, + configDir: configDir, + secretServer: secretServer, + storeExtMgr: storeExtMgr, + grpcMaxRecvMsgSize: grpcMaxRecvMsgSize, + } } func withDefaultValue(old, defVal any) any { - if old == "" || old == nil { - old = defVal - } - return old + if old == "" || old == nil { + old = defVal + } + return old } func parseSuiteWithItems(data []byte) (suite *testing.TestSuite, err error) { - suite, err = testing.ParseFromData(data) - if err == nil && (suite == nil || suite.Items == nil) { - err = errNoTestSuiteFound - } - return + suite, err = testing.ParseFromData(data) + if err == nil && (suite == nil || suite.Items == nil) { + err = errNoTestSuiteFound + } + return } func (s *server) getSuiteFromTestTask(task *TestTask) (suite *testing.TestSuite, err error) { - switch task.Kind { - case "suite": - suite, err = parseSuiteWithItems([]byte(task.Data)) - case "testcase": - var testCase *testing.TestCase - if testCase, err = testing.ParseTestCaseFromData([]byte(task.Data)); err != nil { - return - } - suite = &testing.TestSuite{ - Items: []testing.TestCase{*testCase}, - } - case "testcaseInSuite": - suite, err = parseSuiteWithItems([]byte(task.Data)) - if err != nil { - return - } - - var targetTestcase *testing.TestCase - for _, item := range suite.Items { - if item.Name == task.CaseName { - targetTestcase = &item - break - } - } - - if targetTestcase != nil { - parentCases := findParentTestCases(targetTestcase, suite) - remoteServerLogger.Info("find parent cases", "num", len(parentCases)) - suite.Items = append(parentCases, *targetTestcase) - } else { - err = fmt.Errorf("cannot found testcase %s", task.CaseName) - } - default: - err = fmt.Errorf("not support '%s'", task.Kind) - } - return + switch task.Kind { + case "suite": + suite, err = parseSuiteWithItems([]byte(task.Data)) + case "testcase": + var testCase *testing.TestCase + if testCase, err = testing.ParseTestCaseFromData([]byte(task.Data)); err != nil { + return + } + suite = &testing.TestSuite{ + Items: []testing.TestCase{*testCase}, + } + case "testcaseInSuite": + suite, err = parseSuiteWithItems([]byte(task.Data)) + if err != nil { + return + } + + var targetTestcase *testing.TestCase + for _, item := range suite.Items { + if item.Name == task.CaseName { + targetTestcase = &item + break + } + } + + if targetTestcase != nil { + parentCases := findParentTestCases(targetTestcase, suite) + remoteServerLogger.Info("find parent cases", "num", len(parentCases)) + suite.Items = append(parentCases, *targetTestcase) + } else { + err = fmt.Errorf("cannot found testcase %s", task.CaseName) + } + default: + err = fmt.Errorf("not support '%s'", task.Kind) + } + return } func resetEnv(oldEnv map[string]string) { - for key, val := range oldEnv { - os.Setenv(key, val) - } + for key, val := range oldEnv { + os.Setenv(key, val) + } } func (s *server) getLoader(ctx context.Context) (loader testing.Writer) { - var ok bool - loader = s.loader - - var mdd metadata.MD - if mdd, ok = metadata.FromIncomingContext(ctx); ok { - storeNameMeta := mdd.Get(HeaderKeyStoreName) - if len(storeNameMeta) > 0 { - storeName := strings.TrimSpace(storeNameMeta[0]) - if storeName == "local" || storeName == "" { - return - } - - var err error - if loader, err = s.getLoaderByStoreName(storeName); err != nil { - remoteServerLogger.Info("failed to get loader", "name", storeName, "error", err) - loader = testing.NewNonWriter() - } - } - } - return + var ok bool + loader = s.loader + + var mdd metadata.MD + if mdd, ok = metadata.FromIncomingContext(ctx); ok { + storeNameMeta := mdd.Get(HeaderKeyStoreName) + if len(storeNameMeta) > 0 { + storeName := strings.TrimSpace(storeNameMeta[0]) + if storeName == "local" || storeName == "" { + return + } + + var err error + if loader, err = s.getLoaderByStoreName(storeName); err != nil { + remoteServerLogger.Info("failed to get loader", "name", storeName, "error", err) + loader = testing.NewNonWriter() + } + } + } + return } // Run start to run the test task func (s *server) Run(ctx context.Context, task *TestTask) (reply *TestResult, err error) { - task.Level = withDefaultValue(task.Level, "info").(string) - task.Env = withDefaultValue(task.Env, map[string]string{}).(map[string]string) - - var suite *testing.TestSuite - // TODO may not safe in multiple threads - oldEnv := map[string]string{} - for key, val := range task.Env { - oldEnv[key] = os.Getenv(key) - os.Setenv(key, val) - } - - defer func() { - resetEnv(oldEnv) - }() - - if suite, err = s.getSuiteFromTestTask(task); err != nil { - return - } - - remoteServerLogger.Info("prepare to run", "name", suite.Name, " with level: ", task.Level) - remoteServerLogger.Info("task kind to run", "kind", task.Kind, "lens", len(suite.Items)) - dataContext := map[string]interface{}{} - - if err = suite.Render(dataContext); err != nil { - reply.Error = err.Error() - err = nil - return - } - // inject the parameters from input - if len(task.Parameters) > 0 { - dataContext[testing.ContextKeyGlobalParam] = pairToMap(task.Parameters) - } - - buf := new(bytes.Buffer) - reply = &TestResult{} - - for _, testCase := range suite.Items { - suiteRunner := runner.GetTestSuiteRunner(suite) - suiteRunner.WithOutputWriter(buf) - suiteRunner.WithWriteLevel(task.Level) - suiteRunner.WithSecure(suite.Spec.Secure) - suiteRunner.WithSuite(suite) - - // reuse the API prefix - testCase.Request.RenderAPI(suite.API) - historyHeader := make(map[string]string) - for k, v := range testCase.Request.Header { - historyHeader[k] = v - } - - output, testErr := suiteRunner.RunTestCase(&testCase, dataContext, ctx) - if getter, ok := suiteRunner.(runner.ResponseRecord); ok { - resp := getter.GetResponseRecord() - //resp, err = runner.HandleLargeResponseBody(resp, suite.Name, testCase.Name) - reply.TestCaseResult = append(reply.TestCaseResult, &TestCaseResult{ - StatusCode: int32(resp.StatusCode), - Body: resp.Body, - Header: mapToPair(resp.Header), - Id: testCase.ID, - Output: buf.String(), - }) - } - - if testErr == nil { - dataContext[testCase.Name] = output - } else { - reply.Error = testErr.Error() - break - } - // create history record - go func(historyHeader map[string]string) { - loader := s.getLoader(ctx) - defer loader.Close() - for _, testCaseResult := range reply.TestCaseResult { - err = loader.CreateHistoryTestCase(ToNormalTestCaseResult(testCaseResult), suite, historyHeader) - if err != nil { - remoteServerLogger.Info("error create history") - } - } - }(historyHeader) - } - - if reply.Error != "" { - fmt.Fprintln(buf, reply.Error) - } - reply.Message = buf.String() - return + task.Level = withDefaultValue(task.Level, "info").(string) + task.Env = withDefaultValue(task.Env, map[string]string{}).(map[string]string) + + var suite *testing.TestSuite + // TODO may not safe in multiple threads + oldEnv := map[string]string{} + for key, val := range task.Env { + oldEnv[key] = os.Getenv(key) + os.Setenv(key, val) + } + + defer func() { + resetEnv(oldEnv) + }() + + if suite, err = s.getSuiteFromTestTask(task); err != nil { + return + } + + remoteServerLogger.Info("prepare to run", "name", suite.Name, " with level: ", task.Level) + remoteServerLogger.Info("task kind to run", "kind", task.Kind, "lens", len(suite.Items)) + dataContext := map[string]interface{}{} + + if err = suite.Render(dataContext); err != nil { + reply.Error = err.Error() + err = nil + return + } + // inject the parameters from input + if len(task.Parameters) > 0 { + dataContext[testing.ContextKeyGlobalParam] = pairToMap(task.Parameters) + } + + buf := new(bytes.Buffer) + reply = &TestResult{} + + for _, testCase := range suite.Items { + suiteRunner := runner.GetTestSuiteRunner(suite) + suiteRunner.WithOutputWriter(buf) + suiteRunner.WithWriteLevel(task.Level) + suiteRunner.WithSecure(suite.Spec.Secure) + suiteRunner.WithSuite(suite) + + // reuse the API prefix + testCase.Request.RenderAPI(suite.API) + historyHeader := make(map[string]string) + for k, v := range testCase.Request.Header { + historyHeader[k] = v + } + + output, testErr := suiteRunner.RunTestCase(&testCase, dataContext, ctx) + if getter, ok := suiteRunner.(runner.ResponseRecord); ok { + resp := getter.GetResponseRecord() + //resp, err = runner.HandleLargeResponseBody(resp, suite.Name, testCase.Name) + reply.TestCaseResult = append(reply.TestCaseResult, &TestCaseResult{ + StatusCode: int32(resp.StatusCode), + Body: resp.Body, + Header: mapToPair(resp.Header), + Id: testCase.ID, + Output: buf.String(), + }) + } + + if testErr == nil { + dataContext[testCase.Name] = output + } else { + reply.Error = testErr.Error() + break + } + // create history record + go func(historyHeader map[string]string) { + loader := s.getLoader(ctx) + defer loader.Close() + for _, testCaseResult := range reply.TestCaseResult { + err = loader.CreateHistoryTestCase(ToNormalTestCaseResult(testCaseResult), suite, historyHeader) + if err != nil { + remoteServerLogger.Info("error create history") + } + } + }(historyHeader) + } + + if reply.Error != "" { + fmt.Fprintln(buf, reply.Error) + } + reply.Message = buf.String() + return } func (s *server) BatchRun(srv Runner_BatchRunServer) (err error) { - ctx := srv.Context() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - var in *BatchTestTask - in, err = srv.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - - for i := 0; i < int(in.Count); i++ { - var reply *TestCaseResult - if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ - Suite: in.SuiteName, - Testcase: in.CaseName, - }); err != nil { - return - } - - if err = srv.Send(&TestResult{ - TestCaseResult: []*TestCaseResult{reply}, - Error: reply.Error, - }); err != nil { - return err - } - - var interval string - if interval, err = render.Render("batch run interval", in.Interval, nil); err != nil { - return - } - - var duration time.Duration - if duration, err = time.ParseDuration(interval); err != nil { - return - } - time.Sleep(duration) - } - } - } + ctx := srv.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + var in *BatchTestTask + in, err = srv.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + for i := 0; i < int(in.Count); i++ { + var reply *TestCaseResult + if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ + Suite: in.SuiteName, + Testcase: in.CaseName, + }); err != nil { + return + } + + if err = srv.Send(&TestResult{ + TestCaseResult: []*TestCaseResult{reply}, + Error: reply.Error, + }); err != nil { + return err + } + + var interval string + if interval, err = render.Render("batch run interval", in.Interval, nil); err != nil { + return + } + + var duration time.Duration + if duration, err = time.ParseDuration(interval); err != nil { + return + } + time.Sleep(duration) + } + } + } } func (s *server) DownloadResponseFile(ctx context.Context, in *TestCase) (reply *FileData, err error) { - if in.Response != nil { - tempFileName := in.Response.Body - if tempFileName == "" { - return nil, errors.New("file name is empty") - } - - tempDir := os.TempDir() - filePath := filepath.Join(tempDir, tempFileName) - if filepath.Clean(filePath) != filepath.Join(tempDir, filepath.Base(tempFileName)) { - return nil, errors.New("invalid file path") - } - - fmt.Println("get file from", filePath) - fileContent, err := os.ReadFile(filePath) - if err != nil { - return nil, fmt.Errorf("failed to read file: %s", filePath) - } - - mimeType := mime.TypeByExtension(filepath.Ext(filePath)) - if mimeType == "" { - mimeType = "application/octet-stream" - } - - filename := filepath.Base(filePath) - // try to get the original filename - var originalFileName []byte - if originalFileName, err = os.ReadFile(filePath + "name"); err == nil && len(originalFileName) > 0 { - filename = string(originalFileName) - } - - reply = &FileData{ - Data: fileContent, - ContentType: mimeType, - Filename: filename, - } - - return reply, nil - } else { - return reply, errors.New("response is empty") - } + if in.Response != nil { + tempFileName := in.Response.Body + if tempFileName == "" { + return nil, errors.New("file name is empty") + } + + tempDir := os.TempDir() + filePath := filepath.Join(tempDir, tempFileName) + if filepath.Clean(filePath) != filepath.Join(tempDir, filepath.Base(tempFileName)) { + return nil, errors.New("invalid file path") + } + + fmt.Println("get file from", filePath) + fileContent, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %s", filePath) + } + + mimeType := mime.TypeByExtension(filepath.Ext(filePath)) + if mimeType == "" { + mimeType = "application/octet-stream" + } + + filename := filepath.Base(filePath) + // try to get the original filename + var originalFileName []byte + if originalFileName, err = os.ReadFile(filePath + "name"); err == nil && len(originalFileName) > 0 { + filename = string(originalFileName) + } + + reply = &FileData{ + Data: fileContent, + ContentType: mimeType, + Filename: filename, + } + + return reply, nil + } else { + return reply, errors.New("response is empty") + } } func (s *server) RunTestSuite(srv Runner_RunTestSuiteServer) (err error) { - ctx := srv.Context() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - var in *TestSuiteIdentity - in, err = srv.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - - var suite *Suite - if suite, err = s.ListTestCase(ctx, in); err != nil { - return - } - - for _, item := range suite.Items { - var reply *TestCaseResult - if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ - Suite: in.Name, - Testcase: item.Name, - }); err != nil { - return - } - - if err = srv.Send(&TestResult{ - TestCaseResult: []*TestCaseResult{reply}, - Error: reply.Error, - }); err != nil { - return err - } - } - } - } + ctx := srv.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + var in *TestSuiteIdentity + in, err = srv.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + + var suite *Suite + if suite, err = s.ListTestCase(ctx, in); err != nil { + return + } + + for _, item := range suite.Items { + var reply *TestCaseResult + if reply, err = s.RunTestCase(ctx, &TestCaseIdentity{ + Suite: in.Name, + Testcase: item.Name, + }); err != nil { + return + } + + if err = srv.Send(&TestResult{ + TestCaseResult: []*TestCaseResult{reply}, + Error: reply.Error, + }); err != nil { + return err + } + } + } + } } func (s *server) GetSchema(ctx context.Context, in *SimpleQuery) (result *CommonResult, err error) { - result = &CommonResult{ - Success: true, - } - switch in.Name { - case "core": - result.Message = docs.Schema - case "mock": - result.Message = docs.MockSchema - } - return + result = &CommonResult{ + Success: true, + } + switch in.Name { + case "core": + result.Message = docs.Schema + case "mock": + result.Message = docs.MockSchema + } + return } // GetVersion returns the version func (s *server) GetVersion(ctx context.Context, in *Empty) (reply *Version, err error) { - reply = &Version{ - Version: version.GetVersion(), - Date: version.GetDate(), - Commit: version.GetCommit(), - } - return + reply = &Version{ + Version: version.GetVersion(), + Date: version.GetDate(), + Commit: version.GetCommit(), + } + return } func (s *server) GetSuites(ctx context.Context, in *Empty) (reply *Suites, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &Suites{ - Data: make(map[string]*Items), - } - - var suites []testing.TestSuite - if suites, err = loader.ListTestSuite(); err == nil && suites != nil { - for _, suite := range suites { - items := &Items{} - for _, item := range suite.Items { - items.Data = append(items.Data, item.Name) - } - items.Kind = suite.Spec.Kind - reply.Data[suite.Name] = items - } - } - - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &Suites{ + Data: make(map[string]*Items), + } + + var suites []testing.TestSuite + if suites, err = loader.ListTestSuite(); err == nil && suites != nil { + for _, suite := range suites { + items := &Items{} + for _, item := range suite.Items { + items.Data = append(items.Data, item.Name) + } + items.Kind = suite.Spec.Kind + reply.Data[suite.Name] = items + } + } + + return } func (s *server) GetHistorySuites(ctx context.Context, in *Empty) (reply *HistorySuites, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HistorySuites{ - Data: make(map[string]*HistoryItems), - } - - var suites []testing.HistoryTestSuite - if suites, err = loader.ListHistoryTestSuite(); err == nil && suites != nil { - for _, suite := range suites { - items := &HistoryItems{} - for _, item := range suite.Items { - data := &HistoryCaseIdentity{ - ID: item.ID, - HistorySuiteName: item.HistorySuiteName, - Kind: item.SuiteSpec.Kind, - Suite: item.SuiteName, - Testcase: item.CaseName, - } - items.Data = append(items.Data, data) - } - reply.Data[suite.HistorySuiteName] = items - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HistorySuites{ + Data: make(map[string]*HistoryItems), + } + + var suites []testing.HistoryTestSuite + if suites, err = loader.ListHistoryTestSuite(); err == nil && suites != nil { + for _, suite := range suites { + items := &HistoryItems{} + for _, item := range suite.Items { + data := &HistoryCaseIdentity{ + ID: item.ID, + HistorySuiteName: item.HistorySuiteName, + Kind: item.SuiteSpec.Kind, + Suite: item.SuiteName, + Testcase: item.CaseName, + } + items.Data = append(items.Data, data) + } + reply.Data[suite.HistorySuiteName] = items + } + } + return } func (s *server) CreateTestSuite(ctx context.Context, in *TestSuiteIdentity) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - if loader == nil { - reply.Error = "no loader found" - } else { - if err = loader.CreateSuite(in.Name, in.Api); err == nil { - toUpdate := testing.TestSuite{ - Name: in.Name, - API: in.Api, - Spec: testing.APISpec{ - Kind: in.Kind, - }, - } - - switch strings.ToLower(in.Kind) { - case "grpc", "trpc": - toUpdate.Spec.RPC = &testing.RPCDesc{} - } - - err = loader.UpdateSuite(toUpdate) - } - } - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + if loader == nil { + reply.Error = "no loader found" + } else { + if err = loader.CreateSuite(in.Name, in.Api); err == nil { + toUpdate := testing.TestSuite{ + Name: in.Name, + API: in.Api, + Spec: testing.APISpec{ + Kind: in.Kind, + }, + } + + switch strings.ToLower(in.Kind) { + case "grpc", "trpc": + toUpdate.Spec.RPC = &testing.RPCDesc{} + } + + err = loader.UpdateSuite(toUpdate) + } + } + return } func (s *server) ImportTestSuite(ctx context.Context, in *TestSuiteSource) (result *CommonResult, err error) { - result = &CommonResult{} - var dataImporter generator.Importer - switch in.Kind { - case "postman": - dataImporter = generator.NewPostmanImporter() - case "native", "": - dataImporter = generator.NewNativeImporter() - default: - result.Success = false - result.Message = fmt.Sprintf("not support kind: %s", in.Kind) - return - } - - remoteServerLogger.Logger.Info("import test suite", "kind", in.Kind, "url", in.Url) - var suite *testing.TestSuite - if in.Url != "" { - suite, err = dataImporter.ConvertFromURL(in.Url) - } else if in.Data != "" { - suite, err = dataImporter.Convert([]byte(in.Data)) - } else { - err = errors.New("url or data is required") - } - - if err != nil { - result.Success = false - result.Message = err.Error() - return - } - - loader := s.getLoader(ctx) - defer loader.Close() - - if err = loader.CreateSuite(suite.Name, suite.API); err != nil { - return - } - - for _, item := range suite.Items { - if err = loader.CreateTestCase(suite.Name, item); err != nil { - break - } - } - result.Success = true - return + result = &CommonResult{} + var dataImporter generator.Importer + switch in.Kind { + case "postman": + dataImporter = generator.NewPostmanImporter() + case "native", "": + dataImporter = generator.NewNativeImporter() + default: + result.Success = false + result.Message = fmt.Sprintf("not support kind: %s", in.Kind) + return + } + + remoteServerLogger.Logger.Info("import test suite", "kind", in.Kind, "url", in.Url) + var suite *testing.TestSuite + if in.Url != "" { + suite, err = dataImporter.ConvertFromURL(in.Url) + } else if in.Data != "" { + suite, err = dataImporter.Convert([]byte(in.Data)) + } else { + err = errors.New("url or data is required") + } + + if err != nil { + result.Success = false + result.Message = err.Error() + return + } + + loader := s.getLoader(ctx) + defer loader.Close() + + if err = loader.CreateSuite(suite.Name, suite.API); err != nil { + return + } + + for _, item := range suite.Items { + if err = loader.CreateTestCase(suite.Name, item); err != nil { + break + } + } + result.Success = true + return } func (s *server) GetTestSuite(ctx context.Context, in *TestSuiteIdentity) (result *TestSuite, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - var suite *testing.TestSuite - if suite, _, err = loader.GetSuite(in.Name); err == nil && suite != nil { - result = ToGRPCSuite(suite) - } - return + loader := s.getLoader(ctx) + defer loader.Close() + var suite *testing.TestSuite + if suite, _, err = loader.GetSuite(in.Name); err == nil && suite != nil { + result = ToGRPCSuite(suite) + } + return } func (s *server) UpdateTestSuite(ctx context.Context, in *TestSuite) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.UpdateSuite(*ToNormalSuite(in)) - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.UpdateSuite(*ToNormalSuite(in)) + return } func (s *server) DeleteTestSuite(ctx context.Context, in *TestSuiteIdentity) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.DeleteSuite(in.Name) - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.DeleteSuite(in.Name) + return } func (s *server) DuplicateTestSuite(ctx context.Context, in *TestSuiteDuplicate) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - - if in.SourceSuiteName == in.TargetSuiteName { - reply.Error = "source and target suite name should be different" - return - } - - var suite testing.TestSuite - if suite, err = loader.GetTestSuite(in.SourceSuiteName, true); err == nil { - suite.Name = in.TargetSuiteName - if err = loader.CreateSuite(suite.Name, suite.API); err == nil { - for _, testCase := range suite.Items { - if err = loader.CreateTestCase(suite.Name, testCase); err != nil { - break - } - } - } - } - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + + if in.SourceSuiteName == in.TargetSuiteName { + reply.Error = "source and target suite name should be different" + return + } + + var suite testing.TestSuite + if suite, err = loader.GetTestSuite(in.SourceSuiteName, true); err == nil { + suite.Name = in.TargetSuiteName + if err = loader.CreateSuite(suite.Name, suite.API); err == nil { + for _, testCase := range suite.Items { + if err = loader.CreateTestCase(suite.Name, testCase); err != nil { + break + } + } + } + } + return } func (s *server) RenameTestSuite(ctx context.Context, in *TestSuiteDuplicate) (reply *HelloReply, err error) { - reply = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.RenameTestSuite(in.SourceSuiteName, in.TargetSuiteName) - return + reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.RenameTestSuite(in.SourceSuiteName, in.TargetSuiteName) + return } func (s *server) ListTestCase(ctx context.Context, in *TestSuiteIdentity) (result *Suite, err error) { - var items []testing.TestCase - loader := s.getLoader(ctx) - defer loader.Close() - if items, err = loader.ListTestCase(in.Name); err == nil { - result = &Suite{} - for _, item := range items { - result.Items = append(result.Items, ToGRPCTestCase(item)) - } - } - return + var items []testing.TestCase + loader := s.getLoader(ctx) + defer loader.Close() + if items, err = loader.ListTestCase(in.Name); err == nil { + result = &Suite{} + for _, item := range items { + result.Items = append(result.Items, ToGRPCTestCase(item)) + } + } + return } func (s *server) GetTestSuiteYaml(ctx context.Context, in *TestSuiteIdentity) (reply *YamlData, err error) { - var data []byte - loader := s.getLoader(ctx) - defer loader.Close() - if data, err = loader.GetTestSuiteYaml(in.Name); err == nil { - reply = &YamlData{ - Data: data, - } - } - return + var data []byte + loader := s.getLoader(ctx) + defer loader.Close() + if data, err = loader.GetTestSuiteYaml(in.Name); err == nil { + reply = &YamlData{ + Data: data, + } + } + return } func (s *server) GetTestCase(ctx context.Context, in *TestCaseIdentity) (reply *TestCase, err error) { - var result testing.TestCase - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetTestCase(in.Suite, in.Testcase); err == nil { - reply = ToGRPCTestCase(result) + var result testing.TestCase + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetTestCase(in.Suite, in.Testcase); err == nil { + reply = ToGRPCTestCase(result) - var suite testing.TestSuite - if suite, err = loader.GetTestSuite(in.Suite, false); err == nil { - reply.Server = suite.API - } - } - return + var suite testing.TestSuite + if suite, err = loader.GetTestSuite(in.Suite, false); err == nil { + reply.Server = suite.API + } + } + return } func (s *server) GetHistoryTestCaseWithResult(ctx context.Context, in *HistoryTestCase) (reply *HistoryTestResult, err error) { - var result testing.HistoryTestResult - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetHistoryTestCaseWithResult(in.ID); err == nil { - reply = ToGRPCHistoryTestCaseResult(result) - } - return + var result testing.HistoryTestResult + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetHistoryTestCaseWithResult(in.ID); err == nil { + reply = ToGRPCHistoryTestCaseResult(result) + } + return } func (s *server) GetHistoryTestCase(ctx context.Context, in *HistoryTestCase) (reply *HistoryTestCase, err error) { - var result testing.HistoryTestCase - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetHistoryTestCase(in.ID); err == nil { - reply = ConvertToGRPCHistoryTestCase(result) - } - return + var result testing.HistoryTestCase + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetHistoryTestCase(in.ID); err == nil { + reply = ConvertToGRPCHistoryTestCase(result) + } + return } var ExecutionCountNum = promauto.NewCounter(prometheus.CounterOpts{ - Name: "atest_execution_count", - Help: "The total number of request execution", + Name: "atest_execution_count", + Help: "The total number of request execution", }) var ExecutionSuccessNum = promauto.NewCounter(prometheus.CounterOpts{ - Name: "atest_execution_success", - Help: "The total number of request execution success", + Name: "atest_execution_success", + Help: "The total number of request execution success", }) var ExecutionFailNum = promauto.NewCounter(prometheus.CounterOpts{ - Name: "atest_execution_fail", - Help: "The total number of request execution fail", + Name: "atest_execution_fail", + Help: "The total number of request execution fail", }) func (s *server) GetTestCaseAllHistory(ctx context.Context, in *TestCase) (result *HistoryTestCases, err error) { - var items []testing.HistoryTestCase - loader := s.getLoader(ctx) - defer loader.Close() - if items, err = loader.GetTestCaseAllHistory(in.SuiteName, in.Name); err == nil { - result = &HistoryTestCases{} - for _, item := range items { - result.Data = append(result.Data, ConvertToGRPCHistoryTestCase(item)) - } - } - return + var items []testing.HistoryTestCase + loader := s.getLoader(ctx) + defer loader.Close() + if items, err = loader.GetTestCaseAllHistory(in.SuiteName, in.Name); err == nil { + result = &HistoryTestCases{} + for _, item := range items { + result.Data = append(result.Data, ConvertToGRPCHistoryTestCase(item)) + } + } + return } func (s *server) RunTestCase(ctx context.Context, in *TestCaseIdentity) (result *TestCaseResult, err error) { - var targetTestSuite testing.TestSuite - ExecutionCountNum.Inc() - defer func() { - if result.Error == "" { - ExecutionSuccessNum.Inc() - } else { - ExecutionFailNum.Inc() - } - }() - - result = &TestCaseResult{} - loader := s.getLoader(ctx) - defer loader.Close() - targetTestSuite, err = loader.GetTestSuite(in.Suite, true) - if err != nil || targetTestSuite.Name == "" { - err = nil - result.Error = fmt.Sprintf("not found suite: %s", in.Suite) - return - } - - var data []byte - if data, err = yaml.Marshal(targetTestSuite); err == nil { - task := &TestTask{ - Kind: "testcaseInSuite", - Data: string(data), - CaseName: in.Testcase, - Level: "debug", - Parameters: in.Parameters, - } - - var reply *TestResult - var lastItem *TestCaseResult - if reply, err = s.Run(ctx, task); err == nil && len(reply.TestCaseResult) > 0 { - lastIndex := len(reply.TestCaseResult) - 1 - lastItem = reply.TestCaseResult[lastIndex] - - if len(lastItem.Body) > GrpcMaxRecvMsgSize { - e := "the HTTP response body exceeded the maximum message size limit received by the gRPC client" - result = &TestCaseResult{ - Output: reply.Message, - Error: e, - Body: "", - Header: lastItem.Header, - StatusCode: http.StatusOK, - } - return - } - - result = &TestCaseResult{ - Output: reply.Message, - Error: reply.Error, - Body: lastItem.Body, - Header: lastItem.Header, - StatusCode: lastItem.StatusCode, - } - } else if err != nil { - result.Error = err.Error() - } else { - result = &TestCaseResult{ - Output: reply.Message, - Error: reply.Error, - } - } - - if reply != nil { - result.Output = reply.Message - result.Error = reply.Error - } - if lastItem != nil { - result.Body = lastItem.Body - result.Header = lastItem.Header - result.StatusCode = lastItem.StatusCode - } - } - return + var targetTestSuite testing.TestSuite + ExecutionCountNum.Inc() + defer func() { + if result.Error == "" { + ExecutionSuccessNum.Inc() + } else { + ExecutionFailNum.Inc() + } + }() + + result = &TestCaseResult{} + loader := s.getLoader(ctx) + defer loader.Close() + targetTestSuite, err = loader.GetTestSuite(in.Suite, true) + if err != nil || targetTestSuite.Name == "" { + err = nil + result.Error = fmt.Sprintf("not found suite: %s", in.Suite) + return + } + + var data []byte + if data, err = yaml.Marshal(targetTestSuite); err == nil { + task := &TestTask{ + Kind: "testcaseInSuite", + Data: string(data), + CaseName: in.Testcase, + Level: "debug", + Parameters: in.Parameters, + } + + var reply *TestResult + var lastItem *TestCaseResult + if reply, err = s.Run(ctx, task); err == nil && len(reply.TestCaseResult) > 0 { + lastIndex := len(reply.TestCaseResult) - 1 + lastItem = reply.TestCaseResult[lastIndex] + + if len(lastItem.Body) > GrpcMaxRecvMsgSize { + e := "the HTTP response body exceeded the maximum message size limit received by the gRPC client" + result = &TestCaseResult{ + Output: reply.Message, + Error: e, + Body: "", + Header: lastItem.Header, + StatusCode: http.StatusOK, + } + return + } + + result = &TestCaseResult{ + Output: reply.Message, + Error: reply.Error, + Body: lastItem.Body, + Header: lastItem.Header, + StatusCode: lastItem.StatusCode, + } + } else if err != nil { + result.Error = err.Error() + } else { + result = &TestCaseResult{ + Output: reply.Message, + Error: reply.Error, + } + } + + if reply != nil { + result.Output = reply.Message + result.Error = reply.Error + } + if lastItem != nil { + result.Body = lastItem.Body + result.Header = lastItem.Header + result.StatusCode = lastItem.StatusCode + } + } + return } func mapInterToPair(data map[string]interface{}) (pairs []*Pair) { - pairs = make([]*Pair, 0) - for k, v := range data { - pairs = append(pairs, &Pair{ - Key: k, - Value: fmt.Sprintf("%v", v), - }) - } - return + pairs = make([]*Pair, 0) + for k, v := range data { + pairs = append(pairs, &Pair{ + Key: k, + Value: fmt.Sprintf("%v", v), + }) + } + return } func mapToPair(data map[string]string) (pairs []*Pair) { - pairs = make([]*Pair, 0) - for k, v := range data { - pairs = append(pairs, &Pair{ - Key: k, - Value: v, - }) - } - return + pairs = make([]*Pair, 0) + for k, v := range data { + pairs = append(pairs, &Pair{ + Key: k, + Value: v, + }) + } + return } func pairToInterMap(pairs []*Pair) (data map[string]interface{}) { - data = make(map[string]interface{}) - for _, pair := range pairs { - if pair.Key == "" { - continue - } - data[pair.Key] = pair.Value - } - return + data = make(map[string]interface{}) + for _, pair := range pairs { + if pair.Key == "" { + continue + } + data[pair.Key] = pair.Value + } + return } func pairToMap(pairs []*Pair) (data map[string]string) { - data = make(map[string]string) - for _, pair := range pairs { - if pair.Key == "" { - continue - } - data[pair.Key] = pair.Value - } - return + data = make(map[string]string) + for _, pair := range pairs { + if pair.Key == "" { + continue + } + data[pair.Key] = pair.Value + } + return } func convertConditionalVerify(verify []*ConditionalVerify) (result []testing.ConditionalVerify) { - if verify != nil { - result = make([]testing.ConditionalVerify, 0) + if verify != nil { + result = make([]testing.ConditionalVerify, 0) - for _, item := range verify { - result = append(result, testing.ConditionalVerify{ - Condition: item.Condition, - Verify: item.Verify, - }) - } - } - return + for _, item := range verify { + result = append(result, testing.ConditionalVerify{ + Condition: item.Condition, + Verify: item.Verify, + }) + } + } + return } func (s *server) CreateTestCase(ctx context.Context, in *TestCaseWithSuite) (reply *HelloReply, err error) { - reply = &HelloReply{} - if in.Data == nil { - err = errors.New("data is required") - } else { - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.CreateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) - } - return + reply = &HelloReply{} + if in.Data == nil { + err = errors.New("data is required") + } else { + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.CreateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) + } + return } func (s *server) UpdateTestCase(ctx context.Context, in *TestCaseWithSuite) (reply *HelloReply, err error) { - reply = &HelloReply{} - if in.Data == nil { - err = errors.New("data is required") - return - } - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.UpdateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) - return + reply = &HelloReply{} + if in.Data == nil { + err = errors.New("data is required") + return + } + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.UpdateTestCase(in.SuiteName, ToNormalTestCase(in.Data)) + return } func (s *server) DeleteTestCase(ctx context.Context, in *TestCaseIdentity) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} - err = loader.DeleteTestCase(in.Suite, in.Testcase) - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} + err = loader.DeleteTestCase(in.Suite, in.Testcase) + return } func (s *server) DeleteHistoryTestCase(ctx context.Context, in *HistoryTestCase) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} - err = loader.DeleteHistoryTestCase(in.ID) - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} + err = loader.DeleteHistoryTestCase(in.ID) + return } func (s *server) DeleteAllHistoryTestCase(ctx context.Context, in *HistoryTestCase) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} - err = loader.DeleteAllHistoryTestCase(in.SuiteName, in.CaseName) - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} + err = loader.DeleteAllHistoryTestCase(in.SuiteName, in.CaseName) + return } func (s *server) DuplicateTestCase(ctx context.Context, in *TestCaseDuplicate) (reply *HelloReply, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + reply = &HelloReply{} - if in.SourceCaseName == in.TargetCaseName { - reply.Error = "source and target case name should be different" - return - } + if in.SourceCaseName == in.TargetCaseName { + reply.Error = "source and target case name should be different" + return + } - var testcase testing.TestCase - if testcase, err = loader.GetTestCase(in.SourceSuiteName, in.SourceCaseName); err == nil { - testcase.Name = in.TargetCaseName - err = loader.CreateTestCase(in.TargetSuiteName, testcase) - } - return + var testcase testing.TestCase + if testcase, err = loader.GetTestCase(in.SourceSuiteName, in.SourceCaseName); err == nil { + testcase.Name = in.TargetCaseName + err = loader.CreateTestCase(in.TargetSuiteName, testcase) + } + return } func (s *server) RenameTestCase(ctx context.Context, in *TestCaseDuplicate) (result *HelloReply, err error) { - result = &HelloReply{} - loader := s.getLoader(ctx) - defer loader.Close() - err = loader.RenameTestCase(in.SourceSuiteName, in.SourceCaseName, in.TargetCaseName) - return + result = &HelloReply{} + loader := s.getLoader(ctx) + defer loader.Close() + err = loader.RenameTestCase(in.SourceSuiteName, in.SourceCaseName, in.TargetCaseName) + return } // code generator func (s *server) ListCodeGenerator(ctx context.Context, in *Empty) (reply *SimpleList, err error) { - reply = &SimpleList{} + reply = &SimpleList{} - generators := generator.GetCodeGenerators() - for name := range generators { - reply.Data = append(reply.Data, &Pair{ - Key: name, - }) - } - return + generators := generator.GetCodeGenerators() + for name := range generators { + reply.Data = append(reply.Data, &Pair{ + Key: name, + }) + } + return } func (s *server) GenerateCode(ctx context.Context, in *CodeGenerateRequest) (reply *CommonResult, err error) { - reply = &CommonResult{} - instance := generator.GetCodeGenerator(in.Generator) - if instance == nil { - reply.Success = false - reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) - } else { - var result testing.TestCase - var suite testing.TestSuite - - loader := s.getLoader(ctx) - if suite, err = loader.GetTestSuite(in.TestSuite, true); err != nil { - return - } - - dataContext := map[string]interface{}{} - if err = suite.Render(dataContext); err != nil { - return - } - - var output string - var genErr error - if in.TestCase == "" { - output, genErr = instance.Generate(&suite, nil) - } else { - if result, err = loader.GetTestCase(in.TestSuite, in.TestCase); err == nil { - result.Request.RenderAPI(suite.API) - - output, genErr = instance.Generate(&suite, &result) - } - } - reply.Success = genErr == nil - reply.Message = util.OrErrorMessage(genErr, output) - } - return + reply = &CommonResult{} + instance := generator.GetCodeGenerator(in.Generator) + if instance == nil { + reply.Success = false + reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) + } else { + var result testing.TestCase + var suite testing.TestSuite + + loader := s.getLoader(ctx) + if suite, err = loader.GetTestSuite(in.TestSuite, true); err != nil { + return + } + + dataContext := map[string]interface{}{} + if err = suite.Render(dataContext); err != nil { + return + } + + var output string + var genErr error + if in.TestCase == "" { + output, genErr = instance.Generate(&suite, nil) + } else { + if result, err = loader.GetTestCase(in.TestSuite, in.TestCase); err == nil { + result.Request.RenderAPI(suite.API) + + output, genErr = instance.Generate(&suite, &result) + } + } + reply.Success = genErr == nil + reply.Message = util.OrErrorMessage(genErr, output) + } + return } func (s *server) HistoryGenerateCode(ctx context.Context, in *CodeGenerateRequest) (reply *CommonResult, err error) { - reply = &CommonResult{} - instance := generator.GetCodeGenerator(in.Generator) - if instance == nil { - reply.Success = false - reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) - } else { - loader := s.getLoader(ctx) - var result testing.HistoryTestCase - result, err = loader.GetHistoryTestCase(in.ID) - var testCase testing.TestCase - var suite testing.TestSuite - testCase = result.Data - suite.Name = result.SuiteName - suite.API = result.SuiteAPI - suite.Spec = result.SuiteSpec - suite.Param = result.SuiteParam - - output, genErr := instance.Generate(&suite, &testCase) - reply.Success = genErr == nil - reply.Message = util.OrErrorMessage(genErr, output) - } - return + reply = &CommonResult{} + instance := generator.GetCodeGenerator(in.Generator) + if instance == nil { + reply.Success = false + reply.Message = fmt.Sprintf("generator '%s' not found", in.Generator) + } else { + loader := s.getLoader(ctx) + var result testing.HistoryTestCase + result, err = loader.GetHistoryTestCase(in.ID) + var testCase testing.TestCase + var suite testing.TestSuite + testCase = result.Data + suite.Name = result.SuiteName + suite.API = result.SuiteAPI + suite.Spec = result.SuiteSpec + suite.Param = result.SuiteParam + + output, genErr := instance.Generate(&suite, &testCase) + reply.Success = genErr == nil + reply.Message = util.OrErrorMessage(genErr, output) + } + return } // converter func (s *server) ListConverter(ctx context.Context, in *Empty) (reply *SimpleList, err error) { - reply = &SimpleList{} - converters := generator.GetTestSuiteConverters() - for name := range converters { - reply.Data = append(reply.Data, &Pair{ - Key: name, - }) - } - return + reply = &SimpleList{} + converters := generator.GetTestSuiteConverters() + for name := range converters { + reply.Data = append(reply.Data, &Pair{ + Key: name, + }) + } + return } func (s *server) ConvertTestSuite(ctx context.Context, in *CodeGenerateRequest) (reply *CommonResult, err error) { - reply = &CommonResult{} - - instance := generator.GetTestSuiteConverter(in.Generator) - if instance == nil { - reply.Success = false - reply.Message = fmt.Sprintf("converter '%s' not found", in.Generator) - } else { - var result testing.TestSuite - loader := s.getLoader(ctx) - defer loader.Close() - if result, err = loader.GetTestSuite(in.TestSuite, true); err == nil { - output, genErr := instance.Convert(&result) - reply.Success = genErr == nil - reply.Message = util.OrErrorMessage(genErr, output) - } - } - return + reply = &CommonResult{} + + instance := generator.GetTestSuiteConverter(in.Generator) + if instance == nil { + reply.Success = false + reply.Message = fmt.Sprintf("converter '%s' not found", in.Generator) + } else { + var result testing.TestSuite + loader := s.getLoader(ctx) + defer loader.Close() + if result, err = loader.GetTestSuite(in.TestSuite, true); err == nil { + output, genErr := instance.Convert(&result) + reply.Success = genErr == nil + reply.Message = util.OrErrorMessage(genErr, output) + } + } + return } // Sample returns a sample of the test task func (s *server) Sample(ctx context.Context, in *Empty) (reply *HelloReply, err error) { - reply = &HelloReply{Message: sample.TestSuiteGitLab} - return + reply = &HelloReply{Message: sample.TestSuiteGitLab} + return } // PopularHeaders returns a list of popular headers func (s *server) PopularHeaders(ctx context.Context, in *Empty) (pairs *Pairs, err error) { - pairs = &Pairs{ - Data: []*Pair{}, - } + pairs = &Pairs{ + Data: []*Pair{}, + } - err = yaml.Unmarshal(popularHeaders, &pairs.Data) - return + err = yaml.Unmarshal(popularHeaders, &pairs.Data) + return } // GetSuggestedAPIs returns a list of suggested APIs func (s *server) GetSuggestedAPIs(ctx context.Context, in *TestSuiteIdentity) (reply *TestCases, err error) { - reply = &TestCases{} + reply = &TestCases{} - var suite *testing.TestSuite - loader := s.getLoader(ctx) - defer loader.Close() - if suite, _, err = loader.GetSuite(in.Name); err != nil || suite == nil { - return - } + var suite *testing.TestSuite + loader := s.getLoader(ctx) + defer loader.Close() + if suite, _, err = loader.GetSuite(in.Name); err != nil || suite == nil { + return + } - remoteServerLogger.Info("Finding APIs from", "name", in.Name, "with loader", reflect.TypeOf(loader)) + remoteServerLogger.Info("Finding APIs from", "name", in.Name, "with loader", reflect.TypeOf(loader)) - suiteRunner := runner.GetTestSuiteRunner(suite) - var result []*testing.TestCase - if result, err = suiteRunner.GetSuggestedAPIs(suite, in.Api); err == nil && result != nil { - for i := range result { - reply.Data = append(reply.Data, ToGRPCTestCase(*result[i])) - } - } - return + suiteRunner := runner.GetTestSuiteRunner(suite) + var result []*testing.TestCase + if result, err = suiteRunner.GetSuggestedAPIs(suite, in.Api); err == nil && result != nil { + for i := range result { + reply.Data = append(reply.Data, ToGRPCTestCase(*result[i])) + } + } + return } // FunctionsQuery returns a list of functions func (s *server) FunctionsQuery(ctx context.Context, in *SimpleQuery) (reply *Pairs, err error) { - reply = &Pairs{} - in.Name = strings.ToLower(in.Name) - - if in.Kind == "verify" { - for _, fn := range builtin.Builtins { - lowerName := strings.ToLower(fn.Name) - if in.Name == "" || strings.Contains(lowerName, in.Name) { - reply.Data = append(reply.Data, &Pair{ - Key: fn.Name, - Value: fmt.Sprintf("%v", reflect.TypeOf(fn.Func)), - }) - } - } - } else { - for name, fn := range render.FuncMap() { - lowerName := strings.ToLower(name) - if in.Name == "" || strings.Contains(lowerName, in.Name) { - reply.Data = append(reply.Data, &Pair{ - Key: name, - Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), - Description: render.FuncUsage(name), - }) - } - } - } - return + reply = &Pairs{} + in.Name = strings.ToLower(in.Name) + + if in.Kind == "verify" { + for _, fn := range builtin.Builtins { + lowerName := strings.ToLower(fn.Name) + if in.Name == "" || strings.Contains(lowerName, in.Name) { + reply.Data = append(reply.Data, &Pair{ + Key: fn.Name, + Value: fmt.Sprintf("%v", reflect.TypeOf(fn.Func)), + }) + } + } + } else { + for name, fn := range render.FuncMap() { + lowerName := strings.ToLower(name) + if in.Name == "" || strings.Contains(lowerName, in.Name) { + reply.Data = append(reply.Data, &Pair{ + Key: name, + Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), + Description: render.FuncUsage(name), + }) + } + } + } + return } // FunctionsQueryStream works like FunctionsQuery but is implemented in bidirectional streaming func (s *server) FunctionsQueryStream(srv Runner_FunctionsQueryStreamServer) error { - ctx := srv.Context() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - in, err := srv.Recv() - if err != nil { - if err == io.EOF { - return nil - } - return err - } - reply := &Pairs{} - in.Name = strings.ToLower(in.Name) - - for name, fn := range render.FuncMap() { - lowerCaseName := strings.ToLower(name) - if in.Name == "" || strings.Contains(lowerCaseName, in.Name) { - reply.Data = append(reply.Data, &Pair{ - Key: name, - Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), - }) - } - } - if err := srv.Send(reply); err != nil { - return err - } - } - } + ctx := srv.Context() + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + in, err := srv.Recv() + if err != nil { + if err == io.EOF { + return nil + } + return err + } + reply := &Pairs{} + in.Name = strings.ToLower(in.Name) + + for name, fn := range render.FuncMap() { + lowerCaseName := strings.ToLower(name) + if in.Name == "" || strings.Contains(lowerCaseName, in.Name) { + reply.Data = append(reply.Data, &Pair{ + Key: name, + Value: fmt.Sprintf("%v", reflect.TypeOf(fn)), + }) + } + } + if err := srv.Send(reply); err != nil { + return err + } + } + } } func (s *server) GetStoreKinds(context.Context, *Empty) (kinds *StoreKinds, err error) { - storeFactory := testing.NewStoreFactory(s.configDir) - var stores []testing.StoreKind - if stores, err = storeFactory.GetStoreKinds(); err == nil { - kinds = &StoreKinds{} - for _, store := range stores { - kinds.Data = append(kinds.Data, &StoreKind{ - Name: store.Name, - Enabled: store.Enabled, - Url: store.URL, - }) - } - } - return + storeFactory := testing.NewStoreFactory(s.configDir) + var stores []testing.StoreKind + if stores, err = storeFactory.GetStoreKinds(); err == nil { + kinds = &StoreKinds{} + for _, store := range stores { + kinds.Data = append(kinds.Data, &StoreKind{ + Name: store.Name, + Enabled: store.Enabled, + Url: store.URL, + }) + } + } + return } func (s *server) GetStores(ctx context.Context, in *Empty) (reply *Stores, err error) { - user := oauth.GetUserFromContext(ctx) - storeFactory := testing.NewStoreFactory(s.configDir) - var stores []testing.Store - var owner string - if user != nil { - owner = user.Name - } - if stores, err = storeFactory.GetStoresByOwner(owner); err == nil { - reply = &Stores{ - Data: make([]*Store, 0), - } - wg := sync.WaitGroup{} - mu := sync.Mutex{} - for _, item := range stores { - wg.Add(1) - go func() { - defer wg.Done() - - grpcStore := ToGRPCStore(item) - if item.Disabled { - return - } - - storeStatus, sErr := s.VerifyStore(ctx, &SimpleQuery{Name: item.Name}) - grpcStore.Ready = sErr == nil && storeStatus.Ready - grpcStore.ReadOnly = storeStatus.ReadOnly - grpcStore.Password = util.PasswordPlaceholder - - mu.Lock() - reply.Data = append(reply.Data, grpcStore) - mu.Unlock() - }() - } - wg.Wait() - slices.SortFunc(reply.Data, func(a, b *Store) int { - return strings.Compare(a.Name, b.Name) - }) - reply.Data = append(reply.Data, &Store{ - Name: "local", - Kind: &StoreKind{}, - Ready: true, - }) - } - return + user := oauth.GetUserFromContext(ctx) + storeFactory := testing.NewStoreFactory(s.configDir) + var stores []testing.Store + var owner string + if user != nil { + owner = user.Name + } + if stores, err = storeFactory.GetStoresByOwner(owner); err == nil { + reply = &Stores{ + Data: make([]*Store, 0), + } + wg := sync.WaitGroup{} + mu := sync.Mutex{} + for _, item := range stores { + wg.Add(1) + go func() { + defer wg.Done() + + grpcStore := ToGRPCStore(item) + if item.Disabled { + return + } + + storeStatus, sErr := s.VerifyStore(ctx, &SimpleQuery{Name: item.Name}) + grpcStore.Ready = sErr == nil && storeStatus.Ready + grpcStore.ReadOnly = storeStatus.ReadOnly + grpcStore.Password = util.PasswordPlaceholder + + mu.Lock() + reply.Data = append(reply.Data, grpcStore) + mu.Unlock() + }() + } + wg.Wait() + slices.SortFunc(reply.Data, func(a, b *Store) int { + return strings.Compare(a.Name, b.Name) + }) + reply.Data = append(reply.Data, &Store{ + Name: "local", + Kind: &StoreKind{}, + Ready: true, + }) + } + return } func (s *server) CreateStore(ctx context.Context, in *Store) (reply *Store, err error) { - reply = &Store{} - user := oauth.GetUserFromContext(ctx) - if user != nil { - in.Owner = user.Name - } + reply = &Store{} + user := oauth.GetUserFromContext(ctx) + if user != nil { + in.Owner = user.Name + } - storeFactory := testing.NewStoreFactory(s.configDir) - store := ToNormalStore(in) + storeFactory := testing.NewStoreFactory(s.configDir) + store := ToNormalStore(in) - if store.Kind.URL == "" { - store.Kind.URL = fmt.Sprintf("unix://%s", home.GetExtensionSocketPath(store.Kind.Name)) - } + if store.Kind.URL == "" { + store.Kind.URL = fmt.Sprintf("unix://%s", home.GetExtensionSocketPath(store.Kind.Name)) + } - if err = storeFactory.CreateStore(store); err == nil && s.storeExtMgr != nil { - err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) - } - return + if err = storeFactory.CreateStore(store); err == nil && s.storeExtMgr != nil { + err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) + } + return } func (s *server) UpdateStore(ctx context.Context, in *Store) (reply *Store, err error) { - reply = &Store{} - storeFactory := testing.NewStoreFactory(s.configDir) - store := ToNormalStore(in) - if err = storeFactory.UpdateStore(store); err == nil && s.storeExtMgr != nil { - // TODO need to restart extension if config was changed - err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) - } - return + reply = &Store{} + storeFactory := testing.NewStoreFactory(s.configDir) + store := ToNormalStore(in) + if err = storeFactory.UpdateStore(store); err == nil && s.storeExtMgr != nil { + // TODO need to restart extension if config was changed + err = s.storeExtMgr.Start(store.Kind.Name, store.Kind.URL) + } + return } func (s *server) DeleteStore(ctx context.Context, in *Store) (reply *Store, err error) { - reply = &Store{} - storeFactory := testing.NewStoreFactory(s.configDir) - err = storeFactory.DeleteStore(in.Name) - return + reply = &Store{} + storeFactory := testing.NewStoreFactory(s.configDir) + err = storeFactory.DeleteStore(in.Name) + return } func (s *server) VerifyStore(ctx context.Context, in *SimpleQuery) (reply *ExtensionStatus, err error) { - reply = &ExtensionStatus{} - var loader testing.Writer - if loader, err = s.getLoaderByStoreName(in.Name); err == nil && loader != nil { - readOnly, verifyErr := loader.Verify() - reply.Ready = verifyErr == nil - reply.ReadOnly = readOnly - reply.Message = util.OKOrErrorMessage(verifyErr) - } - return + reply = &ExtensionStatus{} + var loader testing.Writer + if loader, err = s.getLoaderByStoreName(in.Name); err == nil && loader != nil { + readOnly, verifyErr := loader.Verify() + reply.Ready = verifyErr == nil + reply.ReadOnly = readOnly + reply.Message = util.OKOrErrorMessage(verifyErr) + } + return } // secret related interfaces func (s *server) GetSecrets(ctx context.Context, in *Empty) (reply *Secrets, err error) { - return s.secretServer.GetSecrets(ctx, in) + return s.secretServer.GetSecrets(ctx, in) } func (s *server) CreateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - return s.secretServer.CreateSecret(ctx, in) + return s.secretServer.CreateSecret(ctx, in) } func (s *server) DeleteSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - return s.secretServer.DeleteSecret(ctx, in) + return s.secretServer.DeleteSecret(ctx, in) } func (s *server) UpdateSecret(ctx context.Context, in *Secret) (reply *CommonResult, err error) { - return s.secretServer.UpdateSecret(ctx, in) + return s.secretServer.UpdateSecret(ctx, in) } func (s *server) PProf(ctx context.Context, in *PProfRequest) (reply *PProfData, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - reply = &PProfData{ - Data: loader.PProf(in.Name), - } - return + loader := s.getLoader(ctx) + defer loader.Close() + reply = &PProfData{ + Data: loader.PProf(in.Name), + } + return } func (s *server) Query(ctx context.Context, query *DataQuery) (result *DataQueryResult, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - - // render the SQL query - var sql string - if sql, err = render.Render("sql render", query.Sql, nil); err != nil { - return nil, fmt.Errorf("failed to render SQL query: %w", err) - } - - var dataResult testing.DataResult - if dataResult, err = loader.Query(map[string]string{ - "sql": sql, - "key": query.Key, - "offset": fmt.Sprintf("%d", query.Offset), - "limit": fmt.Sprintf("%d", query.Limit), - }); err == nil { - result = &DataQueryResult{ - Data: mapToPair(dataResult.Pairs), - } - for _, item := range dataResult.Rows { - result.Items = append(result.Items, &Pairs{ - Data: mapToPair(item), - }) - } - result.Meta = &DataMeta{ - CurrentDatabase: dataResult.CurrentDatabase, - Databases: dataResult.Databases, - Tables: dataResult.Tables, - Duration: dataResult.Duration, - Labels: mapToPair(dataResult.Labels), - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + + // render the SQL query + var sql string + if sql, err = render.Render("sql render", query.Sql, nil); err != nil { + return nil, fmt.Errorf("failed to render SQL query: %w", err) + } + + var dataResult testing.DataResult + if dataResult, err = loader.Query(map[string]string{ + "sql": sql, + "key": query.Key, + "offset": fmt.Sprintf("%d", query.Offset), + "limit": fmt.Sprintf("%d", query.Limit), + }); err == nil { + result = &DataQueryResult{ + Data: mapToPair(dataResult.Pairs), + } + for _, item := range dataResult.Rows { + result.Items = append(result.Items, &Pairs{ + Data: mapToPair(item), + }) + } + result.Meta = &DataMeta{ + CurrentDatabase: dataResult.CurrentDatabase, + Databases: dataResult.Databases, + Tables: dataResult.Tables, + Duration: dataResult.Duration, + Labels: mapToPair(dataResult.Labels), + } + } + return } func (s *server) GetThemes(ctx context.Context, _ *Empty) (result *SimpleList, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - - result = &SimpleList{} - var themes []string - if themes, err = loader.GetThemes(); err == nil { - for _, theme := range themes { - result.Data = append(result.Data, &Pair{ - Key: theme, - Value: "", - }) - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + + result = &SimpleList{} + var themes []string + if themes, err = loader.GetThemes(); err == nil { + for _, theme := range themes { + result.Data = append(result.Data, &Pair{ + Key: theme, + Value: "", + }) + } + } + return } func (s *server) GetTheme(ctx context.Context, in *SimpleName) (result *CommonResult, err error) { - loader := s.getLoader(ctx) - defer loader.Close() + loader := s.getLoader(ctx) + defer loader.Close() - result = &CommonResult{} - result.Message, err = loader.GetTheme(in.Name) - if err != nil { - result.Message = fmt.Sprintf("failed to get theme: %v", err) - } - return + result = &CommonResult{} + result.Message, err = loader.GetTheme(in.Name) + if err != nil { + result.Message = fmt.Sprintf("failed to get theme: %v", err) + } + return } func (s *server) GetBindings(ctx context.Context, _ *Empty) (result *SimpleList, err error) { - loader := s.getLoader(ctx) - defer loader.Close() - - result = &SimpleList{} - var bindings []string - if bindings, err = loader.GetBindings(); err == nil { - for _, theme := range bindings { - result.Data = append(result.Data, &Pair{ - Key: theme, - Value: "", - }) - } - } - return + loader := s.getLoader(ctx) + defer loader.Close() + + result = &SimpleList{} + var bindings []string + if bindings, err = loader.GetBindings(); err == nil { + for _, theme := range bindings { + result.Data = append(result.Data, &Pair{ + Key: theme, + Value: "", + }) + } + } + return } func (s *server) GetBinding(ctx context.Context, in *SimpleName) (result *CommonResult, err error) { - loader := s.getLoader(ctx) - defer loader.Close() + loader := s.getLoader(ctx) + defer loader.Close() - result = &CommonResult{} - result.Message, err = loader.GetBinding(in.Name) - if err != nil { - result.Message = fmt.Sprintf("failed to get binding: %v", err) - } - return + result = &CommonResult{} + result.Message, err = loader.GetBinding(in.Name) + if err != nil { + result.Message = fmt.Sprintf("failed to get binding: %v", err) + } + return } // implement the mock server // Start starts the mock server type mockServerController struct { - UnimplementedMockServer - mockWriter mock.ReaderAndWriter - loader mock.Loadable - reader mock.Reader - logData chan string - prefix string - combinePort int + UnimplementedMockServer + mockWriter mock.ReaderAndWriter + loader mock.Loadable + reader mock.Reader + logData chan string + prefix string + combinePort int } func NewMockServerController(mockWriter mock.ReaderAndWriter, loader mock.Loadable, combinePort int) MockServer { - return &mockServerController{ - mockWriter: mockWriter, - loader: loader, - prefix: "/mock/server", - logData: make(chan string, 100), - combinePort: combinePort, - } + return &mockServerController{ + mockWriter: mockWriter, + loader: loader, + prefix: "/mock/server", + logData: make(chan string, 100), + combinePort: combinePort, + } } func (s *mockServerController) Reload(ctx context.Context, in *MockConfig) (reply *Empty, err error) { - s.mockWriter.Write([]byte(in.Config)) - s.prefix = in.Prefix - if dServer, ok := s.loader.(mock.DynamicServer); ok && dServer.GetPort() != strconv.Itoa(int(in.GetPort())) { - if strconv.Itoa(s.combinePort) != dServer.GetPort() { - if stopErr := dServer.Stop(); stopErr != nil { - remoteServerLogger.Info("failed to stop old server", "error", stopErr) - } else { - remoteServerLogger.Info("old server stopped", "port", dServer.GetPort()) - } - } - - server := mock.NewInMemoryServer(ctx, int(in.GetPort())).WithTLS(dServer.GetTLS()) - if err = server.Start(s.mockWriter, in.Prefix); err != nil { - return - } - server.WithLogWriter(s) - s.loader = server - } - err = s.loader.Load() - return + s.mockWriter.Write([]byte(in.Config)) + s.prefix = in.Prefix + if dServer, ok := s.loader.(mock.DynamicServer); ok && dServer.GetPort() != strconv.Itoa(int(in.GetPort())) { + if strconv.Itoa(s.combinePort) != dServer.GetPort() { + if stopErr := dServer.Stop(); stopErr != nil { + remoteServerLogger.Info("failed to stop old server", "error", stopErr) + } else { + remoteServerLogger.Info("old server stopped", "port", dServer.GetPort()) + } + } + + server := mock.NewInMemoryServer(ctx, int(in.GetPort())).WithTLS(dServer.GetTLS()) + if err = server.Start(s.mockWriter, in.Prefix); err != nil { + return + } + server.WithLogWriter(s) + s.loader = server + } + err = s.loader.Load() + return } func (s *mockServerController) GetConfig(ctx context.Context, in *Empty) (reply *MockConfig, err error) { - reply = &MockConfig{ - Prefix: s.prefix, - Config: string(s.mockWriter.GetData()), - } - if dServer, ok := s.loader.(mock.DynamicServer); ok { - if port, pErr := strconv.ParseInt(dServer.GetPort(), 10, 32); pErr == nil { - reply.Port = int32(port) - } - } - return + reply = &MockConfig{ + Prefix: s.prefix, + Config: string(s.mockWriter.GetData()), + } + if dServer, ok := s.loader.(mock.DynamicServer); ok { + if port, pErr := strconv.ParseInt(dServer.GetPort(), 10, 32); pErr == nil { + reply.Port = int32(port) + } + } + return } func (s *mockServerController) LogWatch(e *Empty, logServer Mock_LogWatchServer) (err error) { - logServer.Send(&CommonResult{ - Success: true, - Message: "Mock server log watch started\n", - }) - for msg := range s.logData { - logServer.Send(&CommonResult{ - Success: true, - Message: msg, - }) - } - return + logServer.Send(&CommonResult{ + Success: true, + Message: "Mock server log watch started\n", + }) + for msg := range s.logData { + logServer.Send(&CommonResult{ + Success: true, + Message: msg, + }) + } + return } func (s *mockServerController) Write(p []byte) (n int, err error) { - select { - case s.logData <- fmt.Sprintf("%s: %s", time.Now().Format(time.RFC3339), string(p)): - default: - } - return + select { + case s.logData <- fmt.Sprintf("%s: %s", time.Now().Format(time.RFC3339), string(p)): + default: + } + return } func (s *server) getLoaderByStoreName(storeName string) (loader testing.Writer, err error) { - var store *testing.Store - store, err = testing.NewStoreFactory(s.configDir).GetStore(storeName) - if err == nil && store != nil { - loader, err = s.storeWriterFactory.NewInstance(*store) - if err != nil { - err = fmt.Errorf("failed to new grpc loader from store %s, err: %v", store.Name, err) - } - } else { - err = fmt.Errorf("failed to get store %s, err: %v", storeName, err) - } - return + var store *testing.Store + store, err = testing.NewStoreFactory(s.configDir).GetStore(storeName) + if err == nil && store != nil { + loader, err = s.storeWriterFactory.NewInstance(*store) + if err != nil { + err = fmt.Errorf("failed to new grpc loader from store %s, err: %v", store.Name, err) + } + } else { + err = fmt.Errorf("failed to get store %s, err: %v", storeName, err) + } + return } //go:embed data/headers.yaml var popularHeaders []byte func findParentTestCases(testcase *testing.TestCase, suite *testing.TestSuite) (testcases []testing.TestCase) { - reg, matchErr := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) - targetReg, targetErr := regexp.Compile(`\.\w*`) - - expectNames := new(UniqueSlice[string]) - if matchErr == nil && targetErr == nil { - var expectName string - for _, val := range testcase.Request.Header { - if matched := reg.MatchString(val); matched { - expectName = targetReg.FindString(val) - expectName = strings.TrimPrefix(expectName, ".") - expectNames.Push(expectName) - } - } - - findExpectNames(testcase.Request.API, expectNames) - findExpectNames(testcase.Request.Body.String(), expectNames) - - remoteServerLogger.Info("expect test case names", "name", expectNames.GetAll()) - for _, item := range suite.Items { - if expectNames.Exist(item.Name) { - testcases = append(testcases, item) - } - } - } - return + reg, matchErr := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) + targetReg, targetErr := regexp.Compile(`\.\w*`) + + expectNames := new(UniqueSlice[string]) + if matchErr == nil && targetErr == nil { + var expectName string + for _, val := range testcase.Request.Header { + if matched := reg.MatchString(val); matched { + expectName = targetReg.FindString(val) + expectName = strings.TrimPrefix(expectName, ".") + expectNames.Push(expectName) + } + } + + findExpectNames(testcase.Request.API, expectNames) + findExpectNames(testcase.Request.Body.String(), expectNames) + + remoteServerLogger.Info("expect test case names", "name", expectNames.GetAll()) + for _, item := range suite.Items { + if expectNames.Exist(item.Name) { + testcases = append(testcases, item) + } + } + } + return } func findExpectNames(target string, expectNames *UniqueSlice[string]) { - reg, _ := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) - targetReg, _ := regexp.Compile(`\.\w*`) + reg, _ := regexp.Compile(`(.*?\{\{.*\.\w*.*?\}\})`) + targetReg, _ := regexp.Compile(`\.\w*`) - for _, sub := range reg.FindStringSubmatch(target) { - // remove {{ and }} - if left, leftErr := regexp.Compile(`.*\{\{`); leftErr == nil { - body := left.ReplaceAllString(sub, "") + for _, sub := range reg.FindStringSubmatch(target) { + // remove {{ and }} + if left, leftErr := regexp.Compile(`.*\{\{`); leftErr == nil { + body := left.ReplaceAllString(sub, "") - expectName := targetReg.FindString(body) - expectName = strings.TrimPrefix(expectName, ".") - expectNames.Push(expectName) - } - } + expectName := targetReg.FindString(body) + expectName = strings.TrimPrefix(expectName, ".") + expectNames.Push(expectName) + } + } } // UniqueSlice represents an unique slice type UniqueSlice[T comparable] struct { - data []T + data []T } // Push pushes an item if it's not exist func (s *UniqueSlice[T]) Push(item T) *UniqueSlice[T] { - if s.data == nil { - s.data = []T{item} - } else { - for _, it := range s.data { - if it == item { - return s - } - } - s.data = append(s.data, item) - } - return s + if s.data == nil { + s.data = []T{item} + } else { + for _, it := range s.data { + if it == item { + return s + } + } + s.data = append(s.data, item) + } + return s } // Exist checks if the item exist, return true it exists func (s *UniqueSlice[T]) Exist(item T) bool { - if s.data != nil { - for _, it := range s.data { - if it == item { - return true - } - } - } - return false + if s.data != nil { + for _, it := range s.data { + if it == item { + return true + } + } + } + return false } // GetAll returns all the items func (s *UniqueSlice[T]) GetAll() []T { - return s.data + return s.data } var errNoTestSuiteFound = errors.New("no test suite found") From ecf921e611b0236e211d948ad3b0a4b87504b819 Mon Sep 17 00:00:00 2001 From: rick Date: Mon, 4 Aug 2025 20:48:11 +0800 Subject: [PATCH 3/3] support response as a random image --- docs/site/content/zh/latest/tasks/mock.md | 14 ++++++++++++++ pkg/mock/in_memory.go | 10 ++++++++++ pkg/mock/testdata/api.yaml | 8 ++++++++ 3 files changed, 32 insertions(+) diff --git a/docs/site/content/zh/latest/tasks/mock.md b/docs/site/content/zh/latest/tasks/mock.md index a8aa4822..7952175b 100644 --- a/docs/site/content/zh/latest/tasks/mock.md +++ b/docs/site/content/zh/latest/tasks/mock.md @@ -153,6 +153,20 @@ items: bodyFromFile: /tmp/baidu.html ``` +通过下面的方式也可以生成图片: + +```yaml +items: +- name: image + request: + path: /v1/image + response: + header: + Content-Type: image/png + body: | + {{ randImage 300 300 }} +``` + #### 条件判断 对于查询类的 API,通常会接收参数,并根据参数的不同,返回相应的数据。这时候,可以用到条件判断的表达式: diff --git a/pkg/mock/in_memory.go b/pkg/mock/in_memory.go index fceed94f..0315f5e6 100644 --- a/pkg/mock/in_memory.go +++ b/pkg/mock/in_memory.go @@ -476,6 +476,16 @@ func (h *advanceHandler) handle(w http.ResponseWriter, req *http.Request) { } } + if strings.HasPrefix(h.item.Response.Header[util.ContentType], "image/") { + if strings.HasPrefix(string(h.item.Response.BodyData), util.ImageBase64Prefix) { + // decode base64 image data + imgData := strings.TrimPrefix(string(h.item.Response.BodyData), util.ImageBase64Prefix) + if h.item.Response.BodyData, err = base64.StdEncoding.DecodeString(imgData); err != nil { + memLogger.Error(err, "failed to decode base64 image data") + } + } + } + if err == nil { h.item.Response.Header[util.ContentLength] = fmt.Sprintf("%d", len(h.item.Response.BodyData)) w.Header().Set(util.ContentLength, h.item.Response.Header[util.ContentLength]) diff --git a/pkg/mock/testdata/api.yaml b/pkg/mock/testdata/api.yaml index ac9f6107..1e2978fe 100644 --- a/pkg/mock/testdata/api.yaml +++ b/pkg/mock/testdata/api.yaml @@ -56,6 +56,14 @@ items: "status": "success" }] } + - name: image + request: + path: /v1/image + response: + header: + Content-Type: image/png + body: | + {{ randImage 300 300 }} proxies: - path: /v1/myProjects target: http://localhost:{{.GetPort}}