diff --git a/flash.go b/flash.go index c6926f8..0c65a6e 100644 --- a/flash.go +++ b/flash.go @@ -5,6 +5,7 @@ import ( "fmt" "html/template" "net/http" + "strings" "time" "zgo.at/zlog" @@ -24,7 +25,12 @@ func Flash(w http.ResponseWriter, msg string, v ...interface{}) { func ReadFlash(w http.ResponseWriter, r *http.Request) template.HTML { c, err := r.Cookie(cookieFlash) if err != nil || c.Value == "" { - return "" + // The value won't be read if we set the flash on the same + // request. + c = readSetCookie(w) + if c == nil { + return "" + } } b, err := base64.StdEncoding.DecodeString(c.Value) @@ -37,3 +43,21 @@ func ReadFlash(w http.ResponseWriter, r *http.Request) template.HTML { }) return template.HTML(b) } + +func readSetCookie(w http.ResponseWriter) *http.Cookie { + sk := w.Header().Get("Set-Cookie") + if sk == "" { + return nil + } + + e := strings.Index(sk, "=") + if e == -1 || sk[:e] != cookieFlash { + return nil + } + s := strings.Index(sk, ";") + if s == -1 { + return nil + } + + return &http.Cookie{Value: sk[e+1 : s]} +} diff --git a/flash_test.go b/flash_test.go new file mode 100644 index 0000000..bb77928 --- /dev/null +++ b/flash_test.go @@ -0,0 +1,18 @@ +package zhttp + +import ( + "net/http/httptest" + "testing" +) + +func TestFlash(t *testing.T) { + r := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + Flash(rr, "w00t") + + out := ReadFlash(rr, r) + if out != "w00t" { + t.Errorf("out: %#v", out) + } +}