/
socket.go
138 lines (115 loc) · 2.89 KB
/
socket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package socket
import (
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"github.com/gorilla/websocket"
"github.com/teris-io/shortid"
)
// Message that is received from connected clients
type Message struct {
From string `json:"-"`
Type string `json:"type"`
Payload interface{} `json:"payload"`
}
var (
mutex = sync.Mutex{}
allClients = sync.Map{}
upgrader = websocket.Upgrader{}
closeCodes = []int{1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007,
1008, 1009, 1010, 1011, 1012, 1013, 1014}
cMsg chan *Message
// OnOpen is called whenever a new client is connected
OnOpen func(clientID string)
// OnClose is called whenever a client disconnects for any reason
OnClose func(clientID string)
// OnError is called whenever an error occurs
OnError func(clientID string, err error)
// CheckOrigin is used by Socket when upgrading a WebSocket connection
CheckOrigin func(r *http.Request) bool
)
func init() {
cMsg = make(chan *Message)
upgrader.ReadBufferSize = 1024
upgrader.WriteBufferSize = 1024
}
// Handler connects a new client. Any errors are sent to the OnError callback or are instead
// sent to standard output.
func Handler(w http.ResponseWriter, r *http.Request) {
err := func() error {
upgrader.CheckOrigin = CheckOrigin
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return err
}
clientID, err := shortid.Generate()
if err != nil {
return err
}
allClients.Store(clientID, conn)
go handleClient(clientID, conn)
if OnOpen != nil {
OnOpen(clientID)
}
return nil
}()
if err != nil {
if OnError != nil {
OnError("socket.Handler", err)
} else {
log.Printf("socket.Handler error: %v", err)
}
}
}
func handleClient(clientID string, conn *websocket.Conn) {
for {
msg := &Message{}
if err := conn.ReadJSON(msg); err != nil {
if OnError != nil {
OnError(clientID, err)
}
if websocket.IsCloseError(err, closeCodes...) {
break
}
}
msg.From = clientID
cMsg <- msg
}
if OnClose != nil {
OnClose(clientID)
}
conn.Close()
allClients.Delete(clientID)
}
// Broadcast to all connected clients
func Broadcast(msgType string, msgPayload interface{}) error {
buf, err := json.Marshal(&Message{Type: msgType, Payload: msgPayload})
if err != nil {
return err
}
mutex.Lock()
defer mutex.Unlock()
allClients.Range(func(key, val interface{}) bool {
conn := val.(*websocket.Conn)
conn.WriteMessage(websocket.TextMessage, buf)
return true
})
return nil
}
// Send a message to a specific client by ID
func Send(clientID, msgType string, msgPayload interface{}) error {
v, ok := allClients.Load(clientID)
if !ok {
return fmt.Errorf("client id %s is not connected", clientID)
}
mutex.Lock()
defer mutex.Unlock()
msg := &Message{Type: msgType, Payload: msgPayload}
return v.(*websocket.Conn).WriteJSON(msg)
}
// Read and block for the next available Message
func Read() *Message {
return <-cMsg
}