/
connections.go
165 lines (139 loc) · 3.66 KB
/
connections.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package caching
import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"net"
"github.com/Liphium/station/spacestation/util"
"github.com/dgraph-io/ristretto"
)
type Connection struct {
ID string
Room string
ClientID string
CurrentSession string
UDP *net.UDPAddr
Key []byte
Cipher cipher.Block
}
func (c *Connection) KeyBase64() string {
return base64.StdEncoding.EncodeToString(c.Key)
}
// ! Always use cost 1
var connectionsCache *ristretto.Cache // ConnectionID -> Connection
var clientIDCache *ristretto.Cache // ClientID -> ConnectionID
func setupConnectionsCache() {
var err error
connectionsCache, err = ristretto.NewCache(&ristretto.Config{
NumCounters: 10_000_000, // 1 million expected connections
MaxCost: 1 << 30, // 1 GB
BufferItems: 64,
})
if err != nil {
panic(err)
}
clientIDCache, err = ristretto.NewCache(&ristretto.Config{
NumCounters: 10_000_000, // 1 million expected connections
MaxCost: 1 << 30, // 1 GB
BufferItems: 64,
OnEvict: func(item *ristretto.Item) {
util.Log.Println("[cache] cached client id of connection", item.Value, "was deleted")
},
})
if err != nil {
panic(err)
}
}
// packetHash = encrypted hash included in the packet by the client
// hash = computed hash of the packet
func VerifyUDP(clientId string, udp net.Addr, hash []byte, voice []byte) (Connection, bool) {
// Get connection
connectionId, valid := clientIDCache.Get(clientId)
if !valid {
return Connection{}, false
}
obj, valid := connectionsCache.Get(connectionId.(string))
if !valid {
return Connection{}, false
}
conn := obj.(Connection)
// Verify hash
merged := append(voice, conn.Key...)
computedHash := util.Hash(merged)
if !util.CompareHash(computedHash, hash) {
util.Log.Println("Error: Hashes don't match")
util.Log.Println("Expected:", computedHash)
util.Log.Println("Got:", hash)
return Connection{}, false
}
// Set UDP
if conn.UDP == nil {
udp, err := net.ResolveUDPAddr("udp", udp.String())
if err != nil {
util.Log.Println("Error: Couldn't resolve udp address:", err)
return Connection{}, false
}
conn.UDP = udp
valid := EnterUDP(conn.Room, conn.ID, clientId, udp, &conn.Key)
if !valid {
util.Log.Println("Error: Couldn't enter udp")
return Connection{}, false
}
connectionsCache.Set(connectionId, conn, 1)
connectionsCache.Wait()
util.Log.Println("Success: UDP set")
}
return conn, true
}
func EmptyConnection(connId string, room string) Connection {
// Generate encryption key
key, err := util.GenerateKey()
if err != nil {
panic(err)
}
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// Store in cache
clientId := util.GenerateToken(10)
conn := Connection{
ID: connId,
Room: room,
ClientID: clientId,
UDP: nil,
Key: key,
Cipher: block,
}
connectionsCache.Set(connId, conn, 1)
clientIDCache.Set(clientId, connId, 1)
return conn
}
func GetConnection(connId string) (Connection, bool) {
conn, valid := connectionsCache.Get(connId)
if !valid {
return Connection{}, false
}
return conn.(Connection), valid
}
// TODO: Create a test for all deletion functions
func DeleteConnection(connId string) {
obj, valid := connectionsCache.Get(connId)
if !valid {
return
}
connection := obj.(Connection)
connectionsCache.Del(connId)
clientIDCache.Del(connection.ClientID)
}
func JoinSession(connId string, sessionId string) bool {
obj, valid := connectionsCache.Get(connId)
if !valid {
return false
}
connection := obj.(Connection)
connection.CurrentSession = sessionId
connectionsCache.Set(connId, connection, 1)
connectionsCache.Wait()
return true
}