-
Notifications
You must be signed in to change notification settings - Fork 0
/
auth.go
127 lines (100 loc) · 2.94 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package main
import (
"crypto/md5"
"fmt"
"io"
"log"
"net/http"
"strings"
"github.com/stretchr/gomniauth"
gomniauthcommon "github.com/stretchr/gomniauth/common"
"github.com/stretchr/objx"
)
// ChatUser interface holds the avatar and unique id of a client
type ChatUser interface {
UniqueID() string
AvatarURL() string
}
type chatUser struct {
gomniauthcommon.User
uniqueID string
}
func (c chatUser) UniqueID() string {
return c.uniqueID
}
type authHandler struct {
next http.Handler
}
func (a *authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, err := r.Cookie("auth")
if err == http.ErrNoCookie {
w.Header().Set("Location", "/login")
w.WriteHeader(http.StatusTemporaryRedirect)
return
}
a.next.ServeHTTP(w, r)
}
// MustAuth protects a handler using authHandler
func MustAuth(handler http.Handler) http.Handler {
return &authHandler{next: handler}
}
// loginHandler handles the third-party login process.
// format: /auth/{action}/{provider}
func loginHandler(w http.ResponseWriter, r *http.Request) {
segments := strings.Split(r.URL.Path, "/")
action := segments[2]
provider := segments[3]
switch action {
case "login":
provider, err := gomniauth.Provider(provider)
if err != nil {
http.Error(w, fmt.Sprintf("Error when trying to get provider %s: %s", provider, err), http.StatusBadRequest)
return
}
loginURL, err := provider.GetBeginAuthURL(nil, nil)
if err != nil {
http.Error(w, fmt.Sprintf("Error when trying to GetBeginAuthURL for %s:%s", provider, err), http.StatusInternalServerError)
return
}
w.Header().Set("Location", loginURL)
w.WriteHeader(http.StatusTemporaryRedirect)
case "callback":
provider, err := gomniauth.Provider(provider)
if err != nil {
http.Error(w, fmt.Sprintf("Error when trying to get provider %s: %s", provider, err), http.StatusBadRequest)
return
}
creds, err := provider.CompleteAuth(objx.MustFromURLQuery(r.URL.RawQuery))
if err != nil {
http.Error(w, fmt.Sprintf("Error when trying to complete auth for %s: %s", provider, err), http.StatusInternalServerError)
return
}
user, err := provider.GetUser(creds)
if err != nil {
log.Fatalln("Error when trying to get user from", provider, "-", err)
}
chatUser := &chatUser{User: user}
m := md5.New()
io.WriteString(m, strings.ToLower(user.Email()))
chatUser.uniqueID = fmt.Sprintf("%x", m.Sum(nil))
avatarURL, err := avatars.GetAvatarURL(chatUser)
if err != nil {
log.Fatalln("Error when trying to GetAvatarURL", "-", err)
}
authCookieValue := objx.New(map[string]interface{}{
"userid": chatUser.uniqueID,
"name": user.Name(),
"avatar_url": avatarURL,
}).MustBase64()
http.SetCookie(w, &http.Cookie{
Name: "auth",
Value: authCookieValue,
Path: "/",
})
w.Header().Set("Location", "/chat")
w.WriteHeader(http.StatusTemporaryRedirect)
default:
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "Auth action %s not supported", action)
}
}