diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs index 8e8c099c661e..fd0217213db5 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs @@ -8,6 +8,7 @@ using System.Net.Http; using System.Text; using System.Text.RegularExpressions; +using System.Threading; namespace Microsoft.PowerShell.Commands { @@ -32,19 +33,21 @@ public class BasicHtmlWebResponseObject : WebResponseObject /// /// Initializes a new instance of the class. /// - /// - public BasicHtmlWebResponseObject(HttpResponseMessage response) : this(response, null) { } + /// The response. + /// Cancellation token. + public BasicHtmlWebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken) { } /// /// Initializes a new instance of the class /// with the specified . /// - /// - /// - public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream) : base(response, contentStream) + /// The response. + /// The content stream associated with the response. + /// Cancellation token. + public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) : base(response, contentStream, cancellationToken) { EnsureHtmlParser(); - InitializeContent(); + InitializeContent(cancellationToken); InitializeRawContent(response); } @@ -159,7 +162,8 @@ public WebCmdletElementCollection Images /// /// Reads the response content from the web response. /// - protected void InitializeContent() + /// The cancellation token. + protected void InitializeContent(CancellationToken cancellationToken) { string contentType = ContentHelper.GetContentType(BaseResponse); if (ContentHelper.IsText(contentType)) @@ -167,7 +171,7 @@ protected void InitializeContent() // Fill the Content buffer string characterSet = WebResponseHelper.GetCharacterSet(BaseResponse); - Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding); + Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding, cancellationToken); Encoding = encoding; } else diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs index 0c97211faea7..9efbce8a7fa8 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs @@ -6,6 +6,7 @@ using System.Management.Automation; using System.Net.Http; using System.Text; +using System.Threading; using System.Xml; using Newtonsoft.Json; @@ -74,12 +75,13 @@ public int MaximumFollowRelLink internal override void ProcessResponse(HttpResponseMessage response) { ArgumentNullException.ThrowIfNull(response); + ArgumentNullException.ThrowIfNull(_cancelToken); - Stream baseResponseStream = StreamHelper.GetResponseStream(response); + Stream baseResponseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token); if (ShouldWriteToPipeline) { - using var responseStream = new BufferingStreamReader(baseResponseStream); + using var responseStream = new BufferingStreamReader(baseResponseStream, _cancelToken.Token); // First see if it is an RSS / ATOM feed, in which case we can // stream it - unless the user has overridden it with a return type of "XML" @@ -95,7 +97,7 @@ internal override void ProcessResponse(HttpResponseMessage response) // Try to get the response encoding from the ContentType header. string charSet = WebResponseHelper.GetCharacterSet(response); - string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding); + string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding, _cancelToken.Token); object obj = null; Exception ex = null; @@ -112,7 +114,7 @@ internal override void ProcessResponse(HttpResponseMessage response) // NOTE: Tests use this verbose output to verify the encoding. WriteVerbose(string.Create(System.Globalization.CultureInfo.InvariantCulture, $"Content encoding: {encodingVerboseName}")); - + bool convertSuccess = false; if (returnType == RestReturnType.Json) @@ -227,7 +229,7 @@ private bool TryProcessFeedStream(Stream responseStream) } } } - catch (XmlException) + catch (XmlException) { // Catch XmlException } @@ -345,17 +347,19 @@ public enum RestReturnType internal class BufferingStreamReader : Stream { - internal BufferingStreamReader(Stream baseStream) + internal BufferingStreamReader(Stream baseStream, CancellationToken cancellationToken) { _baseStream = baseStream; _streamBuffer = new MemoryStream(); _length = long.MaxValue; _copyBuffer = new byte[4096]; + _cancellationToken = cancellationToken; } private readonly Stream _baseStream; private readonly MemoryStream _streamBuffer; private readonly byte[] _copyBuffer; + private readonly CancellationToken _cancellationToken; public override bool CanRead => true; @@ -389,7 +393,7 @@ public override int Read(byte[] buffer, int offset, int count) // If we don't have enough data to fill this from memory, cache more. // We try to read 4096 bytes from base stream every time, so at most we // may cache 4095 bytes more than what is required by the Read operation. - int bytesRead = _baseStream.Read(_copyBuffer, 0, _copyBuffer.Length); + int bytesRead = _baseStream.ReadAsync(_copyBuffer, 0, _copyBuffer.Length, _cancellationToken).GetAwaiter().GetResult(); if (_streamBuffer.Position < _streamBuffer.Length) { diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs index c1df22296e7b..a15209ff7224 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs @@ -611,7 +611,7 @@ protected override void ProcessRecord() string detailMsg = string.Empty; try { - string error = StreamHelper.GetResponseString(response); + string error = StreamHelper.GetResponseString(response, _cancelToken.Token); detailMsg = FormatErrorMessage(error, contentType); } catch @@ -656,6 +656,11 @@ protected override void ProcessRecord() ThrowTerminatingError(er); } + finally + { + _cancelToken?.Dispose(); + _cancelToken = null; + } if (_followRelLink) { diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs index 5a55ccc2e732..8bfdeed2da4b 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs @@ -7,6 +7,7 @@ using System.IO; using System.Net.Http; using System.Text; +using System.Threading; namespace Microsoft.PowerShell.Commands { @@ -74,19 +75,21 @@ public class WebResponseObject /// /// Initializes a new instance of the class. /// - /// - public WebResponseObject(HttpResponseMessage response) : this(response, null) + /// The Http response. + /// The cancellation token. + public WebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken) { } /// /// Initializes a new instance of the class /// with the specified . /// - /// - /// - public WebResponseObject(HttpResponseMessage response, Stream contentStream) + /// Http response. + /// The http content stream. + /// The cancellation token. + public WebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) { - SetResponse(response, contentStream); + SetResponse(response, contentStream, cancellationToken); InitializeContent(); InitializeRawContent(response); } @@ -116,13 +119,13 @@ private void InitializeRawContent(HttpResponseMessage baseResponse) RawContent = raw.ToString(); } - private static bool IsPrintable(char c) => char.IsLetterOrDigit(c) - || char.IsPunctuation(c) - || char.IsSeparator(c) - || char.IsSymbol(c) + private static bool IsPrintable(char c) => char.IsLetterOrDigit(c) + || char.IsPunctuation(c) + || char.IsSeparator(c) + || char.IsSymbol(c) || char.IsWhiteSpace(c); - private void SetResponse(HttpResponseMessage response, Stream contentStream) + private void SetResponse(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) { ArgumentNullException.ThrowIfNull(response); @@ -138,7 +141,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream) Stream st = contentStream; if (contentStream is null) { - st = StreamHelper.GetResponseStream(response); + st = StreamHelper.GetResponseStream(response, cancellationToken); } long contentLength = response.Content.Headers.ContentLength.Value; @@ -148,7 +151,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream) } int initialCapacity = (int)Math.Min(contentLength, StreamHelper.DefaultReadBuffer); - RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault()); + RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault(), cancellationToken); } // Set the position of the content stream to the beginning diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs index e101bcc65b7e..2ad31a9590f0 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs @@ -34,7 +34,7 @@ internal override void ProcessResponse(HttpResponseMessage response) { ArgumentNullException.ThrowIfNull(response); - Stream responseStream = StreamHelper.GetResponseStream(response); + Stream responseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token); if (ShouldWriteToPipeline) { // creating a MemoryStream wrapper to response stream here to support IsStopping. @@ -42,8 +42,9 @@ internal override void ProcessResponse(HttpResponseMessage response) responseStream, StreamHelper.ChunkSize, this, - response.Content.Headers.ContentLength.GetValueOrDefault()); - WebResponseObject ro = WebResponseObjectFactory.GetResponseObject(response, responseStream, this.Context); + response.Content.Headers.ContentLength.GetValueOrDefault(), + _cancelToken.Token); + WebResponseObject ro = WebResponseObjectFactory.GetResponseObject(response, responseStream, this.Context, _cancelToken.Token); ro.RelationLink = _relationLink; WriteObject(ro); diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs index 4c8f1c403219..d5e3b6e6b3b5 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs @@ -9,9 +9,9 @@ namespace Microsoft.PowerShell.Commands { internal static class WebResponseObjectFactory { - internal static WebResponseObject GetResponseObject(HttpResponseMessage response, Stream responseStream, ExecutionContext executionContext) + internal static WebResponseObject GetResponseObject(HttpResponseMessage response, Stream responseStream, ExecutionContext executionContext, System.Threading.CancellationToken cancellationToken) { - WebResponseObject output = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream) : new WebResponseObject(response, responseStream); + WebResponseObject output = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream, cancellationToken) : new WebResponseObject(response, responseStream, cancellationToken); return output; } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs index 1ddaf5a57944..5b0387c44862 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs @@ -25,8 +25,9 @@ internal class WebResponseContentMemoryStream : MemoryStream private readonly long? _contentLength; private readonly Stream _originalStreamToProxy; - private bool _isInitialized = false; private readonly Cmdlet _ownerCmdlet; + private readonly CancellationToken _cancellationToken; + private bool _isInitialized = false; #endregion Data @@ -34,15 +35,17 @@ internal class WebResponseContentMemoryStream : MemoryStream /// /// Initializes a new instance of the class. /// - /// - /// + /// Response stream. + /// Presize the memory stream. /// Owner cmdlet if any. /// Expected download size in Bytes. - internal WebResponseContentMemoryStream(Stream stream, int initialCapacity, Cmdlet cmdlet, long? contentLength) : base(initialCapacity) + /// Cancellation token. + internal WebResponseContentMemoryStream(Stream stream, int initialCapacity, Cmdlet cmdlet, long? contentLength, CancellationToken cancellationToken) : base(initialCapacity) { this._contentLength = contentLength; _originalStreamToProxy = stream; _ownerCmdlet = cmdlet; + _cancellationToken = cancellationToken; } #endregion Constructors @@ -77,7 +80,7 @@ public override long Length /// public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) { - Initialize(); + Initialize(cancellationToken); return base.CopyToAsync(destination, bufferSize, cancellationToken); } @@ -102,7 +105,7 @@ public override int Read(byte[] buffer, int offset, int count) /// public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - Initialize(); + Initialize(cancellationToken); return base.ReadAsync(buffer, offset, count, cancellationToken); } @@ -153,7 +156,7 @@ public override void Write(byte[] buffer, int offset, int count) /// public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - Initialize(); + Initialize(cancellationToken); return base.WriteAsync(buffer, offset, count, cancellationToken); } @@ -182,15 +185,18 @@ protected override void Dispose(bool disposing) base.Dispose(disposing); } - /// - /// - private void Initialize() + private void Initialize(CancellationToken cancellationToken = default) { - if (_isInitialized) + if (_isInitialized) { return; } + if (cancellationToken == default) + { + cancellationToken = _cancellationToken; + } + _isInitialized = true; try { @@ -220,7 +226,7 @@ private void Initialize() } } - read = _originalStreamToProxy.Read(buffer, 0, buffer.Length); + read = _originalStreamToProxy.ReadAsync(buffer, 0, buffer.Length, cancellationToken).GetAwaiter().GetResult(); if (read > 0) { @@ -237,11 +243,11 @@ private void Initialize() // Make sure the length is set appropriately base.SetLength(totalRead); - base.Seek(0, SeekOrigin.Begin); + Seek(0, SeekOrigin.Begin); } catch (Exception) { - base.Dispose(); + Dispose(); throw; } } @@ -329,7 +335,7 @@ internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet c WriteToStream(stream, output, cmdlet, contentLength, cancellationToken); } - private static string StreamToString(Stream stream, Encoding encoding) + private static string StreamToString(Stream stream, Encoding encoding, CancellationToken cancellationToken) { StringBuilder result = new(capacity: ChunkSize); Decoder decoder = encoding.GetDecoder(); @@ -347,7 +353,7 @@ private static string StreamToString(Stream stream, Encoding encoding) { // Read at most the number of bytes that will fit in the input buffer. The // return value is the actual number of bytes read, or zero if no bytes remain. - bytesRead = stream.Read(bytes, 0, useBufferSize * 4); + bytesRead = stream.ReadAsync(bytes, 0, useBufferSize * 4, cancellationToken).GetAwaiter().GetResult(); bool completed = false; int byteIndex = 0; @@ -355,10 +361,8 @@ private static string StreamToString(Stream stream, Encoding encoding) while (!completed) { // If this is the last input data, flush the decoder's internal buffer and state. - bool flush = (bytesRead == 0); - decoder.Convert(bytes, byteIndex, bytesRead - byteIndex, - chars, 0, useBufferSize, flush, - out int bytesUsed, out int charsUsed, out completed); + bool flush = bytesRead == 0; + decoder.Convert(bytes, byteIndex, bytesRead - byteIndex, chars, 0, useBufferSize, flush, out int bytesUsed, out int charsUsed, out completed); // The conversion produced the number of characters indicated by charsUsed. Write that number // of characters to our result buffer @@ -376,12 +380,13 @@ private static string StreamToString(Stream stream, Encoding encoding) break; } } - } while (bytesRead != 0); + } + while (bytesRead != 0); return result.ToString(); } - internal static string DecodeStream(Stream stream, string characterSet, out Encoding encoding) + internal static string DecodeStream(Stream stream, string characterSet, out Encoding encoding, CancellationToken cancellationToken) { bool isDefaultEncoding = false; if (!TryGetEncoding(characterSet, out encoding)) @@ -391,7 +396,7 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco isDefaultEncoding = true; } - string content = StreamToString(stream, encoding); + string content = StreamToString(stream, encoding, cancellationToken); if (isDefaultEncoding) { // We only look within the first 1k characters as the meta element and @@ -400,13 +405,13 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco // Check for a charset attribute on the meta element to override the default Match match = s_metaRegex.Match(substring); - + // Check for a encoding attribute on the xml declaration to override the default if (!match.Success) { match = s_xmlRegex.Match(substring); } - + if (match.Success) { characterSet = match.Groups["charset"].Value; @@ -414,7 +419,7 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco if (TryGetEncoding(characterSet, out Encoding localEncoding)) { stream.Seek(0, SeekOrigin.Begin); - content = StreamToString(stream, localEncoding); + content = StreamToString(stream, localEncoding, cancellationToken); encoding = localEncoding; } } @@ -443,11 +448,11 @@ internal static bool TryGetEncoding(string characterSet, out Encoding encoding) @"<]*charset\s*=\s*[""'\n]?(?[A-Za-z].[^\s""'\n<>]*)[\s""'\n>]", RegexOptions.Compiled | RegexOptions.Singleline | RegexOptions.ExplicitCapture | RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.NonBacktracking ); - + private static readonly Regex s_xmlRegex = new( @"<\?xml\s.*[^.><]*encoding\s*=\s*[""'\n]?(?[A-Za-z].[^\s""'\n<>]*)[\s""'\n>]", RegexOptions.Compiled | RegexOptions.Singleline | RegexOptions.ExplicitCapture | RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.NonBacktracking - ); + ); internal static byte[] EncodeToBytes(string str, Encoding encoding) { @@ -457,9 +462,9 @@ internal static byte[] EncodeToBytes(string str, Encoding encoding) return encoding.GetBytes(str); } - internal static string GetResponseString(HttpResponseMessage response) => response.Content.ReadAsStringAsync().GetAwaiter().GetResult(); + internal static string GetResponseString(HttpResponseMessage response, CancellationToken cancellationToken) => response.Content.ReadAsStringAsync(cancellationToken).GetAwaiter().GetResult(); - internal static Stream GetResponseStream(HttpResponseMessage response) => response.Content.ReadAsStreamAsync().GetAwaiter().GetResult(); + internal static Stream GetResponseStream(HttpResponseMessage response, CancellationToken cancellationToken) => response.Content.ReadAsStreamAsync(cancellationToken).GetAwaiter().GetResult(); #endregion Static Methods }