diff --git a/air_example.toml b/air_example.toml index 8a38ff85..b2235e68 100644 --- a/air_example.toml +++ b/air_example.toml @@ -73,3 +73,9 @@ clean_on_exit = true [screen] clear_on_rebuild = true keep_scroll = true + +# Enable live-reloading on the browser. +[proxy] + enabled = true + proxy_port = 8090 + app_port = 8080 diff --git a/runner/config.go b/runner/config.go index a5af959c..78a4eaa5 100644 --- a/runner/config.go +++ b/runner/config.go @@ -31,6 +31,7 @@ type Config struct { Log cfgLog `toml:"log"` Misc cfgMisc `toml:"misc"` Screen cfgScreen `toml:"screen"` + Proxy cfgProxy `toml:"proxy"` } type cfgBuild struct { @@ -96,6 +97,12 @@ type cfgScreen struct { KeepScroll bool `toml:"keep_scroll"` } +type cfgProxy struct { + Enabled bool `toml:"enabled"` + ProxyPort int `toml:"proxy_port"` + AppPort int `toml:"app_port"` +} + type sliceTransformer struct{} func (t sliceTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error { diff --git a/runner/engine.go b/runner/engine.go index 69db92d4..c652dda1 100644 --- a/runner/engine.go +++ b/runner/engine.go @@ -18,6 +18,7 @@ import ( // Engine ... type Engine struct { config *Config + proxy *Proxy logger *logger watcher filenotify.FileWatcher debugMode bool @@ -48,6 +49,7 @@ func NewEngineWithConfig(cfg *Config, debugMode bool) (*Engine, error) { } e := Engine{ config: cfg, + proxy: NewProxy(&cfg.Proxy), logger: logger, watcher: watcher, debugMode: debugMode, @@ -310,6 +312,11 @@ func (e *Engine) isModified(filename string) bool { // Endless loop and never return func (e *Engine) start() { + if e.config.Proxy.Enabled { + go e.proxy.Run() + e.mainLog("Proxy server listening on http://localhost%s", e.proxy.server.Addr) + } + e.running = true firstRunCh := make(chan bool, 1) firstRunCh <- true @@ -535,6 +542,9 @@ func (e *Engine) runBin() error { cmd, stdout, stderr, _ := e.startCmd(command) processExit := make(chan struct{}) e.mainDebug("running process pid %v", cmd.Process.Pid) + if e.config.Proxy.Enabled { + e.proxy.Reload() + } wg.Add(1) atomic.AddUint64(&e.round, 1) @@ -579,6 +589,11 @@ func (e *Engine) cleanup() { e.mainLog("cleaning...") defer e.mainLog("see you again~") + if e.config.Proxy.Enabled { + e.mainDebug("powering down the proxy...") + e.proxy.Stop() + } + e.withLock(func() { close(e.binStopCh) e.binStopCh = make(chan bool) diff --git a/runner/engine_test.go b/runner/engine_test.go index 2521e308..10498ade 100644 --- a/runner/engine_test.go +++ b/runner/engine_test.go @@ -927,6 +927,9 @@ func Test(t *testing.T) { t.Log("testing") } `) + if err != nil { + t.Fatal(err) + } // run sed // check the file is exist if _, err := os.Stat(dftTOML); err != nil { diff --git a/runner/proxy.go b/runner/proxy.go new file mode 100644 index 00000000..282f2f0a --- /dev/null +++ b/runner/proxy.go @@ -0,0 +1,169 @@ +package runner + +import ( + "bytes" + "errors" + "fmt" + "io" + "log" + "net/http" + "strconv" + "strings" + "syscall" + "time" +) + +type Reloader interface { + AddSubscriber() *Subscriber + RemoveSubscriber(id int) + Reload() + Stop() +} + +type Proxy struct { + server *http.Server + client *http.Client + config *cfgProxy + stream Reloader +} + +func NewProxy(cfg *cfgProxy) *Proxy { + p := &Proxy{ + config: cfg, + server: &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.ProxyPort), + }, + client: &http.Client{}, + stream: NewProxyStream(), + } + return p +} + +func (p *Proxy) Run() { + http.HandleFunc("/", p.proxyHandler) + http.HandleFunc("/internal/reload", p.reloadHandler) + if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("failed to start proxy server: %v", err) + } +} + +func (p *Proxy) Stop() { + p.server.Close() + p.stream.Stop() +} + +func (p *Proxy) Reload() { + p.stream.Reload() +} + +func (p *Proxy) injectLiveReload(respBody io.ReadCloser) string { + buf := new(bytes.Buffer) + if _, err := buf.ReadFrom(respBody); err != nil { + log.Fatalf("failed to convert request body to bytes buffer, err: %+v\n", err) + } + original := buf.String() + + // the script will be injected before the end of the body tag. In case the tag is missing, the injection will be skipped without an error to ensure that a page with partial reloads only has at most one injected script. + body := strings.LastIndex(original, "") + if body == -1 { + return original + } + + script := fmt.Sprintf( + ``, + p.config.ProxyPort, + ) + return original[:body] + script + original[body:] +} + +func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) { + appURL := r.URL + appURL.Scheme = "http" + appURL.Host = fmt.Sprintf("localhost:%d", p.config.AppPort) + + if err := r.ParseForm(); err != nil { + log.Fatalf("failed to read form data from request, err: %+v\n", err) + } + var body io.Reader + if len(r.Form) > 0 { + body = strings.NewReader(r.Form.Encode()) + } else { + body = r.Body + } + req, err := http.NewRequest(r.Method, appURL.String(), body) + if err != nil { + log.Fatalf("proxy could not create request, err: %+v\n", err) + } + + // Copy the headers from the original request + for name, values := range r.Header { + for _, value := range values { + req.Header.Add(name, value) + } + } + req.Header.Set("X-Forwarded-For", r.RemoteAddr) + + // retry on connection refused error since after a file change air will restart the server and it may take a few milliseconds for the server to be up-and-running. + var resp *http.Response + for i := 0; i < 10; i++ { + resp, err = p.client.Do(req) + if err == nil { + break + } + if !errors.Is(err, syscall.ECONNREFUSED) { + log.Fatalf("proxy failed to call %s, err: %+v\n", appURL.String(), err) + } + time.Sleep(100 * time.Millisecond) + } + defer resp.Body.Close() + + // Copy the headers from the proxy response except Content-Length + for k, vv := range resp.Header { + for _, v := range vv { + if k == "Content-Length" { + continue + } + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + + if strings.Contains(resp.Header.Get("Content-Type"), "text/html") { + newPage := p.injectLiveReload(resp.Body) + w.Header().Set("Content-Length", strconv.Itoa((len([]byte(newPage))))) + if _, err := io.WriteString(w, newPage); err != nil { + log.Fatalf("proxy failed injected live reloading script, err: %+v\n", err) + } + } else { + w.Header().Set("Content-Length", resp.Header.Get("Content-Length")) + if _, err := io.Copy(w, resp.Body); err != nil { + log.Fatalf("proxy failed to forward the response body, err: %+v\n", err) + } + } +} + +func (p *Proxy) reloadHandler(w http.ResponseWriter, r *http.Request) { + flusher, err := w.(http.Flusher) + if !err { + http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + sub := p.stream.AddSubscriber() + go func() { + <-r.Context().Done() + p.stream.RemoveSubscriber(sub.id) + }() + + w.WriteHeader(http.StatusOK) + flusher.Flush() + + for range sub.reloadCh { + fmt.Fprintf(w, "data: reload\n\n") + flusher.Flush() + } +} diff --git a/runner/proxy_stream.go b/runner/proxy_stream.go new file mode 100644 index 00000000..412bee3f --- /dev/null +++ b/runner/proxy_stream.go @@ -0,0 +1,50 @@ +package runner + +import ( + "sync" +) + +type ProxyStream struct { + sync.Mutex + subscribers map[int]*Subscriber + count int +} + +type Subscriber struct { + id int + reloadCh chan struct{} +} + +func NewProxyStream() *ProxyStream { + return &ProxyStream{subscribers: make(map[int]*Subscriber)} +} + +func (stream *ProxyStream) Stop() { + for id := range stream.subscribers { + stream.RemoveSubscriber(id) + } + stream.count = 0 +} + +func (stream *ProxyStream) AddSubscriber() *Subscriber { + stream.Lock() + defer stream.Unlock() + stream.count++ + + sub := &Subscriber{id: stream.count, reloadCh: make(chan struct{})} + stream.subscribers[stream.count] = sub + return sub +} + +func (stream *ProxyStream) RemoveSubscriber(id int) { + stream.Lock() + defer stream.Unlock() + close(stream.subscribers[id].reloadCh) + delete(stream.subscribers, id) +} + +func (stream *ProxyStream) Reload() { + for _, sub := range stream.subscribers { + sub.reloadCh <- struct{}{} + } +} diff --git a/runner/proxy_stream_test.go b/runner/proxy_stream_test.go new file mode 100644 index 00000000..daf536e3 --- /dev/null +++ b/runner/proxy_stream_test.go @@ -0,0 +1,66 @@ +package runner + +import ( + "sync" + "testing" +) + +func find(s map[int]*Subscriber, id int) bool { + for _, sub := range s { + if sub.id == id { + return true + } + } + return false +} + +func TestProxyStream(t *testing.T) { + stream := NewProxyStream() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + _ = stream.AddSubscriber() + }(i) + } + wg.Wait() + + if got, exp := len(stream.subscribers), 10; got != exp { + t.Errorf("expected %d but got %d", exp, got) + } + + go func() { + stream.Reload() + }() + + reloadCount := 0 + for _, sub := range stream.subscribers { + wg.Add(1) + go func(sub *Subscriber) { + defer wg.Done() + <-sub.reloadCh + reloadCount++ + }(sub) + } + wg.Wait() + + if got, exp := reloadCount, 10; got != exp { + t.Errorf("expected %d but got %d", exp, got) + } + + stream.RemoveSubscriber(2) + stream.AddSubscriber() + if got, exp := find(stream.subscribers, 2), false; got != exp { + t.Errorf("expected subscriber found to be %t but got %t", exp, got) + } + if got, exp := find(stream.subscribers, 11), true; got != exp { + t.Errorf("expected subscriber found to be %t but got %t", exp, got) + } + + stream.Stop() + if got, exp := len(stream.subscribers), 0; got != exp { + t.Errorf("expected %d but got %d", exp, got) + } +} diff --git a/runner/proxy_test.go b/runner/proxy_test.go new file mode 100644 index 00000000..59d59e81 --- /dev/null +++ b/runner/proxy_test.go @@ -0,0 +1,194 @@ +package runner + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strconv" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +type reloader struct { + subCh chan struct{} + reloadCh chan struct{} +} + +func (r *reloader) AddSubscriber() *Subscriber { + r.subCh <- struct{}{} + return &Subscriber{reloadCh: r.reloadCh} +} + +func (r *reloader) RemoveSubscriber(_ int) { + close(r.subCh) +} + +func (r *reloader) Reload() {} +func (r *reloader) Stop() {} + +var proxyPort = 8090 + +func getServerPort(t *testing.T, srv *httptest.Server) int { + mockURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatal(err) + } + port, err := strconv.Atoi(mockURL.Port()) + if err != nil { + t.Fatal(err) + } + return port +} + +func TestNewProxy(t *testing.T) { + _ = os.Unsetenv(airWd) + cfg := &cfgProxy{ + Enabled: true, + ProxyPort: 1111, + AppPort: 2222, + } + proxy := NewProxy(cfg) + if proxy.config == nil { + t.Fatal("config should not be nil") + } + if proxy.server.Addr == "" { + t.Fatal("server address should not be nil") + } +} + +func TestProxy_proxyHandler(t *testing.T) { + tests := []struct { + name string + req func() *http.Request + assert func(*http.Request) + }{ + { + name: "get_request_with_headers", + assert: func(resp *http.Request) { + assert.Equal(t, "bar", resp.Header.Get("foo")) + }, + req: func() *http.Request { + req := httptest.NewRequest("GET", fmt.Sprintf("http://localhost:%d", proxyPort), nil) + req.Header.Set("foo", "bar") + return req + }, + }, + { + name: "post_form_request", + req: func() *http.Request { + formData := url.Values{} + formData.Add("foo", "bar") + req := httptest.NewRequest("POST", fmt.Sprintf("http://localhost:%d", proxyPort), strings.NewReader(formData.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + assert: func(resp *http.Request) { + assert.NoError(t, resp.ParseForm()) + assert.Equal(t, resp.Form.Get("foo"), "bar") + }, + }, + { + name: "get_request_with_query_string", + req: func() *http.Request { + return httptest.NewRequest("GET", fmt.Sprintf("http://localhost:%d?q=%s", proxyPort, "air"), nil) + }, + assert: func(resp *http.Request) { + q := resp.URL.Query() + assert.Equal(t, q.Encode(), "q=air") + }, + }, + { + name: "put_json_request", + req: func() *http.Request { + body := []byte(`{"foo": "bar"}`) + req := httptest.NewRequest("PUT", fmt.Sprintf("http://localhost:%d/a/b/c", proxyPort), bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json; charset=UTF-8") + return req + }, + assert: func(resp *http.Request) { + type Response struct { + Foo string `json:"foo"` + } + var r Response + assert.NoError(t, json.NewDecoder(resp.Body).Decode(&r)) + assert.Equal(t, resp.URL.Path, "/a/b/c") + assert.Equal(t, r.Foo, "bar") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tt.assert(r) + })) + defer srv.Close() + srvPort := getServerPort(t, srv) + proxy := NewProxy(&cfgProxy{ + Enabled: true, + ProxyPort: proxyPort, + AppPort: srvPort, + }) + proxy.proxyHandler(httptest.NewRecorder(), tt.req()) + }) + } +} + +func TestProxy_reloadHandler(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "thin air") + })) + srvPort := getServerPort(t, srv) + defer srv.Close() + + reloader := &reloader{subCh: make(chan struct{}), reloadCh: make(chan struct{})} + cfg := &cfgProxy{ + Enabled: true, + ProxyPort: proxyPort, + AppPort: srvPort, + } + proxy := &Proxy{ + config: cfg, + server: &http.Server{ + Addr: fmt.Sprintf("localhost:%d", proxyPort), + }, + stream: reloader, + } + + req := httptest.NewRequest("GET", fmt.Sprintf("http://localhost:%d/internal/reload", proxyPort), nil) + rec := httptest.NewRecorder() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + proxy.reloadHandler(rec, req) + }() + + // wait for subscriber to be added + <-reloader.subCh + + // send a reload event and wait for http response + reloader.reloadCh <- struct{}{} + close(reloader.reloadCh) + wg.Wait() + + if !rec.Flushed { + t.Errorf("request should have been flushed") + } + + resp := rec.Result() + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("reading body: %v", err) + } + if got, exp := string(bodyBytes), "data: reload\n\n"; got != exp { + t.Errorf("expected %q but got %q", exp, got) + } +}