-
Notifications
You must be signed in to change notification settings - Fork 34
/
PublicKeySSLServer.java
303 lines (262 loc) · 10.4 KB
/
PublicKeySSLServer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
package edu.washington.cs.publickey.ssl.server;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.security.KeyManagementException;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.zip.GZIPOutputStream;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLSocket;
import edu.washington.cs.publickey.PublicKeyFriend;
import edu.washington.cs.publickey.Tools;
import edu.washington.cs.publickey.storage.PersistentStorage;
public class PublicKeySSLServer {
static final String KEY_SSL_SERVER_KEYSTORE = "ssl_server_keystore";
final static String KEY_SSL_PORT = "ssl_server_port";
private static final int NUM_THREADS = 500;
private static final int MAX_DB_QUEUE = 20;
/*
* to protect against dos, only allow 2 connection attempts/second from the
* same ip and max 50 concurrent connections
*/
private static final long MIN_MS_BETWEEN_CONNECT_ATTEMPTS_PER_IP = 500;
private static final Integer MAX_CONNECTION_PER_IP = 10;
private final PersistentStorage storage;
private final int serverPort;
private volatile boolean quit = false;
private final ExecutorService threadPool;
private SSLServerSocket serverSocket;
private HashMap<String, Integer> activeConnections = new HashMap<String, Integer>();
private HashMap<String, Long> lastConnectAttempt = new HashMap<String, Long>();
private volatile int queueLength = 0;
public PublicKeySSLServer(Properties props, PersistentStorage storage, char[] keystorePassword) throws IOException, KeyManagementException, NoSuchAlgorithmException, KeyStoreException, CertificateException, UnrecoverableKeyException, InterruptedException {
this.serverPort = Integer.parseInt((String) props.get(KEY_SSL_PORT));
this.storage = storage;
this.threadPool = Executors.newFixedThreadPool(NUM_THREADS);
File keyStoreFile = new File(props.getProperty(KEY_SSL_SERVER_KEYSTORE));
SSLKeyManager sslManager = new SSLKeyManager(keyStoreFile, keystorePassword);
serverSocket = sslManager.createServerSocket(serverPort);
serverSocket.setNeedClientAuth(true);
Thread t = new Thread(new Runnable() {
public void run() {
System.out.println("SSL server: listening on port " + serverPort);
while (!quit) {
try {
SSLSocket csocket = (SSLSocket) serverSocket.accept();
String remoteIp = csocket.getInetAddress().getHostAddress();
if (isConnectionAllowed(remoteIp)) {
queueLength++;
// System.out.println("connection from: " + remoteIp
// + " queue=" + queueLength);
initiatingConnection(remoteIp);
PublicKeySSLServerProtocol publicKeySSLServerProtocol = new PublicKeySSLServerProtocol(csocket);
threadPool.execute(publicKeySSLServerProtocol);
}
} catch (IOException e) {
if (e instanceof java.net.SocketException && e.getMessage().equals("Socket closed")) {
System.out.println("SSL Server: closing socket");
} else {
e.printStackTrace();
}
}
}
}
});
t.setName("SSL Server accept thread");
t.start();
}
private void initiatingConnection(String remoteIp) {
synchronized (lastConnectAttempt) {
lastConnectAttempt.put(remoteIp, System.currentTimeMillis());
Integer active = activeConnections.get(remoteIp);
if (active == null) {
activeConnections.put(remoteIp, 1);
} else {
activeConnections.put(remoteIp, active + 1);
}
}
}
private void closingConnection(String remoteIp) {
synchronized (lastConnectAttempt) {
Integer active = activeConnections.get(remoteIp);
if (active <= 1) {
activeConnections.remove(remoteIp);
} else {
activeConnections.put(remoteIp, active - 1);
}
}
}
private boolean isConnectionAllowed(String remoteIp) {
synchronized (lastConnectAttempt) {
Long lastAttempt = lastConnectAttempt.get(remoteIp);
if (lastAttempt != null) {
long timeSince = System.currentTimeMillis() - lastAttempt;
if (timeSince < MIN_MS_BETWEEN_CONNECT_ATTEMPTS_PER_IP) {
System.err.println(new Date() + ": connection from '" + remoteIp + "' denied, " + " to high connect frequency (" + timeSince + "ms<" + MIN_MS_BETWEEN_CONNECT_ATTEMPTS_PER_IP + "ms)");
return false;
}
}
Integer numActiveConnections = activeConnections.get(remoteIp);
if (numActiveConnections != null && numActiveConnections > MAX_CONNECTION_PER_IP) {
System.err.println(new Date() + ": connection from '" + remoteIp + "' denied, " + " to many connections (" + numActiveConnections + ">" + MAX_CONNECTION_PER_IP + ")");
return false;
}
return true;
}
}
class PublicKeySSLServerProtocol implements Runnable {
private final SSLSocket socket;
private final String remoteIp;
long timeInclNet;
public PublicKeySSLServerProtocol(SSLSocket socket) {
this.socket = socket;
this.remoteIp = socket.getInetAddress().getHostAddress();
timeInclNet = System.currentTimeMillis();
}
public void run() {
try {
int dbQueueLength = storage.getDbQueueLength();
if (dbQueueLength > MAX_DB_QUEUE) {
long t = System.currentTimeMillis();
DataInputStream in = new DataInputStream(socket.getInputStream());
BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(socket.getOutputStream())));
// read to not confuse the client
readIgnoreList(in);
out.write(PublicKeyFriend.serialize(new PublicKeyFriend[0]));
out.close();
in.close();
// log("dropping incoming connection, db_queue: " +
// dbQueueLength + " conn_queue: " + queueLength + " took: "
// + (System.currentTimeMillis() - t) + "ms");
} else {
Certificate[] remoteCerts;
remoteCerts = socket.getSession().getPeerCertificates();
if (remoteCerts.length > 0) {
byte[] remoteKey = remoteCerts[0].getPublicKey().getEncoded();
// log("accepting incoming connection, db_queue: " +
// dbQueueLength + " conn_queue: " + queueLength);
DataInputStream in = new DataInputStream(socket.getInputStream());
BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new GZIPOutputStream(socket.getOutputStream())));
HashSet<Integer> keysToIgnore = readIgnoreList(in);
long timeExclNet = System.currentTimeMillis();
// ok, now we know what to ignore
// lets try to find new friends
PublicKeyFriend f = new PublicKeyFriend();
f.setPublicKey(remoteKey);
f.setPublicKeySha1(Tools.getSha1(remoteKey));
long dbStartTime = System.currentTimeMillis();
List<PublicKeyFriend> allFriends = storage.getFriendsUsingPublicKey(f);
long preLastSeenUpdate = System.currentTimeMillis();
storage.updateUserLastSeen(f);
long dbTime = System.currentTimeMillis() - dbStartTime;
long lastSeenOverhead = System.currentTimeMillis() - preLastSeenUpdate;
Map<Integer, PublicKeyFriend> newFriends = new HashMap<Integer, PublicKeyFriend>();
// add them all
for (PublicKeyFriend friend : allFriends) {
// ignore friends the user already knows of
int friendHash = Arrays.hashCode(friend.getPublicKeySha1());
if (!keysToIgnore.contains(friendHash)) {
// check if we already added this one
if (!newFriends.containsKey(friendHash)) {
newFriends.put(friendHash, friend);
}
}
}
List<PublicKeyFriend> friendsArray = new LinkedList<PublicKeyFriend>();
friendsArray.addAll(newFriends.values());
/*
* for debugging, just return an empty list
*/
String serialized = PublicKeyFriend.serialize(new PublicKeyFriend[0]);
// String serialized =
// PublicKeyFriend.serialize(friendsArray.toArray(new
// PublicKeyFriend[newFriends.size()]));
timeExclNet = System.currentTimeMillis() - timeExclNet;
out.write(serialized);
out.close();
in.close();
timeInclNet = System.currentTimeMillis() - timeInclNet;
log("done, returned: " + friendsArray.size() + " ignored: " + keysToIgnore.size() + " time: (queries: " + timeExclNet + " lastSeen: " + lastSeenOverhead + " total: " + timeInclNet + " db=" + dbTime + ")");
}
}
} catch (java.io.EOFException e) {
System.err.println(remoteIp + ": " + "EOF to early");
} catch (SSLPeerUnverifiedException e) {
System.err.println(remoteIp + ": no cert, closing conn");
} catch (java.net.SocketException e) {
if ("Connection timed out".equals(e.getMessage())) {
System.err.println(remoteIp + ": other side timed out");
} else if ("Connection reset".equals(e.getMessage())) {
System.err.println(remoteIp + ": other side closed socket");
} else if ("Broken pipe".equals(e.getMessage())) {
System.err.println(remoteIp + ": broken pipe");
} else {
e.printStackTrace();
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
} finally {
if (socket != null) {
try {
socket.close();
} catch (IOException e1) {
}
}
closingConnection(remoteIp);
queueLength--;
}
}
private HashSet<Integer> readIgnoreList(DataInputStream in) throws IOException {
int numToIgnore = in.readInt();
if (numToIgnore > 100000) {
System.err.println("warning: user specified more that 100000 friends (!!!???), closing conn");
socket.close();
throw new IOException("user specified invalid data");
}
HashSet<Integer> keysToIgnore = new HashSet<Integer>();
byte[] pubKeySha = new byte[20];
for (int i = 0; i < numToIgnore; i++) {
// read 20 bytes
in.readFully(pubKeySha);
int hash = Arrays.hashCode(pubKeySha);
keysToIgnore.add(hash);
}
return keysToIgnore;
}
private void log(String msg) {
System.out.println(remoteIp + ": " + msg);
}
}
public void shutdown() {
quit = true;
try {
serverSocket.close();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
threadPool.shutdown();
storage.shutdown();
}
}