diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs
index 8e8c099c661e..fd0217213db5 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/BasicHtmlWebResponseObject.Common.cs
@@ -8,6 +8,7 @@
using System.Net.Http;
using System.Text;
using System.Text.RegularExpressions;
+using System.Threading;
namespace Microsoft.PowerShell.Commands
{
@@ -32,19 +33,21 @@ public class BasicHtmlWebResponseObject : WebResponseObject
///
/// Initializes a new instance of the class.
///
- ///
- public BasicHtmlWebResponseObject(HttpResponseMessage response) : this(response, null) { }
+ /// The response.
+ /// Cancellation token.
+ public BasicHtmlWebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken) { }
///
/// Initializes a new instance of the class
/// with the specified .
///
- ///
- ///
- public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream) : base(response, contentStream)
+ /// The response.
+ /// The content stream associated with the response.
+ /// Cancellation token.
+ public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) : base(response, contentStream, cancellationToken)
{
EnsureHtmlParser();
- InitializeContent();
+ InitializeContent(cancellationToken);
InitializeRawContent(response);
}
@@ -159,7 +162,8 @@ public WebCmdletElementCollection Images
///
/// Reads the response content from the web response.
///
- protected void InitializeContent()
+ /// The cancellation token.
+ protected void InitializeContent(CancellationToken cancellationToken)
{
string contentType = ContentHelper.GetContentType(BaseResponse);
if (ContentHelper.IsText(contentType))
@@ -167,7 +171,7 @@ protected void InitializeContent()
// Fill the Content buffer
string characterSet = WebResponseHelper.GetCharacterSet(BaseResponse);
- Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding);
+ Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding, cancellationToken);
Encoding = encoding;
}
else
diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs
index 0c97211faea7..9efbce8a7fa8 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs
@@ -6,6 +6,7 @@
using System.Management.Automation;
using System.Net.Http;
using System.Text;
+using System.Threading;
using System.Xml;
using Newtonsoft.Json;
@@ -74,12 +75,13 @@ public int MaximumFollowRelLink
internal override void ProcessResponse(HttpResponseMessage response)
{
ArgumentNullException.ThrowIfNull(response);
+ ArgumentNullException.ThrowIfNull(_cancelToken);
- Stream baseResponseStream = StreamHelper.GetResponseStream(response);
+ Stream baseResponseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token);
if (ShouldWriteToPipeline)
{
- using var responseStream = new BufferingStreamReader(baseResponseStream);
+ using var responseStream = new BufferingStreamReader(baseResponseStream, _cancelToken.Token);
// First see if it is an RSS / ATOM feed, in which case we can
// stream it - unless the user has overridden it with a return type of "XML"
@@ -95,7 +97,7 @@ internal override void ProcessResponse(HttpResponseMessage response)
// Try to get the response encoding from the ContentType header.
string charSet = WebResponseHelper.GetCharacterSet(response);
- string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding);
+ string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding, _cancelToken.Token);
object obj = null;
Exception ex = null;
@@ -112,7 +114,7 @@ internal override void ProcessResponse(HttpResponseMessage response)
// NOTE: Tests use this verbose output to verify the encoding.
WriteVerbose(string.Create(System.Globalization.CultureInfo.InvariantCulture, $"Content encoding: {encodingVerboseName}"));
-
+
bool convertSuccess = false;
if (returnType == RestReturnType.Json)
@@ -227,7 +229,7 @@ private bool TryProcessFeedStream(Stream responseStream)
}
}
}
- catch (XmlException)
+ catch (XmlException)
{
// Catch XmlException
}
@@ -345,17 +347,19 @@ public enum RestReturnType
internal class BufferingStreamReader : Stream
{
- internal BufferingStreamReader(Stream baseStream)
+ internal BufferingStreamReader(Stream baseStream, CancellationToken cancellationToken)
{
_baseStream = baseStream;
_streamBuffer = new MemoryStream();
_length = long.MaxValue;
_copyBuffer = new byte[4096];
+ _cancellationToken = cancellationToken;
}
private readonly Stream _baseStream;
private readonly MemoryStream _streamBuffer;
private readonly byte[] _copyBuffer;
+ private readonly CancellationToken _cancellationToken;
public override bool CanRead => true;
@@ -389,7 +393,7 @@ public override int Read(byte[] buffer, int offset, int count)
// If we don't have enough data to fill this from memory, cache more.
// We try to read 4096 bytes from base stream every time, so at most we
// may cache 4095 bytes more than what is required by the Read operation.
- int bytesRead = _baseStream.Read(_copyBuffer, 0, _copyBuffer.Length);
+ int bytesRead = _baseStream.ReadAsync(_copyBuffer, 0, _copyBuffer.Length, _cancellationToken).GetAwaiter().GetResult();
if (_streamBuffer.Position < _streamBuffer.Length)
{
diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs
index c1df22296e7b..a15209ff7224 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs
@@ -611,7 +611,7 @@ protected override void ProcessRecord()
string detailMsg = string.Empty;
try
{
- string error = StreamHelper.GetResponseString(response);
+ string error = StreamHelper.GetResponseString(response, _cancelToken.Token);
detailMsg = FormatErrorMessage(error, contentType);
}
catch
@@ -656,6 +656,11 @@ protected override void ProcessRecord()
ThrowTerminatingError(er);
}
+ finally
+ {
+ _cancelToken?.Dispose();
+ _cancelToken = null;
+ }
if (_followRelLink)
{
diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs
index 5a55ccc2e732..8bfdeed2da4b 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebResponseObject.Common.cs
@@ -7,6 +7,7 @@
using System.IO;
using System.Net.Http;
using System.Text;
+using System.Threading;
namespace Microsoft.PowerShell.Commands
{
@@ -74,19 +75,21 @@ public class WebResponseObject
///
/// Initializes a new instance of the class.
///
- ///
- public WebResponseObject(HttpResponseMessage response) : this(response, null)
+ /// The Http response.
+ /// The cancellation token.
+ public WebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken)
{ }
///
/// Initializes a new instance of the class
/// with the specified .
///
- ///
- ///
- public WebResponseObject(HttpResponseMessage response, Stream contentStream)
+ /// Http response.
+ /// The http content stream.
+ /// The cancellation token.
+ public WebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken)
{
- SetResponse(response, contentStream);
+ SetResponse(response, contentStream, cancellationToken);
InitializeContent();
InitializeRawContent(response);
}
@@ -116,13 +119,13 @@ private void InitializeRawContent(HttpResponseMessage baseResponse)
RawContent = raw.ToString();
}
- private static bool IsPrintable(char c) => char.IsLetterOrDigit(c)
- || char.IsPunctuation(c)
- || char.IsSeparator(c)
- || char.IsSymbol(c)
+ private static bool IsPrintable(char c) => char.IsLetterOrDigit(c)
+ || char.IsPunctuation(c)
+ || char.IsSeparator(c)
+ || char.IsSymbol(c)
|| char.IsWhiteSpace(c);
- private void SetResponse(HttpResponseMessage response, Stream contentStream)
+ private void SetResponse(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(response);
@@ -138,7 +141,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream)
Stream st = contentStream;
if (contentStream is null)
{
- st = StreamHelper.GetResponseStream(response);
+ st = StreamHelper.GetResponseStream(response, cancellationToken);
}
long contentLength = response.Content.Headers.ContentLength.Value;
@@ -148,7 +151,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream)
}
int initialCapacity = (int)Math.Min(contentLength, StreamHelper.DefaultReadBuffer);
- RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault());
+ RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault(), cancellationToken);
}
// Set the position of the content stream to the beginning
diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs
index e101bcc65b7e..2ad31a9590f0 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs
@@ -34,7 +34,7 @@ internal override void ProcessResponse(HttpResponseMessage response)
{
ArgumentNullException.ThrowIfNull(response);
- Stream responseStream = StreamHelper.GetResponseStream(response);
+ Stream responseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token);
if (ShouldWriteToPipeline)
{
// creating a MemoryStream wrapper to response stream here to support IsStopping.
@@ -42,8 +42,9 @@ internal override void ProcessResponse(HttpResponseMessage response)
responseStream,
StreamHelper.ChunkSize,
this,
- response.Content.Headers.ContentLength.GetValueOrDefault());
- WebResponseObject ro = WebResponseObjectFactory.GetResponseObject(response, responseStream, this.Context);
+ response.Content.Headers.ContentLength.GetValueOrDefault(),
+ _cancelToken.Token);
+ WebResponseObject ro = WebResponseObjectFactory.GetResponseObject(response, responseStream, this.Context, _cancelToken.Token);
ro.RelationLink = _relationLink;
WriteObject(ro);
diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs
index 4c8f1c403219..d5e3b6e6b3b5 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/WebResponseObjectFactory.CoreClr.cs
@@ -9,9 +9,9 @@ namespace Microsoft.PowerShell.Commands
{
internal static class WebResponseObjectFactory
{
- internal static WebResponseObject GetResponseObject(HttpResponseMessage response, Stream responseStream, ExecutionContext executionContext)
+ internal static WebResponseObject GetResponseObject(HttpResponseMessage response, Stream responseStream, ExecutionContext executionContext, System.Threading.CancellationToken cancellationToken)
{
- WebResponseObject output = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream) : new WebResponseObject(response, responseStream);
+ WebResponseObject output = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream, cancellationToken) : new WebResponseObject(response, responseStream, cancellationToken);
return output;
}
diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs
index 1ddaf5a57944..5b0387c44862 100644
--- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs
+++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs
@@ -25,8 +25,9 @@ internal class WebResponseContentMemoryStream : MemoryStream
private readonly long? _contentLength;
private readonly Stream _originalStreamToProxy;
- private bool _isInitialized = false;
private readonly Cmdlet _ownerCmdlet;
+ private readonly CancellationToken _cancellationToken;
+ private bool _isInitialized = false;
#endregion Data
@@ -34,15 +35,17 @@ internal class WebResponseContentMemoryStream : MemoryStream
///
/// Initializes a new instance of the class.
///
- ///
- ///
+ /// Response stream.
+ /// Presize the memory stream.
/// Owner cmdlet if any.
/// Expected download size in Bytes.
- internal WebResponseContentMemoryStream(Stream stream, int initialCapacity, Cmdlet cmdlet, long? contentLength) : base(initialCapacity)
+ /// Cancellation token.
+ internal WebResponseContentMemoryStream(Stream stream, int initialCapacity, Cmdlet cmdlet, long? contentLength, CancellationToken cancellationToken) : base(initialCapacity)
{
this._contentLength = contentLength;
_originalStreamToProxy = stream;
_ownerCmdlet = cmdlet;
+ _cancellationToken = cancellationToken;
}
#endregion Constructors
@@ -77,7 +80,7 @@ public override long Length
///
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
- Initialize();
+ Initialize(cancellationToken);
return base.CopyToAsync(destination, bufferSize, cancellationToken);
}
@@ -102,7 +105,7 @@ public override int Read(byte[] buffer, int offset, int count)
///
public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
- Initialize();
+ Initialize(cancellationToken);
return base.ReadAsync(buffer, offset, count, cancellationToken);
}
@@ -153,7 +156,7 @@ public override void Write(byte[] buffer, int offset, int count)
///
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
- Initialize();
+ Initialize(cancellationToken);
return base.WriteAsync(buffer, offset, count, cancellationToken);
}
@@ -182,15 +185,18 @@ protected override void Dispose(bool disposing)
base.Dispose(disposing);
}
- ///
- ///
- private void Initialize()
+ private void Initialize(CancellationToken cancellationToken = default)
{
- if (_isInitialized)
+ if (_isInitialized)
{
return;
}
+ if (cancellationToken == default)
+ {
+ cancellationToken = _cancellationToken;
+ }
+
_isInitialized = true;
try
{
@@ -220,7 +226,7 @@ private void Initialize()
}
}
- read = _originalStreamToProxy.Read(buffer, 0, buffer.Length);
+ read = _originalStreamToProxy.ReadAsync(buffer, 0, buffer.Length, cancellationToken).GetAwaiter().GetResult();
if (read > 0)
{
@@ -237,11 +243,11 @@ private void Initialize()
// Make sure the length is set appropriately
base.SetLength(totalRead);
- base.Seek(0, SeekOrigin.Begin);
+ Seek(0, SeekOrigin.Begin);
}
catch (Exception)
{
- base.Dispose();
+ Dispose();
throw;
}
}
@@ -329,7 +335,7 @@ internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet c
WriteToStream(stream, output, cmdlet, contentLength, cancellationToken);
}
- private static string StreamToString(Stream stream, Encoding encoding)
+ private static string StreamToString(Stream stream, Encoding encoding, CancellationToken cancellationToken)
{
StringBuilder result = new(capacity: ChunkSize);
Decoder decoder = encoding.GetDecoder();
@@ -347,7 +353,7 @@ private static string StreamToString(Stream stream, Encoding encoding)
{
// Read at most the number of bytes that will fit in the input buffer. The
// return value is the actual number of bytes read, or zero if no bytes remain.
- bytesRead = stream.Read(bytes, 0, useBufferSize * 4);
+ bytesRead = stream.ReadAsync(bytes, 0, useBufferSize * 4, cancellationToken).GetAwaiter().GetResult();
bool completed = false;
int byteIndex = 0;
@@ -355,10 +361,8 @@ private static string StreamToString(Stream stream, Encoding encoding)
while (!completed)
{
// If this is the last input data, flush the decoder's internal buffer and state.
- bool flush = (bytesRead == 0);
- decoder.Convert(bytes, byteIndex, bytesRead - byteIndex,
- chars, 0, useBufferSize, flush,
- out int bytesUsed, out int charsUsed, out completed);
+ bool flush = bytesRead == 0;
+ decoder.Convert(bytes, byteIndex, bytesRead - byteIndex, chars, 0, useBufferSize, flush, out int bytesUsed, out int charsUsed, out completed);
// The conversion produced the number of characters indicated by charsUsed. Write that number
// of characters to our result buffer
@@ -376,12 +380,13 @@ private static string StreamToString(Stream stream, Encoding encoding)
break;
}
}
- } while (bytesRead != 0);
+ }
+ while (bytesRead != 0);
return result.ToString();
}
- internal static string DecodeStream(Stream stream, string characterSet, out Encoding encoding)
+ internal static string DecodeStream(Stream stream, string characterSet, out Encoding encoding, CancellationToken cancellationToken)
{
bool isDefaultEncoding = false;
if (!TryGetEncoding(characterSet, out encoding))
@@ -391,7 +396,7 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco
isDefaultEncoding = true;
}
- string content = StreamToString(stream, encoding);
+ string content = StreamToString(stream, encoding, cancellationToken);
if (isDefaultEncoding)
{
// We only look within the first 1k characters as the meta element and
@@ -400,13 +405,13 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco
// Check for a charset attribute on the meta element to override the default
Match match = s_metaRegex.Match(substring);
-
+
// Check for a encoding attribute on the xml declaration to override the default
if (!match.Success)
{
match = s_xmlRegex.Match(substring);
}
-
+
if (match.Success)
{
characterSet = match.Groups["charset"].Value;
@@ -414,7 +419,7 @@ internal static string DecodeStream(Stream stream, string characterSet, out Enco
if (TryGetEncoding(characterSet, out Encoding localEncoding))
{
stream.Seek(0, SeekOrigin.Begin);
- content = StreamToString(stream, localEncoding);
+ content = StreamToString(stream, localEncoding, cancellationToken);
encoding = localEncoding;
}
}
@@ -443,11 +448,11 @@ internal static bool TryGetEncoding(string characterSet, out Encoding encoding)
@"<]*charset\s*=\s*[""'\n]?(?[A-Za-z].[^\s""'\n<>]*)[\s""'\n>]",
RegexOptions.Compiled | RegexOptions.Singleline | RegexOptions.ExplicitCapture | RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.NonBacktracking
);
-
+
private static readonly Regex s_xmlRegex = new(
@"<\?xml\s.*[^.><]*encoding\s*=\s*[""'\n]?(?[A-Za-z].[^\s""'\n<>]*)[\s""'\n>]",
RegexOptions.Compiled | RegexOptions.Singleline | RegexOptions.ExplicitCapture | RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.NonBacktracking
- );
+ );
internal static byte[] EncodeToBytes(string str, Encoding encoding)
{
@@ -457,9 +462,9 @@ internal static byte[] EncodeToBytes(string str, Encoding encoding)
return encoding.GetBytes(str);
}
- internal static string GetResponseString(HttpResponseMessage response) => response.Content.ReadAsStringAsync().GetAwaiter().GetResult();
+ internal static string GetResponseString(HttpResponseMessage response, CancellationToken cancellationToken) => response.Content.ReadAsStringAsync(cancellationToken).GetAwaiter().GetResult();
- internal static Stream GetResponseStream(HttpResponseMessage response) => response.Content.ReadAsStreamAsync().GetAwaiter().GetResult();
+ internal static Stream GetResponseStream(HttpResponseMessage response, CancellationToken cancellationToken) => response.Content.ReadAsStreamAsync(cancellationToken).GetAwaiter().GetResult();
#endregion Static Methods
}