Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invoke-RestMethod/WebRequest: Support CTRL-C when reading data using cancellation token #19315

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Net.Http;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;

namespace Microsoft.PowerShell.Commands
{
Expand All @@ -32,19 +33,21 @@ public class BasicHtmlWebResponseObject : WebResponseObject
/// <summary>
/// Initializes a new instance of the <see cref="BasicHtmlWebResponseObject"/> class.
/// </summary>
/// <param name="response"></param>
public BasicHtmlWebResponseObject(HttpResponseMessage response) : this(response, null) { }
/// <param name="response">The response.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public BasicHtmlWebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken) { }

/// <summary>
/// Initializes a new instance of the <see cref="BasicHtmlWebResponseObject"/> class
/// with the specified <paramref name="contentStream"/>.
/// </summary>
/// <param name="response"></param>
/// <param name="contentStream"></param>
public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream) : base(response, contentStream)
/// <param name="response">The response.</param>
/// <param name="contentStream">The content stream associated with the response.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public BasicHtmlWebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken) : base(response, contentStream, cancellationToken)
{
EnsureHtmlParser();
InitializeContent();
InitializeContent(cancellationToken);
InitializeRawContent(response);
}

Expand Down Expand Up @@ -159,15 +162,16 @@ public WebCmdletElementCollection Images
/// <summary>
/// Reads the response content from the web response.
/// </summary>
protected void InitializeContent()
/// <param name="cancellationToken">The cancellation token.</param>
protected void InitializeContent(CancellationToken cancellationToken)
{
string contentType = ContentHelper.GetContentType(BaseResponse);
if (ContentHelper.IsText(contentType))
{
// Fill the Content buffer
string characterSet = WebResponseHelper.GetCharacterSet(BaseResponse);

Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding);
Content = StreamHelper.DecodeStream(RawContentStream, characterSet, out Encoding encoding, cancellationToken);
Encoding = encoding;
}
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Management.Automation;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Xml;

using Newtonsoft.Json;
Expand Down Expand Up @@ -74,12 +75,13 @@ public int MaximumFollowRelLink
internal override void ProcessResponse(HttpResponseMessage response)
{
ArgumentNullException.ThrowIfNull(response);
ArgumentNullException.ThrowIfNull(_cancelToken);

Stream baseResponseStream = StreamHelper.GetResponseStream(response);
Stream baseResponseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token);

if (ShouldWriteToPipeline)
{
using var responseStream = new BufferingStreamReader(baseResponseStream);
using var responseStream = new BufferingStreamReader(baseResponseStream, _cancelToken.Token);

// First see if it is an RSS / ATOM feed, in which case we can
// stream it - unless the user has overridden it with a return type of "XML"
Expand All @@ -95,7 +97,7 @@ internal override void ProcessResponse(HttpResponseMessage response)
// Try to get the response encoding from the ContentType header.
string charSet = WebResponseHelper.GetCharacterSet(response);

string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding);
string str = StreamHelper.DecodeStream(responseStream, charSet, out Encoding encoding, _cancelToken.Token);

object obj = null;
Exception ex = null;
Expand All @@ -112,7 +114,7 @@ internal override void ProcessResponse(HttpResponseMessage response)

// NOTE: Tests use this verbose output to verify the encoding.
WriteVerbose(string.Create(System.Globalization.CultureInfo.InvariantCulture, $"Content encoding: {encodingVerboseName}"));

bool convertSuccess = false;

if (returnType == RestReturnType.Json)
Expand Down Expand Up @@ -227,7 +229,7 @@ private bool TryProcessFeedStream(Stream responseStream)
}
}
}
catch (XmlException)
catch (XmlException)
{
// Catch XmlException
}
Expand Down Expand Up @@ -345,17 +347,19 @@ public enum RestReturnType

internal class BufferingStreamReader : Stream
{
internal BufferingStreamReader(Stream baseStream)
internal BufferingStreamReader(Stream baseStream, CancellationToken cancellationToken)
{
_baseStream = baseStream;
_streamBuffer = new MemoryStream();
_length = long.MaxValue;
_copyBuffer = new byte[4096];
_cancellationToken = cancellationToken;
}

private readonly Stream _baseStream;
private readonly MemoryStream _streamBuffer;
private readonly byte[] _copyBuffer;
private readonly CancellationToken _cancellationToken;

public override bool CanRead => true;

Expand Down Expand Up @@ -389,7 +393,7 @@ public override int Read(byte[] buffer, int offset, int count)
// If we don't have enough data to fill this from memory, cache more.
// We try to read 4096 bytes from base stream every time, so at most we
// may cache 4095 bytes more than what is required by the Read operation.
int bytesRead = _baseStream.Read(_copyBuffer, 0, _copyBuffer.Length);
int bytesRead = _baseStream.ReadAsync(_copyBuffer, 0, _copyBuffer.Length, _cancellationToken).GetAwaiter().GetResult();

if (_streamBuffer.Position < _streamBuffer.Length)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ protected override void ProcessRecord()
string detailMsg = string.Empty;
try
{
string error = StreamHelper.GetResponseString(response);
string error = StreamHelper.GetResponseString(response, _cancelToken.Token);
detailMsg = FormatErrorMessage(error, contentType);
}
catch
Expand Down Expand Up @@ -656,6 +656,11 @@ protected override void ProcessRecord()

ThrowTerminatingError(er);
}
finally
{
_cancelToken?.Dispose();
_cancelToken = null;
}

if (_followRelLink)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading;

namespace Microsoft.PowerShell.Commands
{
Expand Down Expand Up @@ -74,19 +75,21 @@ public class WebResponseObject
/// <summary>
/// Initializes a new instance of the <see cref="WebResponseObject"/> class.
/// </summary>
/// <param name="response"></param>
public WebResponseObject(HttpResponseMessage response) : this(response, null)
/// <param name="response">The Http response.</param>
/// <param name="cancellationToken">The cancellation token.</param>
public WebResponseObject(HttpResponseMessage response, CancellationToken cancellationToken) : this(response, null, cancellationToken)
{ }

/// <summary>
/// Initializes a new instance of the <see cref="WebResponseObject"/> class
/// with the specified <paramref name="contentStream"/>.
/// </summary>
/// <param name="response"></param>
/// <param name="contentStream"></param>
public WebResponseObject(HttpResponseMessage response, Stream contentStream)
/// <param name="response">Http response.</param>
/// <param name="contentStream">The http content stream.</param>
/// <param name="cancellationToken">The cancellation token.</param>
public WebResponseObject(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken)
{
SetResponse(response, contentStream);
SetResponse(response, contentStream, cancellationToken);
InitializeContent();
InitializeRawContent(response);
}
Expand Down Expand Up @@ -116,13 +119,13 @@ private void InitializeRawContent(HttpResponseMessage baseResponse)
RawContent = raw.ToString();
}

private static bool IsPrintable(char c) => char.IsLetterOrDigit(c)
|| char.IsPunctuation(c)
|| char.IsSeparator(c)
|| char.IsSymbol(c)
private static bool IsPrintable(char c) => char.IsLetterOrDigit(c)
|| char.IsPunctuation(c)
|| char.IsSeparator(c)
|| char.IsSymbol(c)
|| char.IsWhiteSpace(c);

private void SetResponse(HttpResponseMessage response, Stream contentStream)
private void SetResponse(HttpResponseMessage response, Stream contentStream, CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(response);

Expand All @@ -138,7 +141,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream)
Stream st = contentStream;
if (contentStream is null)
{
st = StreamHelper.GetResponseStream(response);
st = StreamHelper.GetResponseStream(response, cancellationToken);
}

long contentLength = response.Content.Headers.ContentLength.Value;
Expand All @@ -148,7 +151,7 @@ private void SetResponse(HttpResponseMessage response, Stream contentStream)
}

int initialCapacity = (int)Math.Min(contentLength, StreamHelper.DefaultReadBuffer);
RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault());
RawContentStream = new WebResponseContentMemoryStream(st, initialCapacity, cmdlet: null, response.Content.Headers.ContentLength.GetValueOrDefault(), cancellationToken);
}

// Set the position of the content stream to the beginning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@ internal override void ProcessResponse(HttpResponseMessage response)
{
ArgumentNullException.ThrowIfNull(response);

Stream responseStream = StreamHelper.GetResponseStream(response);
Stream responseStream = StreamHelper.GetResponseStream(response, _cancelToken.Token);
if (ShouldWriteToPipeline)
{
// creating a MemoryStream wrapper to response stream here to support IsStopping.
responseStream = new WebResponseContentMemoryStream(
responseStream,
StreamHelper.ChunkSize,
this,
response.Content.Headers.ContentLength.GetValueOrDefault());
WebResponseObject ro = WebResponseObjectFactory.GetResponseObject(response, responseStream, this.Context);
response.Content.Headers.ContentLength.GetValueOrDefault(),
_cancelToken.Token);
WebResponseObject ro = WebResponseObjectFactory.GetResponseObject(response, responseStream, this.Context, _cancelToken.Token);
ro.RelationLink = _relationLink;
WriteObject(ro);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace Microsoft.PowerShell.Commands
{
internal static class WebResponseObjectFactory
{
internal static WebResponseObject GetResponseObject(HttpResponseMessage response, Stream responseStream, ExecutionContext executionContext)
internal static WebResponseObject GetResponseObject(HttpResponseMessage response, Stream responseStream, ExecutionContext executionContext, System.Threading.CancellationToken cancellationToken)
{
WebResponseObject output = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream) : new WebResponseObject(response, responseStream);
WebResponseObject output = WebResponseHelper.IsText(response) ? new BasicHtmlWebResponseObject(response, responseStream, cancellationToken) : new WebResponseObject(response, responseStream, cancellationToken);

return output;
}
Expand Down