diff --git a/README.md b/README.md index b058b67..d75ac8a 100644 --- a/README.md +++ b/README.md @@ -1 +1,3 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/1602/witness)](https://goreportcard.com/report/github.com/1602/witness) +[![Build Status](https://travis-ci.org/1602/witness.svg?branch=master)](https://travis-ci.org/1602/witness) +[![Coverage Status](https://img.shields.io/coveralls/github/1602/witness.svg)](https://coveralls.io/github/1602/witness?branch=master) diff --git a/transport.go b/transport.go index 9f17e87..af3865b 100644 --- a/transport.go +++ b/transport.go @@ -1,8 +1,10 @@ package witness import ( + "context" "encoding/json" "fmt" + "log" "net/http" ) @@ -13,7 +15,8 @@ type sse struct { closingClients chan chan []byte firstClient chan bool firstClientConnected bool - done chan bool + ctx context.Context + startServer func() } func (t *sse) Notify(rtl RoundTripLog) { @@ -29,22 +32,36 @@ func serializeOrDie(stuff interface{}) []byte { return json } -func NewTransport(firstClientReady, done chan bool) (transport *sse) { +func NewSSETransport() (transport *sse) { transport = &sse{ distributor: make(chan []byte), openingClients: make(chan chan []byte), connectedClients: make(map[chan []byte]bool), closingClients: make(chan chan []byte), - firstClient: firstClientReady, firstClientConnected: false, - done: done, + startServer: func() { + log.Fatal("HTTP server error: ", http.ListenAndServe("localhost:1602", transport)) + }, } - go transport.route() - return transport } +func (t *sse) Init(ctx context.Context) { + t.ctx = ctx + t.firstClient = make(chan bool, 1) + go t.route() + go t.startServer() + + // TODO: make configurable + // wait until first client connected + fmt.Println("waiting for first client") + + <-t.firstClient + + fmt.Println("first client connected") +} + func (t *sse) ServeHTTP(rw http.ResponseWriter, req *http.Request) { flusher, flusherSupported := rw.(http.Flusher) @@ -83,7 +100,7 @@ func (t *sse) ServeHTTP(rw http.ResponseWriter, req *http.Request) { case data := <-ch: fmt.Fprintf(rw, "data: %s\n\n", data) flusher.Flush() - case <-t.done: + case <-t.ctx.Done(): return } } diff --git a/transport_test.go b/transport_test.go index 488a783..a722f44 100644 --- a/transport_test.go +++ b/transport_test.go @@ -1,6 +1,8 @@ package witness import ( + "context" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -11,7 +13,7 @@ import ( func TestNotify(t *testing.T) { url := "http://example.com" - tr := NewTransport(nil, nil) + tr := NewSSETransport() go tr.Notify(RoundTripLog{RequestLog: RequestLog{Url: url}}) msg := <-tr.distributor if !strings.Contains(string(msg), url) { @@ -41,37 +43,44 @@ func TestSerializeOrDie(t *testing.T) { func TestServeHTTP(t *testing.T) { t.Run("waiting for the first client", func(t *testing.T) { - done := make(chan bool, 1) - ch := make(chan bool, 1) - tr := NewTransport(ch, done) - ts := httptest.NewServer(tr) + tr := NewSSETransport() + ts := httptest.NewUnstartedServer(tr) defer ts.Close() + ctx, cancel := context.WithCancel(context.TODO()) + tr.startServer = func() { + ts.Start() - go func() { - req, _ := http.NewRequest("GET", ts.URL, nil) - client := &http.Client{Timeout: 10 * time.Millisecond} - res, _ := client.Do(req) - l, _ := ioutil.ReadAll(res.Body) - res.Body.Close() - result := string(l) + go func() { + req, _ := http.NewRequest("GET", ts.URL, nil) + fmt.Println("server url is", ts.URL) + client := &http.Client{Timeout: 10 * time.Millisecond} + res, err := client.Do(req) + if err != nil { + fmt.Println(res, err) + } + l, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + result := string(l) - if !strings.HasPrefix(result, "data:") { - t.Errorf("expected body to have prefix 'data:', got %s", result) - } + if !strings.HasPrefix(result, "data:") { + t.Errorf("expected body to have prefix 'data:', got %s", result) + } - if !strings.Contains(result, "example.com") { - t.Errorf("expected body to contain 'example.com', got %s", result) - } + if !strings.Contains(result, "example.com") { + t.Errorf("expected body to contain 'example.com', got %s", result) + } - done <- true - }() - <-ch + cancel() + }() + } + + tr.Init(ctx) tr.Notify(RoundTripLog{RequestLog: RequestLog{Url: "http://example.com"}}) }) t.Run("flusher not supported", func(t *testing.T) { xx := &x{make(map[string][]string), 0, ""} - tr := NewTransport(nil, nil) + tr := NewSSETransport() req, _ := http.NewRequest("GET", "http://example.com", nil) tr.ServeHTTP(xx, req) if xx.statusCode != 500 { diff --git a/witness.go b/witness.go index edf3388..d62d02e 100644 --- a/witness.go +++ b/witness.go @@ -1,7 +1,7 @@ package witness import ( - "log" + "context" "net/http" "net/http/httptrace" "time" @@ -39,23 +39,15 @@ type ResponseLog struct { // Notifier interface must be implemented by a transport. type Notifier interface { + Init(context.Context) Notify(RoundTripLog) } -func DebugClient(client *http.Client) { - firstClientConnected := make(chan bool, 1) - n := NewTransport(firstClientConnected, nil) +var DefaultTransport Notifier = NewSSETransport() - go (func() { - // TODO: make configurable - log.Fatal("HTTP server error: ", http.ListenAndServe("localhost:1602", n)) - })() - - // TODO: make configurable - // wait until first client connected - <-firstClientConnected - - InstrumentClient(client, n, true) +func DebugClient(client *http.Client, ctx context.Context) { + DefaultTransport.Init(ctx) + InstrumentClient(client, DefaultTransport, true) } func InstrumentClient(client *http.Client, n Notifier, includeBody bool) { diff --git a/witness_test.go b/witness_test.go index d4920d2..ff48eab 100644 --- a/witness_test.go +++ b/witness_test.go @@ -2,18 +2,17 @@ package witness import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" - "log" "net/http" "net/http/httptest" "testing" - "time" ) // uncomment this for manual testing using frontend inspector client -//* +/* func TestDebugClient(t *testing.T) { client := &http.Client{} fmt.Println("haha") @@ -44,12 +43,27 @@ func TestDebugClient(t *testing.T) { type fakeNotifier struct { payload RoundTripLog + ctx context.Context +} + +func (n *fakeNotifier) Init(ctx context.Context) { + n.ctx = ctx } func (n *fakeNotifier) Notify(p RoundTripLog) { n.payload = p } +func TestDebugClient(t *testing.T) { + client := &http.Client{} + dtStashed := DefaultTransport + defer func() { + DefaultTransport = dtStashed + }() + DefaultTransport = &fakeNotifier{} + DebugClient(client, context.Background()) +} + func TestInstrumentClient(t *testing.T) { t.Run("with body", func(t *testing.T) { client := &http.Client{}