-
-
Notifications
You must be signed in to change notification settings - Fork 232
/
proxy.go
231 lines (203 loc) · 5.32 KB
/
proxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
package proxy
import (
"bytes"
"compress/gzip"
"fmt"
"io"
stdlog "log"
"log/slog"
"math"
"net/http"
"net/http/httputil"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/a-h/templ/cmd/templ/generatecmd/sse"
"github.com/andybalholm/brotli"
_ "embed"
)
//go:embed script.js
var script string
const scriptTag = `<script src="/_templ/reload/script.js"></script>`
type Handler struct {
log *slog.Logger
URL string
Target *url.URL
p *httputil.ReverseProxy
sse *sse.Handler
}
func insertScriptTagIntoBody(body string) (updated string) {
return strings.Replace(body, "</body>", scriptTag+"</body>", -1)
}
type passthroughWriteCloser struct {
io.Writer
}
func (pwc passthroughWriteCloser) Close() error {
return nil
}
const unsupportedContentEncoding = "Unsupported content encoding, hot reload script not inserted."
func (h *Handler) modifyResponse(r *http.Response) error {
if r.Header.Get("templ-skip-modify") == "true" {
return nil
}
if contentType := r.Header.Get("Content-Type"); !strings.HasPrefix(contentType, "text/html") {
return nil
}
// Set up readers and writers.
newReader := func(in io.Reader) (out io.Reader, err error) {
return in, nil
}
newWriter := func(out io.Writer) io.WriteCloser {
return passthroughWriteCloser{out}
}
switch r.Header.Get("Content-Encoding") {
case "gzip":
newReader = func(in io.Reader) (out io.Reader, err error) {
return gzip.NewReader(in)
}
newWriter = func(out io.Writer) io.WriteCloser {
return gzip.NewWriter(out)
}
case "br":
newReader = func(in io.Reader) (out io.Reader, err error) {
return brotli.NewReader(in), nil
}
newWriter = func(out io.Writer) io.WriteCloser {
return brotli.NewWriter(out)
}
case "":
// No content encoding.
default:
h.log.Warn(unsupportedContentEncoding, slog.String("encoding", r.Header.Get("Content-Encoding")))
}
// Read the encoded body.
encr, err := newReader(r.Body)
if err != nil {
return err
}
defer r.Body.Close()
body, err := io.ReadAll(encr)
if err != nil {
return err
}
// Update it.
updated := insertScriptTagIntoBody(string(body))
// Encode the response.
var buf bytes.Buffer
encw := newWriter(&buf)
_, err = encw.Write([]byte(updated))
if err != nil {
return err
}
err = encw.Close()
if err != nil {
return err
}
// Update the response.
r.Body = io.NopCloser(&buf)
r.ContentLength = int64(buf.Len())
r.Header.Set("Content-Length", strconv.Itoa(buf.Len()))
return nil
}
func New(log *slog.Logger, bind string, port int, target *url.URL) (h *Handler) {
p := httputil.NewSingleHostReverseProxy(target)
p.ErrorLog = stdlog.New(os.Stderr, "Proxy to target error: ", 0)
p.Transport = &roundTripper{
maxRetries: 10,
initialDelay: 100 * time.Millisecond,
backoffExponent: 1.5,
}
h = &Handler{
log: log,
URL: fmt.Sprintf("http://%s:%d", bind, port),
Target: target,
p: p,
sse: sse.New(),
}
p.ModifyResponse = h.modifyResponse
return h
}
func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/_templ/reload/script.js" {
// Provides a script that reloads the page.
w.Header().Add("Content-Type", "text/javascript")
_, err := io.WriteString(w, script)
if err != nil {
fmt.Printf("failed to write script: %v\n", err)
}
return
}
if r.URL.Path == "/_templ/reload/events" {
switch r.Method {
case http.MethodGet:
// Provides a list of messages including a reload message.
p.sse.ServeHTTP(w, r)
return
case http.MethodPost:
// Send a reload message to all connected clients.
p.sse.Send("message", "reload")
return
}
http.Error(w, "only GET or POST method allowed", http.StatusMethodNotAllowed)
return
}
p.p.ServeHTTP(w, r)
}
func (p *Handler) SendSSE(eventType string, data string) {
p.sse.Send(eventType, data)
}
type roundTripper struct {
maxRetries int
initialDelay time.Duration
backoffExponent float64
}
func (rt *roundTripper) setShouldSkipResponseModificationHeader(r *http.Request, resp *http.Response) {
// Instruct the modifyResponse function to skip modifying the response if the
// HTTP request has come from HTMX.
if r.Header.Get("HX-Request") != "true" {
return
}
resp.Header.Set("templ-skip-modify", "true")
}
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// Read and buffer the body.
var bodyBytes []byte
if r.Body != nil && r.Body != http.NoBody {
var err error
bodyBytes, err = io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body.Close()
}
// Retry logic.
var resp *http.Response
var err error
for retries := 0; retries < rt.maxRetries; retries++ {
// Clone the request and set the body.
req := r.Clone(r.Context())
if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
// Execute the request.
resp, err = http.DefaultTransport.RoundTrip(req)
if err != nil {
time.Sleep(rt.initialDelay * time.Duration(math.Pow(rt.backoffExponent, float64(retries))))
continue
}
rt.setShouldSkipResponseModificationHeader(r, resp)
return resp, nil
}
return nil, fmt.Errorf("max retries reached")
}
func NotifyProxy(host string, port int) error {
urlStr := fmt.Sprintf("http://%s:%d/_templ/reload/events", host, port)
req, err := http.NewRequest(http.MethodPost, urlStr, nil)
if err != nil {
return err
}
_, err = http.DefaultClient.Do(req)
return err
}