Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions internal/api/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@ func (c *Client) GetAPIKey() (*APIKeyResponse, error) {
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusUnauthorized {
return nil, fmt.Errorf("invalid API key")
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
return nil, errorFromResponse(resp)
}

var result APIKeyResponse
Expand Down
55 changes: 30 additions & 25 deletions internal/api/api_key_test.go
Original file line number Diff line number Diff line change
@@ -1,42 +1,45 @@
package api

import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestGetAPIKey(t *testing.T) {
tests := []struct {
name string
statusCode int
body string
wantErr string
wantTeam string
name string
statusCode int
body string
wantAPIErr *APIError
wantErrMsg string
wantTeam string
}{
{
name: "success",
statusCode: http.StatusOK,
body: `{"success":true,"teamName":"Acme"}`,
body: `{"teamName":"Acme"}`,
wantTeam: "Acme",
},
{
name: "unauthorized",
statusCode: http.StatusUnauthorized,
body: `{"success":false,"error":"Invalid API key"}`,
wantErr: "invalid API key",
wantAPIErr: &APIError{StatusCode: http.StatusUnauthorized, Message: "Invalid API key"},
},
{
name: "unexpected status",
statusCode: http.StatusInternalServerError,
body: ``,
wantErr: "unexpected status: 500",
wantAPIErr: &APIError{StatusCode: http.StatusInternalServerError},
},
{
name: "invalid json",
statusCode: http.StatusOK,
body: `not json`,
wantErr: "failed to decode response",
wantErrMsg: "failed to decode response",
},
}

Expand All @@ -51,12 +54,26 @@ func TestGetAPIKey(t *testing.T) {
client := NewClient(server.URL, "test-key")
result, err := client.GetAPIKey()

if tt.wantErr != "" {
if tt.wantAPIErr != nil {
var apiErr *APIError
if !errors.As(err, &apiErr) {
t.Fatalf("expected *APIError, got %T: %v", err, err)
}
if apiErr.StatusCode != tt.wantAPIErr.StatusCode {
t.Errorf("StatusCode = %d, want %d", apiErr.StatusCode, tt.wantAPIErr.StatusCode)
}
if tt.wantAPIErr.Message != "" && apiErr.Message != tt.wantAPIErr.Message {
t.Errorf("Message = %q, want %q", apiErr.Message, tt.wantAPIErr.Message)
}
return
}

if tt.wantErrMsg != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErr)
t.Fatalf("expected error containing %q, got nil", tt.wantErrMsg)
}
if !contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want it to contain %q", err.Error(), tt.wantErr)
if !strings.Contains(err.Error(), tt.wantErrMsg) {
t.Errorf("error = %q, want it to contain %q", err.Error(), tt.wantErrMsg)
}
return
}
Expand All @@ -70,15 +87,3 @@ func TestGetAPIKey(t *testing.T) {
})
}
}

func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
func() bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}())
}
20 changes: 20 additions & 0 deletions internal/api/client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
package api

import (
"encoding/json"
"fmt"
"net/http"
"time"
)

type APIError struct {
StatusCode int
Message string
}

func (e *APIError) Error() string {
return e.Message
}

type Client struct {
baseURL string
apiKey string
Expand All @@ -20,6 +30,16 @@ func NewClient(baseURL, apiKey string) *Client {
}
}

func errorFromResponse(resp *http.Response) *APIError {
var body struct {
Error string `json:"error"`
}
if err := json.NewDecoder(resp.Body).Decode(&body); err == nil && body.Error != "" {
return &APIError{StatusCode: resp.StatusCode, Message: body.Error}
}
return &APIError{StatusCode: resp.StatusCode, Message: fmt.Sprintf("unexpected status: %d", resp.StatusCode)}
}

func (c *Client) newRequest(method, path string) (*http.Request, error) {
url := fmt.Sprintf("%s%s", c.baseURL, path)
req, err := http.NewRequest(method, url, nil)
Expand Down
Loading