Skip to content

Commit

Permalink
Split up BaseWebSocket more into ClientWebSocket/ServerWebSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
UnknownShadow200 committed Feb 16, 2021
1 parent eb4bd39 commit 6d005f4
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 68 deletions.
177 changes: 111 additions & 66 deletions MCGalaxy/Network/BaseWebSocket.cs
Expand Up @@ -21,54 +21,39 @@

namespace MCGalaxy.Network {

/// <summary> Abstracts WebSocket handling. </summary>
/// <summary> Abstracts WebSocket handling </summary>
public abstract class BaseWebSocket : INetSocket, INetProtocol {
bool readingHeaders = true;
bool conn, upgrade, version;
string verKey;
protected bool conn, upgrade;
protected bool readingHeaders = true;

void AcceptConnection() {
const string fmt =
"HTTP/1.1 101 Switching Protocols\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: {0}\r\n" +
"Sec-WebSocket-Protocol: ClassiCube\r\n" +
"\r\n";

string key = verKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
/// <summary> Computes the base64-encoded handshake verification key </summary>
protected static string ComputeKey(string rawKey) {
string key = rawKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
SHA1 sha = SHA1.Create();
byte[] raw = sha.ComputeHash(Encoding.ASCII.GetBytes(key));

string headers = String.Format(fmt, Convert.ToBase64String(raw));
SendRaw(Encoding.ASCII.GetBytes(headers), SendFlags.None);
readingHeaders = false;
return Convert.ToBase64String(raw);
}

protected abstract void OnGotAllHeaders();
protected abstract void OnGotHeader(string key, string val);

void ProcessHeader(string raw) {
// end of headers
if (raw.Length == 0) {
if (conn && upgrade && version && verKey != null) {
AcceptConnection();
} else {
// don't pretend to be a http server (so IP:port isn't marked as one by bots)
Close();
}
}
// end of all headers
if (raw.Length == 0) OnGotAllHeaders();

// check that got a proper header
int sep = raw.IndexOf(':');
if (sep == -1) return; // not a proper header
if (sep == -1) return;

string key = raw.Substring(0, sep);
string val = raw.Substring(sep + 1).Trim();

if (key.CaselessEq("Connection")) {
conn = val.CaselessContains("Upgrade");
} else if (key.CaselessEq("Upgrade")) {
upgrade = val.CaselessEq("websocket");
} else if (key.CaselessEq("Sec-WebSocket-Version")) {
version = val.CaselessEq("13");
} else if (key.CaselessEq("Sec-WebSocket-Key")) {
verKey = val;
} else {
OnGotHeader(key, val);
}
}

Expand Down Expand Up @@ -101,10 +86,10 @@ public abstract class BaseWebSocket : INetSocket, INetProtocol {
const int state_mask = 4;
const int state_data = 5;

const int opcode_continued = 0;
const int opcode_binary = 2;
const int opcode_disconnect = 8;
const int FIN = 0x80;
protected const int OPCODE_CONTINUED = 0;
protected const int OPCODE_BINARY = 2;
protected const int OPCODE_DISCONNECT = 8;
protected const int FIN = 0x80;

void DecodeFrame() {
for (int i = 0; i < frameLen; i++) {
Expand All @@ -113,12 +98,12 @@ public abstract class BaseWebSocket : INetSocket, INetProtocol {

switch (opcode) {
// TODO: reply to ping frames
case opcode_continued:
case opcode_binary:
case OPCODE_CONTINUED:
case OPCODE_BINARY:
if (frameLen == 0) return;
HandleData(frame, frameLen);
break;
case opcode_disconnect:
case OPCODE_DISCONNECT:
// Connection is getting closed
Disconnect(1000); break;
default:
Expand Down Expand Up @@ -205,26 +190,9 @@ public abstract class BaseWebSocket : INetSocket, INetProtocol {
return offset;
}

/// <summary> Wraps the given data in a websocket frame </summary>
protected static byte[] WrapData(byte[] data) {
int headerLen = 2 + (data.Length >= 126 ? 2 : 0);
byte[] packet = new byte[headerLen + data.Length];
packet[0] = opcode_binary | FIN;

if (headerLen > 2) {
packet[1] = 126;
packet[2] = (byte)(data.Length >> 8);
packet[3] = (byte)data.Length;
} else {
packet[1] = (byte)data.Length;
}
Buffer.BlockCopy(data, 0, packet, headerLen, data.Length);
return packet;
}

protected static byte[] WrapDisconnect(int reason) {
byte[] packet = new byte[4];
packet[0] = opcode_disconnect | FIN;
packet[0] = OPCODE_DISCONNECT | FIN;
packet[1] = 2;
packet[2] = (byte)(reason >> 8);
packet[3] = (byte)reason;
Expand All @@ -243,20 +211,97 @@ public abstract class BaseWebSocket : INetSocket, INetProtocol {
public void Disconnect() { Disconnect(1000); }
}

/// <summary> Abstracts a server side WebSocket </summary>
public abstract class ServerWebSocket : BaseWebSocket {
bool version;
string verKey;

void AcceptConnection() {
const string fmt =
"HTTP/1.1 101 Switching Protocols\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: {0}\r\n" +
"Sec-WebSocket-Protocol: ClassiCube\r\n" +
"\r\n";

string key = ComputeKey(verKey);
string headers = String.Format(fmt, key);
SendRaw(Encoding.ASCII.GetBytes(headers), SendFlags.None);
readingHeaders = false;
}

protected override void OnGotAllHeaders() {
if (conn && upgrade && version && verKey != null) {
AcceptConnection();
} else {
// don't pretend to be a http server (so IP:port isn't marked as one by bots)
Close();
}
}

protected override void OnGotHeader(string key, string val) {
if (key.CaselessEq("Sec-WebSocket-Version")) {
version = val.CaselessEq("13");
} else if (key.CaselessEq("Sec-WebSocket-Key")) {
verKey = val;
}
}

/// <summary> Wraps the given data in a websocket frame </summary>
protected static byte[] WrapData(byte[] data) {
int headerLen = 2 + (data.Length >= 126 ? 2 : 0);
byte[] packet = new byte[headerLen + data.Length];
packet[0] = OPCODE_BINARY | FIN;

if (headerLen > 2) {
packet[1] = 126;
packet[2] = (byte)(data.Length >> 8);
packet[3] = (byte)data.Length;
} else {
packet[1] = (byte)data.Length;
}
Buffer.BlockCopy(data, 0, packet, headerLen, data.Length);
return packet;
}
}

/// <summary> Abstracts a client side WebSocket </summary>
public abstract class ClientWebSocket : BaseWebSocket {
string verKey;
// TODO: use a random securely generated key
const string key = "xTNDiuZRoMKtxrnJDWyLmA==";

protected void WriteHeader(string value) {
byte[] data = Encoding.ASCII.GetBytes(value + "\r\n");
SendRaw(data, SendFlags.None);
void AcceptConnection() {
readingHeaders = false;
}

protected override void OnGotAllHeaders() {
if (conn && upgrade && verKey == ComputeKey(key)) {
AcceptConnection();
} else {
// don't pretend to be a http server (so IP:port isn't marked as one by bots)
Close();
}
}

protected override void OnGotHeader(string key, string val) {
if (key.CaselessEq("Sec-WebSocket-Accept")) {
verKey = val;
}
}

public override void Init() {
WriteHeader("GET / HTTP/1.1");
WriteHeader("Connection: Upgrade");
WriteHeader("Upgrade: websocket");
WriteHeader("Sec-WebSocket-Version: 13");
WriteHeader("Sec-WebSocket-Key: xTNDiuZRoMKtxrnJDWyLmA==");
WriteHeader("");
const string fmt =
"GET / HTTP/1.1\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Version: 13\r\n" +
"Sec-WebSocket-Key: {0}\r\n" +
"\r\n";

string headers = String.Format(fmt, key);
SendRaw(Encoding.ASCII.GetBytes(headers), SendFlags.None);
}
}
}
2 changes: 1 addition & 1 deletion MCGalaxy/Network/Player.Networking.cs
Expand Up @@ -130,7 +130,7 @@ public partial class Player : IDisposable, INetProtocol {


public void Send(byte[] buffer) { Socket.Send(buffer, SendFlags.None); }
public void Send(byte[] buffer, bool sync = false) {
public void Send(byte[] buffer, bool sync) {
Socket.Send(buffer, sync ? SendFlags.Synchronous : SendFlags.None);
}

Expand Down
2 changes: 1 addition & 1 deletion MCGalaxy/Network/Sockets.cs
Expand Up @@ -272,7 +272,7 @@ public sealed class TcpSocket : INetSocket {
}

/// <summary> Abstracts a WebSocket on top of a socket. </summary>
public sealed class WebSocket : BaseWebSocket {
public sealed class WebSocket : ServerWebSocket {
readonly INetSocket s;

public WebSocket(INetSocket socket) { s = socket; }
Expand Down

0 comments on commit 6d005f4

Please sign in to comment.