Permalink
Browse files

add websockets support to our mochiweb copy, from github.com/RJ/mochiweb

  • Loading branch information...
RJ committed Dec 27, 2009
1 parent f43fdbb commit 49e551153c0cb2f9f418fcb597d5db6d7102cf34
@@ -460,12 +460,7 @@ equiv_object(Props1, Props2) ->
equiv_list([], []) ->
true;
equiv_list([V1 | L1], [V2 | L2]) ->
- case equiv(V1, V2) of
- true ->
- equiv_list(L1, L2);
- false ->
- false
- end.
+ equiv(V1, V2) andalso equiv_list(L1, L2).
test_all() ->
test_issue33(),
@@ -12,7 +12,7 @@
-export([test/0]).
% This is a macro to placate syntax highlighters..
--define(Q, $\").%"
+-define(Q, $\").
-define(ADV_COL(S, N), S#decoder{offset=N+S#decoder.offset,
column=N+S#decoder.column}).
-define(INC_COL(S), S#decoder{offset=1+S#decoder.offset,
@@ -354,10 +354,24 @@ tokenize_string_fast(B, O) ->
case B of
<<_:O/binary, ?Q, _/binary>> ->
O;
- <<_:O/binary, C, _/binary>> when C =/= $\\ ->
+ <<_:O/binary, $\\, _/binary>> ->
+ {escape, O};
+ <<_:O/binary, C1, _/binary>> when C1 < 128 ->
tokenize_string_fast(B, 1 + O);
+ <<_:O/binary, C1, C2, _/binary>> when C1 >= 194, C1 =< 223,
+ C2 >= 128, C2 =< 191 ->
+ tokenize_string_fast(B, 2 + O);
+ <<_:O/binary, C1, C2, C3, _/binary>> when C1 >= 224, C1 =< 239,
+ C2 >= 128, C2 =< 191,
+ C3 >= 128, C3 =< 191 ->
+ tokenize_string_fast(B, 3 + O);
+ <<_:O/binary, C1, C2, C3, C4, _/binary>> when C1 >= 240, C1 =< 244,
+ C2 >= 128, C2 =< 191,
+ C3 >= 128, C3 =< 191,
+ C4 >= 128, C4 =< 191 ->
+ tokenize_string_fast(B, 4 + O);
_ ->
- {escape, O}
+ throw(invalid_utf8)
end.
tokenize_string(B, S=#decoder{offset=O}, Acc) ->
@@ -550,17 +564,13 @@ equiv_object(Props1, Props2) ->
equiv_list([], []) ->
true;
equiv_list([V1 | L1], [V2 | L2]) ->
- case equiv(V1, V2) of
- true ->
- equiv_list(L1, L2);
- false ->
- false
- end.
+ equiv(V1, V2) andalso equiv_list(L1, L2).
test_all() ->
[1199344435545.0, 1] = decode(<<"[1199344435545.0,1]">>),
<<16#F0,16#9D,16#9C,16#95>> = decode([34,"\\ud835","\\udf15",34]),
test_encoder_utf8(),
+ test_input_validation(),
test_one(e2j_test_vec(utf8), 1).
test_one([], _N) ->
@@ -626,3 +636,28 @@ test_encoder_utf8() ->
Enc = mochijson2:encoder([{utf8, true}]),
[34,"\\u0001",[209,130],[208,181],[209,129],[209,130],34] =
Enc(<<1,"\321\202\320\265\321\201\321\202">>).
+
+test_input_validation() ->
+ Good = [
+ {16#00A3, <<?Q, 16#C2, 16#A3, ?Q>>}, % pound
+ {16#20AC, <<?Q, 16#E2, 16#82, 16#AC, ?Q>>}, % euro
+ {16#10196, <<?Q, 16#F0, 16#90, 16#86, 16#96, ?Q>>} % denarius
+ ],
+ lists:foreach(fun({CodePoint, UTF8}) ->
+ Expect = list_to_binary(xmerl_ucs:to_utf8(CodePoint)),
+ Expect = decode(UTF8)
+ end, Good),
+
+ Bad = [
+ % 2nd, 3rd, or 4th byte of a multi-byte sequence w/o leading byte
+ <<?Q, 16#80, ?Q>>,
+ % missing continuations, last byte in each should be 80-BF
+ <<?Q, 16#C2, 16#7F, ?Q>>,
+ <<?Q, 16#E0, 16#80,16#7F, ?Q>>,
+ <<?Q, 16#F0, 16#80, 16#80, 16#7F, ?Q>>,
+ % we don't support code points > 10FFFF per RFC 3629
+ <<?Q, 16#F5, 16#80, 16#80, 16#80, ?Q>>
+ ],
+ lists:foreach(fun(X) ->
+ ok = try decode(X) catch invalid_utf8 -> ok end
+ end, Bad).
@@ -27,11 +27,12 @@ set_defaults(Defaults, PropList) ->
lists:foldl(fun set_default/2, PropList, Defaults).
parse_options(Options) ->
- {loop, HttpLoop} = proplists:lookup(loop, Options),
+ WwwLoop = proplists:get_value(loop, Options),
+ WSLoop = proplists:get_value(wsloop, Options),
Loop = fun (S) ->
- ?MODULE:loop(S, HttpLoop)
+ ?MODULE:loop(S, {WwwLoop,WSLoop})
end,
- Options1 = [{loop, Loop} | proplists:delete(loop, Options)],
+ Options1 = [{loop, Loop}, {wsloop, Loop} | proplists:delete(loop, proplists:delete(wsloop, Options))],
set_defaults(?DEFAULTS, Options1).
stop() ->
@@ -124,16 +125,25 @@ headers(Socket, Request, Headers, _Body, ?MAX_HEADERS) ->
Req:respond({400, [], []}),
gen_tcp:close(Socket),
exit(normal);
-headers(Socket, Request, Headers, Body, HeaderCount) ->
+
+headers(Socket, Request, Headers, {WwwLoop, WSLoop}, HeaderCount) ->
case gen_tcp:recv(Socket, 0, ?IDLE_TIMEOUT) of
{ok, http_eoh} ->
- inet:setopts(Socket, [{packet, raw}]),
- Req = mochiweb:new_request({Socket, Request,
- lists:reverse(Headers)}),
- Body(Req),
- ?MODULE:after_response(Body, Req);
+ {_, {abs_path,Path}, _} = Request,
+ case websocket_check(Socket, Path, Headers) of
+ true -> % a websocket request
+ inet:setopts(Socket, [{packet, raw}]),
+ WSRequest = websocket_request:new(Socket,Path),
+ WSLoop(WSRequest);
+ false -> % normal http request
+ inet:setopts(Socket, [{packet, raw}]),
+ Req = mochiweb:new_request({Socket, Request,
+ lists:reverse(Headers)}),
+ WwwLoop(Req),
+ ?MODULE:after_response({WwwLoop, WSLoop}, Req)
+ end;
{ok, {http_header, _, Name, _, Value}} ->
- headers(Socket, Request, [{Name, Value} | Headers], Body,
+ headers(Socket, Request, [{Name, Value} | Headers], {WwwLoop, WSLoop},
1 + HeaderCount);
_Other ->
gen_tcp:close(Socket),
@@ -150,3 +160,22 @@ after_response(Body, Req) ->
Req:cleanup(),
?MODULE:loop(Socket, Body)
end.
+
+websocket_check(Socket,Path,Headers) ->
+ case proplists:get_value('Upgrade',Headers) of
+ "WebSocket" ->
+ websocket_send_handshake(Socket,Path,Headers),
+ true;
+ _Other ->
+ false
+ end.
+
+websocket_send_handshake(Socket,Path,Headers) ->
+ Origin = proplists:get_value("Origin",Headers),
+ Location = proplists:get_value('Host', Headers),
+ Proto = "HTTP/1.1 101 Web Socket Protocol Handshake\r\nUpgrade: WebSocket\r\nConnection: Upgrade\r\n",
+ Resp = Proto ++
+ "WebSocket-Origin: " ++ Origin ++ "\r\n" ++
+ "WebSocket-Location: ws://" ++ Location ++ Path ++ "\r\n\r\n",
+ gen_tcp:send(Socket, Resp).
+
@@ -76,9 +76,9 @@ parse_multipart_request(Req, Callback) ->
Boundary = iolist_to_binary(
get_boundary(Req:get_header_value("content-type"))),
Prefix = <<"\r\n--", Boundary/binary>>,
- BS = size(Boundary),
+ BS = byte_size(Boundary),
Chunk = read_chunk(Req, Length),
- Length1 = Length - size(Chunk),
+ Length1 = Length - byte_size(Chunk),
<<"--", Boundary:BS/binary, "\r\n", Rest/binary>> = Chunk,
feed_mp(headers, flash_multipart_hack(#mp{boundary=Prefix,
length=Length1,
@@ -117,7 +117,7 @@ read_chunk(Req, Length) when Length > 0 ->
read_more(State=#mp{length=Length, buffer=Buffer, req=Req}) ->
Data = read_chunk(Req, Length),
Buffer1 = <<Buffer/binary, Data/binary>>,
- flash_multipart_hack(State#mp{length=Length - size(Data),
+ flash_multipart_hack(State#mp{length=Length - byte_size(Data),
buffer=Buffer1}).
flash_multipart_hack(State=#mp{length=0, buffer=Buffer, boundary=Prefix}) ->
@@ -285,14 +285,12 @@ test_parse3() ->
eof],
TestCallback = fun (Next) -> test_callback(Next, Expect) end,
ServerFun = fun (Socket) ->
- case gen_tcp:send(Socket, BinContent) of
- ok ->
- exit(normal)
- end
+ ok = gen_tcp:send(Socket, BinContent),
+ exit(normal)
end,
ClientFun = fun (Socket) ->
Req = fake_request(Socket, ContentType,
- size(BinContent)),
+ byte_size(BinContent)),
Res = parse_multipart_request(Req, TestCallback),
{0, <<>>, ok} = Res,
ok
@@ -318,14 +316,12 @@ test_parse2() ->
eof],
TestCallback = fun (Next) -> test_callback(Next, Expect) end,
ServerFun = fun (Socket) ->
- case gen_tcp:send(Socket, BinContent) of
- ok ->
- exit(normal)
- end
+ ok = gen_tcp:send(Socket, BinContent),
+ exit(normal)
end,
ClientFun = fun (Socket) ->
Req = fake_request(Socket, ContentType,
- size(BinContent)),
+ byte_size(BinContent)),
Res = parse_multipart_request(Req, TestCallback),
{0, <<>>, ok} = Res,
ok
@@ -351,14 +347,12 @@ test_parse_form() ->
""], "\r\n"),
BinContent = iolist_to_binary(Content),
ServerFun = fun (Socket) ->
- case gen_tcp:send(Socket, BinContent) of
- ok ->
- exit(normal)
- end
+ ok = gen_tcp:send(Socket, BinContent),
+ exit(normal)
end,
ClientFun = fun (Socket) ->
Req = fake_request(Socket, ContentType,
- size(BinContent)),
+ byte_size(BinContent)),
Res = parse_form(Req),
[{"submit-name", "Larry"},
{"files", {"file1.txt", {"text/plain",[]},
@@ -400,14 +394,12 @@ test_parse() ->
eof],
TestCallback = fun (Next) -> test_callback(Next, Expect) end,
ServerFun = fun (Socket) ->
- case gen_tcp:send(Socket, BinContent) of
- ok ->
- exit(normal)
- end
+ ok = gen_tcp:send(Socket, BinContent),
+ exit(normal)
end,
ClientFun = fun (Socket) ->
Req = fake_request(Socket, ContentType,
- size(BinContent)),
+ byte_size(BinContent)),
Res = parse_multipart_request(Req, TestCallback),
{0, <<>>, ok} = Res,
ok
@@ -471,14 +463,12 @@ test_flash_parse() ->
eof],
TestCallback = fun (Next) -> test_callback(Next, Expect) end,
ServerFun = fun (Socket) ->
- case gen_tcp:send(Socket, BinContent) of
- ok ->
- exit(normal)
- end
+ ok = gen_tcp:send(Socket, BinContent),
+ exit(normal)
end,
ClientFun = fun (Socket) ->
Req = fake_request(Socket, ContentType,
- size(BinContent)),
+ byte_size(BinContent)),
Res = parse_multipart_request(Req, TestCallback),
{0, <<>>, ok} = Res,
ok
@@ -515,14 +505,12 @@ test_flash_parse2() ->
eof],
TestCallback = fun (Next) -> test_callback(Next, Expect) end,
ServerFun = fun (Socket) ->
- case gen_tcp:send(Socket, BinContent) of
- ok ->
- exit(normal)
- end
+ ok = gen_tcp:send(Socket, BinContent),
+ exit(normal)
end,
ClientFun = fun (Socket) ->
Req = fake_request(Socket, ContentType,
- size(BinContent)),
+ byte_size(BinContent)),
Res = parse_multipart_request(Req, TestCallback),
{0, <<>>, ok} = Res,
ok
@@ -219,7 +219,7 @@ stream_body(MaxChunkSize, ChunkFun, FunState, MaxBodyLength) ->
MaxBodyLength when is_integer(MaxBodyLength), MaxBodyLength < Length ->
exit({body_too_large, content_length});
_ ->
- stream_unchunked_body(Length, MaxChunkSize, ChunkFun, FunState)
+ stream_unchunked_body(Length, ChunkFun, FunState)
end;
Length ->
exit({length_not_integer, Length})
@@ -454,16 +454,20 @@ stream_chunked_body(MaxChunkSize, Fun, FunState) ->
stream_chunked_body(MaxChunkSize, Fun, NewState)
end.
-stream_unchunked_body(0, _MaxChunkSize, Fun, FunState) ->
+stream_unchunked_body(0, Fun, FunState) ->
Fun({0, <<>>}, FunState);
-stream_unchunked_body(Length, MaxChunkSize, Fun, FunState) when Length > MaxChunkSize ->
- Bin = recv(MaxChunkSize),
- NewState = Fun({MaxChunkSize, Bin}, FunState),
- stream_unchunked_body(Length - MaxChunkSize, MaxChunkSize, Fun, NewState);
-stream_unchunked_body(Length, MaxChunkSize, Fun, FunState) ->
- Bin = recv(Length),
- NewState = Fun({Length, Bin}, FunState),
- stream_unchunked_body(0, MaxChunkSize, Fun, NewState).
+stream_unchunked_body(Length, Fun, FunState) when Length > 0 ->
+ Bin = recv(0),
+ BinSize = byte_size(Bin),
+ if BinSize > Length ->
+ <<OurBody:Length/binary, Extra/binary>> = Bin,
+ gen_tcp:unrecv(Socket, Extra),
+ NewState = Fun({Length, OurBody}, FunState),
+ stream_unchunked_body(0, Fun, NewState);
+ true ->
+ NewState = Fun({BinSize, Bin}, FunState),
+ stream_unchunked_body(Length - BinSize, Fun, NewState)
+ end.
%% @spec read_chunk_length() -> integer()
@@ -676,7 +680,6 @@ range_parts({file, IoDevice}, Ranges) ->
end,
LocNums, Data),
{Bodies, Size};
-
range_parts(Body0, Ranges) ->
Body = iolist_to_binary(Body0),
Size = size(Body),
@@ -743,7 +746,6 @@ test_range() ->
[{none, 20}] = parse_range_request("bytes=-20"),
io:format(".. ok ~n"),
-
%% invalid, single ranges
io:format("Testing parse_range_request with invalid ranges~n"),
io:format("1"),
@@ -771,7 +773,7 @@ test_range() ->
io:format(".. ok~n"),
Body = <<"012345678901234567890123456789012345678901234567890123456789">>,
- BodySize = size(Body), %% 60
+ BodySize = byte_size(Body), %% 60
BodySize = 60,
%% these values assume BodySize =:= 60
@@ -17,6 +17,7 @@
-record(mochiweb_socket_server,
{port,
loop,
+ wsloop,
name=undefined,
max=2048,
ip=any,
@@ -77,6 +78,8 @@ parse_options([{ip, Ip} | Rest], State) ->
parse_options(Rest, State#mochiweb_socket_server{ip=ParsedIp});
parse_options([{loop, Loop} | Rest], State) ->
parse_options(Rest, State#mochiweb_socket_server{loop=Loop});
+parse_options([{wsloop, Loop} | Rest], State) ->
+ parse_options(Rest, State#mochiweb_socket_server{wsloop=Loop});
parse_options([{backlog, Backlog} | Rest], State) ->
parse_options(Rest, State#mochiweb_socket_server{backlog=Backlog});
parse_options([{max, Max} | Rest], State) ->
@@ -423,9 +423,7 @@ record_to_proplist(Record, Fields) ->
%% Fields should be obtained by calling record_info(fields, record_type)
%% where record_type is the record type of Record
record_to_proplist(Record, Fields, TypeKey)
- when is_tuple(Record),
- is_list(Fields),
- size(Record) - 1 =:= length(Fields) ->
+ when tuple_size(Record) - 1 =:= length(Fields) ->
lists:zip([TypeKey | Fields], tuple_to_list(Record)).
Oops, something went wrong.

0 comments on commit 49e5511

Please sign in to comment.