Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ internal partial class DownloadDirectoryCommand : BaseCommand<TransferUtilityDow
long _transferredBytes;
string _currentFile;

internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing unused constructors

: this(s3Client, request, useMultipartDownload: false)
{
}

internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request, bool useMultipartDownload)
internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request, TransferUtilityConfig config, bool useMultipartDownload)
{
if (s3Client == null)
throw new ArgumentNullException(nameof(s3Client));
Expand All @@ -62,6 +57,7 @@ internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDir

this._s3Client = s3Client;
this._request = request;
this._config = config;
this._skipEncryptionInstructionFiles = s3Client is Amazon.S3.Internal.IAmazonS3Encryption;
_failurePolicy =
request.FailurePolicy == FailurePolicy.AbortOnFailure
Expand All @@ -70,12 +66,6 @@ internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDir
this._useMultipartDownload = useMultipartDownload;
}

internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request, TransferUtilityConfig config, bool useMultipartDownload)
: this(s3Client, request, useMultipartDownload)
{
this._config = config;
}

private void downloadedProgressEventCallback(object sender, WriteObjectProgressArgs e)
{
var transferredBytes = Interlocked.Add(ref _transferredBytes, e.IncrementTransferred);
Expand Down
104 changes: 103 additions & 1 deletion sdk/src/Services/S3/Custom/Transfer/Internal/TaskHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
* permissions and limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Amazon.Runtime.Internal.Util;

namespace Amazon.S3.Transfer.Internal
{
Expand All @@ -24,6 +27,11 @@ namespace Amazon.S3.Transfer.Internal
/// </summary>
internal static class TaskHelpers
{
private static Logger Logger
{
get { return Logger.GetLogger(typeof(TaskHelpers)); }
}
Comment on lines +30 to +33
Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recursive property definition: the Logger property calls Logger.GetLogger() which references itself. This should use a static field or a different pattern.

Copilot uses AI. Check for mistakes.

/// <summary>
/// Waits for all tasks to complete or till any task fails or is canceled.
/// </summary>
Expand All @@ -33,7 +41,10 @@ internal static class TaskHelpers
internal static async Task WhenAllOrFirstExceptionAsync(List<Task> pendingTasks, CancellationToken cancellationToken)
{
int processed = 0;
int total = pendingTasks.Count;
int total = pendingTasks.Count;

Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync: Starting with TotalTasks={0}", total);

while (processed < total)
{
cancellationToken.ThrowIfCancellationRequested();
Expand All @@ -48,7 +59,12 @@ await completedTask

pendingTasks.Remove(completedTask);
processed++;

Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync: Task completed (Processed={0}/{1}, Remaining={2})",
processed, total, pendingTasks.Count);
}

Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync: All tasks completed (Total={0})", total);
}

/// <summary>
Expand All @@ -64,6 +80,9 @@ internal static async Task<List<T>> WhenAllOrFirstExceptionAsync<T>(List<Task<T>
int processed = 0;
int total = pendingTasks.Count;
var responses = new List<T>();

Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync<T>: Starting with TotalTasks={0}", total);

while (processed < total)
{
cancellationToken.ThrowIfCancellationRequested();
Expand All @@ -79,9 +98,92 @@ internal static async Task<List<T>> WhenAllOrFirstExceptionAsync<T>(List<Task<T>

pendingTasks.Remove(completedTask);
processed++;

Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync<T>: Task completed (Processed={0}/{1}, Remaining={2})",
processed, total, pendingTasks.Count);
}

Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync<T>: All tasks completed (Total={0})", total);

return responses;
}

/// <summary>
/// Executes work items with limited concurrency using a task pool pattern.
/// Creates only as many tasks as the concurrency limit allows, rather than creating
/// all tasks upfront. This reduces memory overhead for large collections.
/// </summary>
/// <remarks>
/// This method provides a clean way to limit concurrent operations without creating
/// all tasks upfront. It maintains a pool of active tasks up to the maxConcurrency limit,
/// replacing completed tasks with new ones until all items are processed.
/// The caller is responsible for implementing failure handling within the processAsync function.
/// </remarks>
/// <typeparam name="T">The type of items to process</typeparam>
/// <param name="items">The collection of items to process</param>
/// <param name="maxConcurrency">Maximum number of concurrent tasks</param>
/// <param name="processAsync">Async function to process each item</param>
/// <param name="cancellationToken">Cancellation token to observe</param>
/// <returns>A task that completes when all items are processed, or throws on first failure</returns>
internal static async Task ForEachWithConcurrencyAsync<T>(
IEnumerable<T> items,
int maxConcurrency,
Func<T, CancellationToken, Task> processAsync,
CancellationToken cancellationToken)
{
var itemList = items as IList<T> ?? items.ToList();
if (itemList.Count == 0)
{
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: No items to process");
return;
}

Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Starting with TotalItems={0}, MaxConcurrency={1}",
itemList.Count, maxConcurrency);

int nextIndex = 0;
var activeTasks = new List<Task>();

// Start initial batch up to concurrency limit
int initialBatchSize = Math.Min(maxConcurrency, itemList.Count);
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Starting initial batch of {0} tasks", initialBatchSize);

for (int i = 0; i < initialBatchSize; i++)
{
var task = processAsync(itemList[nextIndex++], cancellationToken);
activeTasks.Add(task);
}

// Process completions and start new tasks until all work is done
while (activeTasks.Count > 0)
{
cancellationToken.ThrowIfCancellationRequested();

var completedTask = await Task.WhenAny(activeTasks)
.ConfigureAwait(continueOnCapturedContext: false);

// Propagate exceptions (fail-fast behavior by default)
// Caller's processAsync function should handle failure policy if needed
await completedTask
.ConfigureAwait(continueOnCapturedContext: false);

activeTasks.Remove(completedTask);

int itemsCompleted = nextIndex - activeTasks.Count;
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Task completed (Active={0}, Completed={1}/{2}, Remaining={3})",
activeTasks.Count, itemsCompleted, itemList.Count, itemList.Count - itemsCompleted);

// Start next task if more work remains
if (nextIndex < itemList.Count)
{
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Starting next task (Index={0}/{1}, Active={2})",
nextIndex + 1, itemList.Count, activeTasks.Count + 1);
var nextTask = processAsync(itemList[nextIndex++], cancellationToken);
activeTasks.Add(nextTask);
}
}

Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: All items processed (Total={0})", itemList.Count);
}
}
}
Loading