Skip to content

Commit e4804d5

Browse files
.Net: Updates to SessionsPythonPlugin (#11872)
### Motivation, Context and Description This PR provides control over which domains requests can be sent to. Additionally, it moves functionality that is common to all operations of the plugin to one private method to remove duplication. Contributes to: #10070
1 parent 46e5744 commit e4804d5

File tree

3 files changed

+86
-17
lines changed

3 files changed

+86
-17
lines changed

dotnet/src/Plugins/Plugins.Core/CodeInterpreter/SessionsPythonPlugin.cs

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Collections.Generic;
55
using System.ComponentModel;
66
using System.IO;
7+
using System.Linq;
78
using System.Net.Http;
89
using System.Text;
910
using System.Text.Json;
@@ -94,12 +95,9 @@ public async Task<string> ExecuteCodeAsync([Description("The valid Python code t
9495

9596
var requestBody = new SessionsPythonCodeExecutionProperties(this._settings, code);
9697

97-
using var request = new HttpRequestMessage(HttpMethod.Post, $"{this._poolManagementEndpoint}/executions?identifier={this._settings.SessionId}&api-version={ApiVersion}")
98-
{
99-
Content = new StringContent(JsonSerializer.Serialize(requestBody), Encoding.UTF8, "application/json")
100-
};
98+
using var content = new StringContent(JsonSerializer.Serialize(requestBody), Encoding.UTF8, "application/json");
10199

102-
using var response = await httpClient.SendWithSuccessCheckAsync(request, CancellationToken.None).ConfigureAwait(false);
100+
using var response = await this.SendAsync(httpClient, HttpMethod.Post, "executions", content).ConfigureAwait(false);
103101

104102
var responseContent = JsonSerializer.Deserialize<JsonElement>(await response.Content.ReadAsStringAsync().ConfigureAwait(false));
105103

@@ -139,15 +137,13 @@ public async Task<SessionsRemoteFileMetadata> UploadFileAsync(
139137
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
140138

141139
using var fileContent = new ByteArrayContent(File.ReadAllBytes(localFilePath));
142-
using var request = new HttpRequestMessage(HttpMethod.Post, $"{this._poolManagementEndpoint}files?identifier={this._settings.SessionId}&api-version={ApiVersion}")
140+
141+
using var multipartFormDataContent = new MultipartFormDataContent()
143142
{
144-
Content = new MultipartFormDataContent
145-
{
146-
{ fileContent, "file", remoteFileName },
147-
}
143+
{ fileContent, "file", remoteFileName },
148144
};
149145

150-
using var response = await httpClient.SendWithSuccessCheckAsync(request, CancellationToken.None).ConfigureAwait(false);
146+
using var response = await this.SendAsync(httpClient, HttpMethod.Post, "files", multipartFormDataContent).ConfigureAwait(false);
151147

152148
var stringContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
153149

@@ -172,9 +168,7 @@ public async Task<byte[]> DownloadFileAsync(
172168
using var httpClient = this._httpClientFactory.CreateClient();
173169
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
174170

175-
using var request = new HttpRequestMessage(HttpMethod.Get, $"{this._poolManagementEndpoint}/files/{Uri.EscapeDataString(remoteFileName)}/content?identifier={this._settings.SessionId}&api-version={ApiVersion}");
176-
177-
using var response = await httpClient.SendWithSuccessCheckAsync(request, CancellationToken.None).ConfigureAwait(false);
171+
using var response = await this.SendAsync(httpClient, HttpMethod.Get, $"files/{Uri.EscapeDataString(remoteFileName)}/content").ConfigureAwait(false);
178172

179173
var fileContent = await response.Content.ReadAsByteArrayAsync().ConfigureAwait(false);
180174

@@ -205,9 +199,7 @@ public async Task<IReadOnlyList<SessionsRemoteFileMetadata>> ListFilesAsync()
205199
using var httpClient = this._httpClientFactory.CreateClient();
206200
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
207201

208-
using var request = new HttpRequestMessage(HttpMethod.Get, $"{this._poolManagementEndpoint}/files?identifier={this._settings.SessionId}&api-version={ApiVersion}");
209-
210-
using var response = await httpClient.SendWithSuccessCheckAsync(request, CancellationToken.None).ConfigureAwait(false);
202+
using var response = await this.SendAsync(httpClient, HttpMethod.Get, "files").ConfigureAwait(false);
211203

212204
var jsonElementResult = JsonSerializer.Deserialize<JsonElement>(await response.Content.ReadAsStringAsync().ConfigureAwait(false));
213205

@@ -262,6 +254,36 @@ private async Task AddHeadersAsync(HttpClient httpClient)
262254
}
263255
}
264256

257+
/// <summary>
258+
/// Sends an HTTP request to the specified path with the specified method and content.
259+
/// </summary>
260+
/// <param name="httpClient">The HTTP client to use.</param>
261+
/// <param name="method">The HTTP method to use.</param>
262+
/// <param name="path">The path to send the request to.</param>
263+
/// <param name="httpContent">The content to send with the request.</param>
264+
/// <returns>The HTTP response message.</returns>
265+
private async Task<HttpResponseMessage> SendAsync(HttpClient httpClient, HttpMethod method, string path, HttpContent? httpContent = null)
266+
{
267+
// The query string is the same for all operations
268+
var pathWithQueryString = $"{path}?identifier={this._settings.SessionId}&api-version={ApiVersion}";
269+
270+
var uri = new Uri(this._poolManagementEndpoint, pathWithQueryString);
271+
272+
// If a list of allowed domains has been provided, the host of the provided
273+
// uri is checked to verify it is in the allowed domain list.
274+
if (!this._settings.AllowedDomains?.Contains(uri.Host) ?? false)
275+
{
276+
throw new InvalidOperationException("Sending requests to the provided location is not allowed.");
277+
}
278+
279+
using var request = new HttpRequestMessage(method, uri)
280+
{
281+
Content = httpContent,
282+
};
283+
284+
return await httpClient.SendWithSuccessCheckAsync(request, CancellationToken.None).ConfigureAwait(false);
285+
}
286+
265287
#if NET
266288
[GeneratedRegex(@"^(\s|`)*(?i:python)?\s*", RegexOptions.ExplicitCapture)]
267289
private static partial Regex RemoveLeadingWhitespaceBackticksPython();

dotnet/src/Plugins/Plugins.Core/CodeInterpreter/SessionsPythonSettings.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System;
4+
using System.Collections.Generic;
45
using System.ComponentModel;
56
using System.Text.Json.Serialization;
67

@@ -23,6 +24,11 @@ public class SessionsPythonSettings
2324
[JsonIgnore]
2425
public Uri Endpoint { get; set; }
2526

27+
/// <summary>
28+
/// List of allowed domains to download from.
29+
/// </summary>
30+
public IEnumerable<string>? AllowedDomains { get; set; }
31+
2632
/// <summary>
2733
/// The session identifier.
2834
/// </summary>

dotnet/src/Plugins/Plugins.UnitTests/Core/SessionsPythonPluginTests.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,47 @@ public async Task ItShouldDownloadFileSavingInDiskAsync()
291291
Assert.Equal(responseContent, await File.ReadAllBytesAsync(downloadDiskPath));
292292
}
293293

294+
/// <summary>
295+
/// Test the allowed domains for the endpoint.
296+
/// </summary>
297+
/// <remarks>
298+
/// Considering that the functionality which verifies endpoints against the allowed domains is located in one private method,
299+
/// and the method is reused for all operations of the plugin, we test it only for one operation (ListFilesAsync).
300+
/// </remarks>
301+
[Theory]
302+
[InlineData("fake-test-host.io", "https://fake-test-host.io/subscriptions/123/rg/456/sps/test-pool", true)]
303+
[InlineData("prod.fake-test-host.io", "https://prod.fake-test-host.io/subscriptions/123/rg/456/sps/test-pool", true)]
304+
[InlineData("www.fake-test-host.io", "https://www.fake-test-host.io/subscriptions/123/rg/456/sps/test-pool", true)]
305+
[InlineData("www.prod.fake-test-host.io", "https://www.prod.fake-test-host.io/subscriptions/123/rg/456/sps/test-pool", true)]
306+
[InlineData("fake-test-host.io", "https://fake-test-host-1.io/subscriptions/123/rg/456/sps/test-pool", false)]
307+
[InlineData("fake-test-host.io", "https://www.fake-test-host.io/subscriptions/123/rg/456/sps/test-pool", false)]
308+
[InlineData("www.fake-test-host.io", "https://fake-test-host.io/subscriptions/123/rg/456/sps/test-pool", false)]
309+
public async Task ItShouldRespectAllowedDomainsAsync(string allowedDomain, string actualEndpoint, bool isAllowed)
310+
{
311+
// Arrange
312+
this._defaultSettings.AllowedDomains = [allowedDomain];
313+
this._defaultSettings.Endpoint = new Uri(actualEndpoint);
314+
315+
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
316+
{
317+
Content = new StringContent(File.ReadAllText(ListFilesTestDataFilePath)),
318+
};
319+
320+
var sut = new SessionsPythonPlugin(this._defaultSettings, this._httpClientFactory);
321+
322+
// Act
323+
#pragma warning disable CA1031 // Do not catch general exception types
324+
try
325+
{
326+
await sut.ListFilesAsync();
327+
}
328+
catch when (!isAllowed)
329+
{
330+
// Ignore exception if the endpoint is not allowed since we expect it
331+
}
332+
#pragma warning restore CA1031 // Do not catch general exception types
333+
}
334+
294335
public void Dispose()
295336
{
296337
this._httpClient.Dispose();

0 commit comments

Comments
 (0)