diff --git a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BinaryTextWriter.cs b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BinaryTextWriter.cs new file mode 100644 index 0000000000..a5f958b471 --- /dev/null +++ b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BinaryTextWriter.cs @@ -0,0 +1,34 @@ +using System; +using Microsoft.AspNet.SignalR.Hosting; + +namespace Microsoft.AspNet.SignalR.Infrastructure +{ + /// + /// A buffering text writer that supports writing binary directly as well + /// + internal unsafe class BinaryTextWriter : BufferTextWriter, IBinaryWriter + { + public BinaryTextWriter(IResponse response) : + base((data, state) => ((IResponse)state).Write(data), response, reuseBuffers: true, bufferSize: 128) + { + + } + + public BinaryTextWriter(IWebSocket socket) : + base((data, state) => ((IWebSocket)state).SendChunk(data), socket, reuseBuffers: false, bufferSize: 1024) + { + + } + + + public BinaryTextWriter(Action, object> write, object state, bool reuseBuffers, int bufferSize) : + base(write, state, reuseBuffers, bufferSize) + { + } + + public void Write(ArraySegment data) + { + Writer.Write(data); + } + } +} diff --git a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BufferTextWriter.cs b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BufferTextWriter.cs index 2cf39d536e..dbf48b7f52 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BufferTextWriter.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/BufferTextWriter.cs @@ -13,7 +13,7 @@ namespace Microsoft.AspNet.SignalR.Infrastructure /// we don't need to write to a long lived buffer. This saves massive amounts of memory /// as the number of connections grows. /// - internal unsafe class BufferTextWriter : TextWriter, IBinaryWriter + internal abstract unsafe class BufferTextWriter : TextWriter { private readonly Encoding _encoding; @@ -37,7 +37,7 @@ internal unsafe class BufferTextWriter : TextWriter, IBinaryWriter } [SuppressMessage("Microsoft.Globalization", "CA1305:SpecifyIFormatProvider", MessageId = "System.IO.TextWriter.#ctor", Justification = "It won't be used")] - public BufferTextWriter(Action, object> write, object state, bool reuseBuffers, int bufferSize) + protected BufferTextWriter(Action, object> write, object state, bool reuseBuffers, int bufferSize) { _write = write; _writeState = state; @@ -46,7 +46,7 @@ public BufferTextWriter(Action, object> write, object state, _bufferSize = bufferSize; } - private ChunkedWriter Writer + protected internal ChunkedWriter Writer { get { @@ -79,17 +79,12 @@ public override void Write(char value) Writer.Write(value); } - public void Write(ArraySegment data) - { - Writer.Write(data); - } - public override void Flush() { Writer.Flush(); } - private class ChunkedWriter + internal class ChunkedWriter { private int _charPos; private int _charLen; diff --git a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/Connection.cs b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/Connection.cs index bb0782de69..e0eb39ca53 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Infrastructure/Connection.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Infrastructure/Connection.cs @@ -145,7 +145,7 @@ private ArraySegment GetMessageBuffer(object value) { using (var stream = new MemoryStream(128)) { - var bufferWriter = new BufferTextWriter((buffer, state) => + var bufferWriter = new BinaryTextWriter((buffer, state) => { ((MemoryStream)state).Write(buffer.Array, buffer.Offset, buffer.Count); }, diff --git a/src/Microsoft.AspNet.SignalR.Core/Microsoft.AspNet.SignalR.Core.csproj b/src/Microsoft.AspNet.SignalR.Core/Microsoft.AspNet.SignalR.Core.csproj index 732f2bab64..20038e7219 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Microsoft.AspNet.SignalR.Core.csproj +++ b/src/Microsoft.AspNet.SignalR.Core/Microsoft.AspNet.SignalR.Core.csproj @@ -67,6 +67,7 @@ + diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/PersistentResponse.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/PersistentResponse.cs index 17c1bd2d1d..aba666fa20 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/PersistentResponse.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/PersistentResponse.cs @@ -20,7 +20,7 @@ public sealed class PersistentResponse : IJsonWritable private readonly Action _writeCursor; public PersistentResponse() - : this(message => true, writer => { }) + : this(message => false, writer => { }) { } diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs index 4b2c939c7e..f5f343e096 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/TransportDisconnectBase.cs @@ -186,7 +186,7 @@ public Uri Url protected virtual TextWriter CreateResponseWriter() { - return new BufferTextWriter(Context.Response); + return new BinaryTextWriter(Context.Response); } protected void IncrementErrors() diff --git a/src/Microsoft.AspNet.SignalR.Core/Transports/WebSocketTransport.cs b/src/Microsoft.AspNet.SignalR.Core/Transports/WebSocketTransport.cs index 66b4838e3b..872e0f63dc 100644 --- a/src/Microsoft.AspNet.SignalR.Core/Transports/WebSocketTransport.cs +++ b/src/Microsoft.AspNet.SignalR.Core/Transports/WebSocketTransport.cs @@ -95,7 +95,7 @@ public override Task ProcessRequest(ITransportConnection connection) protected override TextWriter CreateResponseWriter() { - return new BufferTextWriter(_socket); + return new BinaryTextWriter(_socket); } public override Task Send(object value) diff --git a/tests/Microsoft.AspNet.SignalR.Tests/BufferTextWriterFacts.cs b/tests/Microsoft.AspNet.SignalR.Tests/BufferTextWriterFacts.cs index d1444d13f4..73df0b0de6 100644 --- a/tests/Microsoft.AspNet.SignalR.Tests/BufferTextWriterFacts.cs +++ b/tests/Microsoft.AspNet.SignalR.Tests/BufferTextWriterFacts.cs @@ -13,7 +13,7 @@ public class BufferTextWriterFacts public void CanEncodingSurrogatePairsCorrectly() { var bytes = new List(); - var writer = new BufferTextWriter((buffer, state) => + var writer = new BinaryTextWriter((buffer, state) => { for (int i = buffer.Offset; i < buffer.Count; i++) { @@ -34,7 +34,7 @@ public void CanEncodingSurrogatePairsCorrectly() public void WriteNewBufferIsUsedForWritingChunksIfReuseBuffersFalse() { var buffers = new List>(); - var writer = new BufferTextWriter((buffer, state) => + var writer = new BinaryTextWriter((buffer, state) => { buffers.Add(buffer); }, @@ -55,7 +55,7 @@ public void WriteNewBufferIsUsedForWritingChunksIfReuseBuffersFalse() public void WriteSameBufferIsUsedForWritingChunksIfReuseBuffersTrue() { var buffers = new List>(); - var writer = new BufferTextWriter((buffer, state) => + var writer = new BinaryTextWriter((buffer, state) => { buffers.Add(buffer); }, @@ -79,7 +79,7 @@ public void WritesInChunks() int size = 3000; var buffers = new List>(); - var writer = new BufferTextWriter((buffer, state) => + var writer = new BinaryTextWriter((buffer, state) => { buffers.Add(buffer); }, @@ -121,7 +121,7 @@ private IEnumerable GetChunks(int size, int bufferSize) public void CanInterleaveStringsAndRawBinary() { var buffers = new List>(); - var writer = new BufferTextWriter((buffer, state) => + var writer = new BinaryTextWriter((buffer, state) => { buffers.Add(buffer); }, diff --git a/tests/Microsoft.AspNet.SignalR.Tests/Core/Transports/ForeverFrameTransportFacts.cs b/tests/Microsoft.AspNet.SignalR.Tests/Core/Transports/ForeverFrameTransportFacts.cs index e5a4e20db7..e4b94cf87a 100644 --- a/tests/Microsoft.AspNet.SignalR.Tests/Core/Transports/ForeverFrameTransportFacts.cs +++ b/tests/Microsoft.AspNet.SignalR.Tests/Core/Transports/ForeverFrameTransportFacts.cs @@ -1,9 +1,12 @@ using System; +using System.Collections.Generic; using System.Collections.Specialized; using System.IO; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.SignalR.Hosting; +using Microsoft.AspNet.SignalR.Messaging; using Microsoft.AspNet.SignalR.Transports; using Moq; using Xunit; @@ -13,17 +16,31 @@ namespace Microsoft.AspNet.SignalR.Tests.Core { public class ForeverFrameTransportFacts { - [Fact] - public void ForeverFrameTransportEscapesTags() + [Theory] + [InlineData("", "\\u003c/sCRiPT\\u003e")] + [InlineData("", "\\u003c/SCRIPT dosomething='false'\\u003e")] + [InlineData("

ELLO

", "\\u003cp\\u003eELLO\\u003c/p\\u003e")] + public void ForeverFrameTransportEscapesTags(string data, string expected) + { + var request = new Mock(); + var response = new CustomResponse(); + var context = new HostContext(request.Object, response); + var fft = new ForeverFrameTransport(context, new DefaultDependencyResolver()); + + AssertEscaped(fft, response, data, expected); + } + + [Theory] + [InlineData("", "\\u003cscript type=\"\"\\u003e\\u003c/script\\u003e")] + [InlineData("", "\\u003cscript type=''\\u003e\\u003c/script\\u003e")] + public void ForeverFrameTransportEscapesTagsWithPersistentResponse(string data, string expected) { var request = new Mock(); var response = new CustomResponse(); var context = new HostContext(request.Object, response); var fft = new ForeverFrameTransport(context, new DefaultDependencyResolver()); - AssertEscaped(fft, response, "", "\\u003c/sCRiPT\\u003e"); - AssertEscaped(fft, response, "", "\\u003c/SCRIPT dosomething='false'\\u003e"); - AssertEscaped(fft, response, "

ELLO

", "\\u003cp\\u003eELLO\\u003c/p\\u003e"); + AssertEscaped(fft, response, GetWrappedResponse(data), expected); } [Theory] @@ -42,7 +59,7 @@ public void ForeverFrameTransportThrowsOnInvalidFrameId(string frameId) var context = new HostContext(request.Object, response); var connection = new Mock(); var fft = new ForeverFrameTransport(context, new DefaultDependencyResolver()); - + Assert.Throws(typeof(InvalidOperationException), () => fft.InitializeResponse(connection.Object)); } @@ -62,7 +79,7 @@ public void ForeverFrameTransportSetsCorrectContentType() Assert.Equal("text/html; charset=UTF-8", response.ContentType); } - private static void AssertEscaped(ForeverFrameTransport fft, CustomResponse response, string input, string expectedOutput) + private static void AssertEscaped(ForeverFrameTransport fft, CustomResponse response, object input, string expectedOutput) { fft.Send(input).Wait(); @@ -73,6 +90,22 @@ private static void AssertEscaped(ForeverFrameTransport fft, CustomResponse resp Assert.True(rawResponse.Contains(expectedOutput)); } + private static PersistentResponse GetWrappedResponse(string raw) + { + var data = Encoding.Default.GetBytes(raw); + var message = new Message("foo", "key", new ArraySegment(data)); + + var response = new PersistentResponse + { + Messages = new List> + { + new ArraySegment(new Message[] { message }) + } + }; + + return response; + } + private class CustomResponse : IResponse { private MemoryStream _stream;