diff --git a/Taskfile.yml b/Taskfile.yml index 8b5ee9f..0fe7ac7 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -5,9 +5,11 @@ version: "3" tasks: default: dotenv: + - .env - .env.{{.ENV}} cmds: - go run ./... {{.CLI_ARGS}} + silent: true deps: cmds: @@ -21,7 +23,7 @@ tasks: test: cmds: - - go test ./... + - go test ./... -cover vuln: cmds: diff --git a/cmd/pagination.go b/cmd/pagination.go new file mode 100644 index 0000000..4db376a --- /dev/null +++ b/cmd/pagination.go @@ -0,0 +1,20 @@ +package cmd + +import ( + "github.com/loops-so/cli/internal/api" + "github.com/spf13/cobra" +) + +func addPaginationFlags(cmd *cobra.Command) { + cmd.Flags().String("per-page", "", "Results per page (10-50, default 20)") + cmd.Flags().String("cursor", "", "Pagination cursor for a specific page") +} + +func paginationParams(cmd *cobra.Command) api.PaginationParams { + perPage, _ := cmd.Flags().GetString("per-page") + cursor, _ := cmd.Flags().GetString("cursor") + return api.PaginationParams{ + PerPage: perPage, + Cursor: cursor, + } +} diff --git a/cmd/transactional.go b/cmd/transactional.go new file mode 100644 index 0000000..791750e --- /dev/null +++ b/cmd/transactional.go @@ -0,0 +1,155 @@ +package cmd + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "os" + "path/filepath" + "strings" + "text/tabwriter" + + "github.com/loops-so/cli/internal/api" + "github.com/loops-so/cli/internal/config" + "github.com/spf13/cobra" +) + +func attachmentFromPath(path string) (api.Attachment, error) { + info, err := os.Stat(path) + if err != nil { + return api.Attachment{}, fmt.Errorf("attachment %q: %w", path, err) + } + if info.IsDir() { + return api.Attachment{}, fmt.Errorf("attachment %q: is a directory", path) + } + + data, err := os.ReadFile(path) + if err != nil { + return api.Attachment{}, fmt.Errorf("attachment %q: %w", path, err) + } + + sniff := data + if len(sniff) > 512 { + sniff = sniff[:512] + } + contentType := http.DetectContentType(sniff) + + return api.Attachment{ + Filename: filepath.Base(path), + ContentType: contentType, + Data: base64.StdEncoding.EncodeToString(data), + }, nil +} + +var transactionalCmd = &cobra.Command{ + Use: "transactional", + Short: "Manage transactional emails", +} + +var transactionalListCmd = &cobra.Command{ + Use: "list", + Short: "List published transactional emails", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return err + } + + params := paginationParams(cmd) + client := api.NewClient(cfg.EndpointURL, cfg.APIKey) + + var emails []api.TransactionalEmail + if params.Cursor != "" { + emails, _, err = client.ListTransactional(params) + } else { + emails, err = api.Paginate(func(cursor string) ([]api.TransactionalEmail, *api.Pagination, error) { + return client.ListTransactional(api.PaginationParams{ + PerPage: params.PerPage, + Cursor: cursor, + }) + }) + } + if err != nil { + return err + } + + if len(emails) == 0 { + fmt.Println("No transactional emails found.") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tNAME\tLAST UPDATED\tVARIABLES") + for _, e := range emails { + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", e.ID, e.Name, e.LastUpdated, strings.Join(e.DataVariables, ", ")) + } + w.Flush() + + return nil + }, +} + +var transactionalSendCmd = &cobra.Command{ + Use: "send", + Short: "Send a transactional email", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return err + } + + email, _ := cmd.Flags().GetString("email") + id, _ := cmd.Flags().GetString("id") + dataRaw, _ := cmd.Flags().GetString("data") + + req := api.SendTransactionalRequest{ + Email: email, + TransactionalID: id, + } + + if cmd.Flags().Changed("add-to-audience") { + v, _ := cmd.Flags().GetBool("add-to-audience") + req.AddToAudience = &v + } + + if dataRaw != "" { + if err := json.Unmarshal([]byte(dataRaw), &req.DataVariables); err != nil { + return fmt.Errorf("--data must be a valid JSON object: %w", err) + } + } + + paths, _ := cmd.Flags().GetStringArray("attachment") + for _, path := range paths { + a, err := attachmentFromPath(path) + if err != nil { + return err + } + req.Attachments = append(req.Attachments, a) + } + + client := api.NewClient(cfg.EndpointURL, cfg.APIKey) + if err := client.SendTransactional(req); err != nil { + return err + } + + fmt.Println("Sent.") + return nil + }, +} + +func init() { + addPaginationFlags(transactionalListCmd) + transactionalCmd.AddCommand(transactionalListCmd) + + transactionalSendCmd.Flags().String("email", "", "Recipient email address") + transactionalSendCmd.Flags().String("id", "", "Transactional email ID") + transactionalSendCmd.Flags().BoolP("add-to-audience", "a", false, "Create a contact if one doesn't exist") + transactionalSendCmd.Flags().String("data", "", "Data variables as a JSON object") + transactionalSendCmd.Flags().StringArrayP("attachment", "A", nil, "Path to a file to attach (repeatable)") + transactionalSendCmd.MarkFlagRequired("email") + transactionalSendCmd.MarkFlagRequired("id") + transactionalCmd.AddCommand(transactionalSendCmd) + + rootCmd.AddCommand(transactionalCmd) +} diff --git a/cmd/transactional_test.go b/cmd/transactional_test.go new file mode 100644 index 0000000..ecd7a88 --- /dev/null +++ b/cmd/transactional_test.go @@ -0,0 +1,55 @@ +package cmd + +import ( + "encoding/base64" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestAttachmentFromPath(t *testing.T) { + t.Run("valid file", func(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "test-*.txt") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + content := []byte("hello attachment") + f.Write(content) + f.Close() + + a, err := attachmentFromPath(f.Name()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if a.Filename != filepath.Base(f.Name()) { + t.Errorf("Filename = %q, want %q", a.Filename, filepath.Base(f.Name())) + } + if !strings.HasPrefix(a.ContentType, "text/plain") { + t.Errorf("ContentType = %q, want text/plain prefix", a.ContentType) + } + if a.Data != base64.StdEncoding.EncodeToString(content) { + t.Errorf("Data = %q, want base64 of content", a.Data) + } + }) + + t.Run("file not found", func(t *testing.T) { + _, err := attachmentFromPath("/nonexistent/file.pdf") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "/nonexistent/file.pdf") { + t.Errorf("error %q should mention the path", err.Error()) + } + }) + + t.Run("directory", func(t *testing.T) { + _, err := attachmentFromPath(t.TempDir()) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "is a directory") { + t.Errorf("error %q should mention 'is a directory'", err.Error()) + } + }) +} diff --git a/internal/api/api_key.go b/internal/api/api_key.go index 6e04984..da97a15 100644 --- a/internal/api/api_key.go +++ b/internal/api/api_key.go @@ -11,7 +11,7 @@ type APIKeyResponse struct { } func (c *Client) GetAPIKey() (*APIKeyResponse, error) { - req, err := c.newRequest(http.MethodGet, "/api-key") + req, err := c.newRequest(http.MethodGet, "/api-key", nil) if err != nil { return nil, err } diff --git a/internal/api/client.go b/internal/api/client.go index 053129e..3ad92be 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -3,6 +3,7 @@ package api import ( "encoding/json" "fmt" + "io" "math/rand/v2" "net/http" "time" @@ -44,10 +45,16 @@ func NewClient(baseURL, apiKey string) *Client { func errorFromResponse(resp *http.Response) *APIError { var body struct { - Error string `json:"error"` + Error string `json:"error"` + Message string `json:"message"` } - if err := json.NewDecoder(resp.Body).Decode(&body); err == nil && body.Error != "" { - return &APIError{StatusCode: resp.StatusCode, Message: body.Error} + if err := json.NewDecoder(resp.Body).Decode(&body); err == nil { + if body.Error != "" { + return &APIError{StatusCode: resp.StatusCode, Message: body.Error} + } + if body.Message != "" { + return &APIError{StatusCode: resp.StatusCode, Message: body.Message} + } } return &APIError{StatusCode: resp.StatusCode, Message: fmt.Sprintf("unexpected status: %d", resp.StatusCode)} } @@ -83,12 +90,15 @@ func (c *Client) do(req *http.Request) (*http.Response, error) { return resp, nil } -func (c *Client) newRequest(method, path string) (*http.Request, error) { +func (c *Client) newRequest(method, path string, body io.Reader) (*http.Request, error) { url := fmt.Sprintf("%s%s", c.baseURL, path) - req, err := http.NewRequest(method, url, nil) + req, err := http.NewRequest(method, url, body) if err != nil { return nil, err } req.Header.Set("Authorization", "Bearer "+c.apiKey) + if body != nil { + req.Header.Set("Content-Type", "application/json") + } return req, nil } diff --git a/internal/api/client_test.go b/internal/api/client_test.go index d76d0d7..208f0d9 100644 --- a/internal/api/client_test.go +++ b/internal/api/client_test.go @@ -22,7 +22,7 @@ func TestNewRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req, err := client.newRequest(tt.method, tt.path) + req, err := client.newRequest(tt.method, tt.path, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -46,12 +46,73 @@ func TestNewRequest(t *testing.T) { func TestNewRequest_InvalidURL(t *testing.T) { client := NewClient("://bad-url", "test-key") - _, err := client.newRequest(http.MethodGet, "/path") + _, err := client.newRequest(http.MethodGet, "/path", nil) if err == nil { t.Error("expected error for invalid URL, got nil") } } +func TestErrorFromResponse(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantMessage string + }{ + { + name: "reads error field", + statusCode: http.StatusBadRequest, + body: `{"error":"something went wrong"}`, + wantMessage: "something went wrong", + }, + { + name: "falls back to message field", + statusCode: http.StatusBadRequest, + body: `{"message":"something went wrong"}`, + wantMessage: "something went wrong", + }, + { + name: "prefers error over message", + statusCode: http.StatusBadRequest, + body: `{"error":"error field","message":"message field"}`, + wantMessage: "error field", + }, + { + name: "falls back to generic when body is empty", + statusCode: http.StatusBadRequest, + body: ``, + wantMessage: "unexpected status: 400", + }, + { + name: "falls back to generic when fields are absent", + statusCode: http.StatusBadRequest, + body: `{"success":false}`, + wantMessage: "unexpected status: 400", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + w.Write([]byte(tt.body)) + })) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + apiErr := errorFromResponse(resp) + if apiErr.Message != tt.wantMessage { + t.Errorf("Message = %q, want %q", apiErr.Message, tt.wantMessage) + } + }) + } +} + func TestDo_Retries(t *testing.T) { origSleep := sleep sleep = func(time.Duration) {} @@ -107,7 +168,7 @@ func TestDo_Retries(t *testing.T) { defer server.Close() client := NewClient(server.URL, "test-key") - req, _ := client.newRequest(http.MethodGet, "/") + req, _ := client.newRequest(http.MethodGet, "/", nil) resp, err := client.do(req) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/internal/api/pagination.go b/internal/api/pagination.go new file mode 100644 index 0000000..6d9bb3a --- /dev/null +++ b/internal/api/pagination.go @@ -0,0 +1,31 @@ +package api + +type Pagination struct { + TotalResults int `json:"totalResults"` + ReturnedResults int `json:"returnedResults"` + PerPage int `json:"perPage"` + TotalPages int `json:"totalPages"` + NextCursor string `json:"nextCursor"` + NextPage int `json:"nextPage"` +} + +type PaginationParams struct { + PerPage string + Cursor string +} + +func Paginate[T any](fetch func(cursor string) ([]T, *Pagination, error)) ([]T, error) { + var all []T + cursor := "" + for { + items, page, err := fetch(cursor) + if err != nil { + return nil, err + } + all = append(all, items...) + if page.NextCursor == "" { + return all, nil + } + cursor = page.NextCursor + } +} diff --git a/internal/api/pagination_test.go b/internal/api/pagination_test.go new file mode 100644 index 0000000..4ed21c2 --- /dev/null +++ b/internal/api/pagination_test.go @@ -0,0 +1,84 @@ +package api + +import ( + "errors" + "testing" +) + +func TestPaginate(t *testing.T) { + t.Run("single page", func(t *testing.T) { + calls := 0 + fetch := func(cursor string) ([]string, *Pagination, error) { + calls++ + return []string{"a", "b"}, &Pagination{NextCursor: ""}, nil + } + + items, err := Paginate(fetch) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(items) != 2 { + t.Errorf("expected 2 items, got %d", len(items)) + } + if calls != 1 { + t.Errorf("expected 1 fetch call, got %d", calls) + } + }) + + t.Run("multiple pages", func(t *testing.T) { + pages := []struct { + items []string + cursor string + }{ + {[]string{"a", "b"}, "cursor1"}, + {[]string{"c", "d"}, "cursor2"}, + {[]string{"e"}, ""}, + } + call := 0 + fetch := func(cursor string) ([]string, *Pagination, error) { + p := pages[call] + call++ + return p.items, &Pagination{NextCursor: p.cursor}, nil + } + + items, err := Paginate(fetch) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(items) != 5 { + t.Errorf("expected 5 items, got %d", len(items)) + } + if call != 3 { + t.Errorf("expected 3 fetch calls, got %d", call) + } + }) + + t.Run("error on first fetch", func(t *testing.T) { + fetchErr := errors.New("api error") + fetch := func(cursor string) ([]string, *Pagination, error) { + return nil, nil, fetchErr + } + + _, err := Paginate(fetch) + if !errors.Is(err, fetchErr) { + t.Errorf("expected fetch error, got %v", err) + } + }) + + t.Run("error mid-pagination", func(t *testing.T) { + fetchErr := errors.New("api error") + call := 0 + fetch := func(cursor string) ([]string, *Pagination, error) { + call++ + if call == 2 { + return nil, nil, fetchErr + } + return []string{"a"}, &Pagination{NextCursor: "cursor1"}, nil + } + + _, err := Paginate(fetch) + if !errors.Is(err, fetchErr) { + t.Errorf("expected fetch error, got %v", err) + } + }) +} diff --git a/internal/api/transactional.go b/internal/api/transactional.go new file mode 100644 index 0000000..fe66423 --- /dev/null +++ b/internal/api/transactional.go @@ -0,0 +1,94 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +type TransactionalEmail struct { + ID string `json:"id"` + Name string `json:"name"` + LastUpdated string `json:"lastUpdated"` + DataVariables []string `json:"dataVariables"` +} + +type Attachment struct { + Filename string `json:"filename"` + ContentType string `json:"contentType"` + Data string `json:"data"` +} + +type SendTransactionalRequest struct { + Email string `json:"email"` + TransactionalID string `json:"transactionalId"` + AddToAudience *bool `json:"addToAudience,omitempty"` + DataVariables map[string]any `json:"dataVariables,omitempty"` + Attachments []Attachment `json:"attachments,omitempty"` +} + +func (c *Client) SendTransactional(req SendTransactionalRequest) error { + b, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to encode request: %w", err) + } + + httpReq, err := c.newRequest(http.MethodPost, "/transactional", bytes.NewReader(b)) + if err != nil { + return err + } + + resp, err := c.do(httpReq) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errorFromResponse(resp) + } + + return nil +} + +func (c *Client) ListTransactional(params PaginationParams) ([]TransactionalEmail, *Pagination, error) { + q := url.Values{} + if params.PerPage != "" { + q.Set("perPage", params.PerPage) + } + if params.Cursor != "" { + q.Set("cursor", params.Cursor) + } + + path := "/transactional" + if len(q) > 0 { + path += "?" + q.Encode() + } + + req, err := c.newRequest(http.MethodGet, path, nil) + if err != nil { + return nil, nil, err + } + + resp, err := c.do(req) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, nil, errorFromResponse(resp) + } + + var result struct { + Pagination Pagination `json:"pagination"` + Data []TransactionalEmail `json:"data"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, nil, fmt.Errorf("failed to decode response: %w", err) + } + + return result.Data, &result.Pagination, nil +} diff --git a/internal/api/transactional_test.go b/internal/api/transactional_test.go new file mode 100644 index 0000000..d3a661c --- /dev/null +++ b/internal/api/transactional_test.go @@ -0,0 +1,352 @@ +package api + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +const listTransactionalResponse = `{ + "pagination": { + "totalResults": 2, + "returnedResults": 2, + "perPage": 20, + "totalPages": 1, + "nextCursor": "", + "nextPage": 0 + }, + "data": [ + { + "id": "abc123", + "name": "Welcome Email", + "lastUpdated": "2024-01-15T10:30:00Z", + "dataVariables": ["name", "email"] + }, + { + "id": "def456", + "name": "Password Reset", + "lastUpdated": "2024-01-14T08:15:00Z", + "dataVariables": ["resetLink"] + } + ] +}` + +func TestListTransactional(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantAPIErr *APIError + wantErrMsg string + wantCount int + }{ + { + name: "success", + statusCode: http.StatusOK, + body: listTransactionalResponse, + wantCount: 2, + }, + { + name: "empty list", + statusCode: http.StatusOK, + body: `{"pagination":{"totalResults":0},"data":[]}`, + wantCount: 0, + }, + { + name: "unauthorized", + statusCode: http.StatusUnauthorized, + body: `{"success":false,"error":"Invalid API key"}`, + wantAPIErr: &APIError{StatusCode: http.StatusUnauthorized, Message: "Invalid API key"}, + }, + { + name: "bad request", + statusCode: http.StatusBadRequest, + body: `{"success":false,"error":"Invalid perPage value"}`, + wantAPIErr: &APIError{StatusCode: http.StatusBadRequest, Message: "Invalid perPage value"}, + }, + { + name: "invalid json", + statusCode: http.StatusOK, + body: `not json`, + wantErrMsg: "failed to decode response", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + w.Write([]byte(tt.body)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + emails, pagination, err := client.ListTransactional(PaginationParams{}) + + if tt.wantAPIErr != nil { + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected *APIError, got %T: %v", err, err) + } + if apiErr.StatusCode != tt.wantAPIErr.StatusCode { + t.Errorf("StatusCode = %d, want %d", apiErr.StatusCode, tt.wantAPIErr.StatusCode) + } + if tt.wantAPIErr.Message != "" && apiErr.Message != tt.wantAPIErr.Message { + t.Errorf("Message = %q, want %q", apiErr.Message, tt.wantAPIErr.Message) + } + return + } + + if tt.wantErrMsg != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrMsg) + } + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf("error = %q, want it to contain %q", err.Error(), tt.wantErrMsg) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(emails) != tt.wantCount { + t.Errorf("len(emails) = %d, want %d", len(emails), tt.wantCount) + } + if pagination == nil { + t.Fatal("expected pagination, got nil") + } + }) + } +} + +func TestListTransactional_ResponseData(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(listTransactionalResponse)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + emails, pagination, err := client.ListTransactional(PaginationParams{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if emails[0].ID != "abc123" { + t.Errorf("ID = %q, want %q", emails[0].ID, "abc123") + } + if emails[0].Name != "Welcome Email" { + t.Errorf("Name = %q, want %q", emails[0].Name, "Welcome Email") + } + if emails[0].LastUpdated != "2024-01-15T10:30:00Z" { + t.Errorf("LastUpdated = %q, want %q", emails[0].LastUpdated, "2024-01-15T10:30:00Z") + } + if len(emails[0].DataVariables) != 2 || emails[0].DataVariables[0] != "name" { + t.Errorf("DataVariables = %v, want [name email]", emails[0].DataVariables) + } + if pagination.TotalResults != 2 { + t.Errorf("TotalResults = %d, want 2", pagination.TotalResults) + } +} + +func TestListTransactional_QueryParams(t *testing.T) { + tests := []struct { + name string + params PaginationParams + wantPerPage string + wantCursor string + }{ + { + name: "no params", + params: PaginationParams{}, + }, + { + name: "per-page only", + params: PaginationParams{PerPage: "50"}, + wantPerPage: "50", + }, + { + name: "cursor only", + params: PaginationParams{Cursor: "abc"}, + wantCursor: "abc", + }, + { + name: "both params", + params: PaginationParams{PerPage: "10", Cursor: "xyz"}, + wantPerPage: "10", + wantCursor: "xyz", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotPerPage, gotCursor string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPerPage = r.URL.Query().Get("perPage") + gotCursor = r.URL.Query().Get("cursor") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"pagination":{},"data":[]}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + client.ListTransactional(tt.params) + + if gotPerPage != tt.wantPerPage { + t.Errorf("perPage = %q, want %q", gotPerPage, tt.wantPerPage) + } + if gotCursor != tt.wantCursor { + t.Errorf("cursor = %q, want %q", gotCursor, tt.wantCursor) + } + }) + } +} + +func TestSendTransactional(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantAPIErr *APIError + wantErrMsg string + }{ + { + name: "success", + statusCode: http.StatusOK, + body: `{"success":true}`, + }, + { + name: "unauthorized", + statusCode: http.StatusUnauthorized, + body: `{"message":"Invalid API key"}`, + wantAPIErr: &APIError{StatusCode: http.StatusUnauthorized, Message: "Invalid API key"}, + }, + { + name: "not found", + statusCode: http.StatusNotFound, + body: `{"message":"Transactional email not found"}`, + wantAPIErr: &APIError{StatusCode: http.StatusNotFound, Message: "Transactional email not found"}, + }, + { + name: "bad request", + statusCode: http.StatusBadRequest, + body: `{"message":"Recipient email is required"}`, + wantAPIErr: &APIError{StatusCode: http.StatusBadRequest, Message: "Recipient email is required"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + w.Write([]byte(tt.body)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + err := client.SendTransactional(SendTransactionalRequest{ + Email: "test@example.com", + TransactionalID: "abc123", + }) + + if tt.wantAPIErr != nil { + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("expected *APIError, got %T: %v", err, err) + } + if apiErr.StatusCode != tt.wantAPIErr.StatusCode { + t.Errorf("StatusCode = %d, want %d", apiErr.StatusCode, tt.wantAPIErr.StatusCode) + } + if tt.wantAPIErr.Message != "" && apiErr.Message != tt.wantAPIErr.Message { + t.Errorf("Message = %q, want %q", apiErr.Message, tt.wantAPIErr.Message) + } + return + } + + if tt.wantErrMsg != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrMsg) + } + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf("error = %q, want it to contain %q", err.Error(), tt.wantErrMsg) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestSendTransactional_RequestBody(t *testing.T) { + addToAudience := true + tests := []struct { + name string + req SendTransactionalRequest + wantEmail string + wantID string + wantAudience *bool + wantVars map[string]any + }{ + { + name: "required fields only", + req: SendTransactionalRequest{Email: "a@b.com", TransactionalID: "abc"}, + wantEmail: "a@b.com", + wantID: "abc", + }, + { + name: "with add-to-audience", + req: SendTransactionalRequest{Email: "a@b.com", TransactionalID: "abc", AddToAudience: &addToAudience}, + wantEmail: "a@b.com", + wantID: "abc", + wantAudience: &addToAudience, + }, + { + name: "with data variables", + req: SendTransactionalRequest{Email: "a@b.com", TransactionalID: "abc", DataVariables: map[string]any{"name": "Alice"}}, + wantEmail: "a@b.com", + wantID: "abc", + wantVars: map[string]any{"name": "Alice"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got SendTransactionalRequest + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + json.Unmarshal(b, &got) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"success":true}`)) + })) + defer server.Close() + + client := NewClient(server.URL, "test-key") + client.SendTransactional(tt.req) + + if got.Email != tt.wantEmail { + t.Errorf("email = %q, want %q", got.Email, tt.wantEmail) + } + if got.TransactionalID != tt.wantID { + t.Errorf("transactionalId = %q, want %q", got.TransactionalID, tt.wantID) + } + if tt.wantAudience != nil { + if got.AddToAudience == nil || *got.AddToAudience != *tt.wantAudience { + t.Errorf("addToAudience = %v, want %v", got.AddToAudience, tt.wantAudience) + } + } + if tt.wantVars != nil { + if got.DataVariables["name"] != tt.wantVars["name"] { + t.Errorf("dataVariables = %v, want %v", got.DataVariables, tt.wantVars) + } + } + }) + } +}