Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.FindSymbols;
Expand All @@ -27,6 +28,8 @@ public static async Task ModifyAllReferencesAsync(
ILogger? logger = null,
CancellationToken ct = default)
{
using var _ = new ProfilingScope($"LocationTransformationUtils.ModifyAllReferencesAsync");

// Convert to lists
// The parameters being IEnumerables is for convenience
var symbolList = symbols.ToList();
Expand Down Expand Up @@ -55,80 +58,105 @@ public static async Task ModifyAllReferencesAsync(
}
}

// Find all locations where the symbols are referenced
// TODO this needs parallelisation config & be sensitive to the environment (future src generator form factor?)
ImmutableHashSet<Document> documents = [..ctx.SourceProject.Documents, ..ctx.TestProject?.Documents ?? []];
await Parallel.ForEachAsync(
symbolList,
ct,
async (symbol, _) => {
var references = await SymbolFinder.FindReferencesAsync(symbol, ctx.SourceProject.Solution, documents, ct);
foreach (var reference in references)
{
foreach (var location in reference.Locations)
var sync = new object();
var findReferencesAsyncTimeElapsed = TimeSpan.Zero;
var forEachAsyncBodyTimeElapsed = TimeSpan.Zero;

using (new ProfilingScope($"LocationTransformationUtils.ModifyAllReferencesAsync - Find references"))
{
// Find all locations where the symbols are referenced
// TODO this needs parallelisation config & be sensitive to the environment (future src generator form factor?)
ImmutableHashSet<Document> documents = [..ctx.SourceProject.Documents, ..ctx.TestProject?.Documents ?? []];
await Parallel.ForEachAsync(
symbolList,
ct,
async (symbol, _) => {
var timestamp = Stopwatch.GetTimestamp();
var references = await SymbolFinder.FindReferencesAsync(symbol, ctx.SourceProject.Solution, documents, ct);
var elapsed = Stopwatch.GetElapsedTime(timestamp);
lock (sync)
{
locations.Add((location.Location,
new LocationTransformerContext(symbol, false,
location.IsCandidateLocation || location.IsImplicit)));
findReferencesAsyncTimeElapsed += elapsed;
}
foreach (var reference in references)
{
foreach (var location in reference.Locations)
{
locations.Add((location.Location,
new LocationTransformerContext(symbol, false,
location.IsCandidateLocation || location.IsImplicit)));
}
}
elapsed = Stopwatch.GetElapsedTime(timestamp);
lock (sync)
{
forEachAsyncBodyTimeElapsed += elapsed;
}
}
}
);
);
}

// Group the locations by source tree. This will be used to prevent accidentally overwriting changes.
var solution = ctx.SourceProject.Solution;
var locationsBySourcetree = locations.GroupBy(l => l.Location.SourceTree);
foreach (var group in locationsBySourcetree)
{
var syntaxTree = group.Key;
if (syntaxTree == null)
{
continue;
}
Console.WriteLine("Symbol count: {0}", symbolList.Count);
Console.WriteLine("Scope '{0}' took {1:F3} MILLISECONDS ON AVERAGE", "LocationTransformationUtils.ModifyAllReferencesAsync - Find references - FindReferencesAsync", findReferencesAsyncTimeElapsed.TotalMilliseconds / symbolList.Count);
Console.WriteLine("Scope '{0}' took {1:F3} MILLISECONDS ON AVERAGE", "LocationTransformationUtils.ModifyAllReferencesAsync - Find references - ForEachAsync Body", forEachAsyncBodyTimeElapsed.TotalMilliseconds / symbolList.Count);

var document = solution.GetDocument(syntaxTree);
if (document == null)
using (new ProfilingScope($"LocationTransformationUtils.ModifyAllReferencesAsync - Rewrite"))
{
// Group the locations by source tree. This will be used to prevent accidentally overwriting changes.
var solution = ctx.SourceProject.Solution;
var locationsBySourcetree = locations.GroupBy(l => l.Location.SourceTree);
foreach (var group in locationsBySourcetree)
{
continue;
}

var syntaxRoot = await syntaxTree.GetRootAsync(ct);
var syntaxTree = group.Key;
if (syntaxTree == null)
{
continue;
}

// Modify each location
// We order the locations so that we modify starting from the end of the file
// This way we prevent changes from being accidentally overwriting changes
foreach (var (location, context) in group.OrderByDescending(l => l.Location.SourceSpan))
{
foreach (var transformer in transformersList)
var document = solution.GetDocument(syntaxTree);
if (document == null)
{
var syntaxNode = syntaxRoot.FindNode(location.SourceSpan);
var nodeToModify = transformer.GetNodeToModify(syntaxNode, context);
if (nodeToModify == null)
{
continue;
}
continue;
}

var newNode = transformer.Visit(nodeToModify);
var originalLength = syntaxNode.FullSpan.Length;
var newLength = newNode.FullSpan.Length;
var syntaxRoot = await syntaxTree.GetRootAsync(ct);

// Ensure that the new node's length is at least the original node's length
// This is because the last few nodes processed usually make up the entire document
// If the document's length has been reduced, then an ArgumentOutOfRangeException will be thrown
if (originalLength - newLength > 0)
// Modify each location
// We order the locations so that we modify starting from the end of the file
// This way we prevent changes from being accidentally overwriting changes
foreach (var (location, context) in group.OrderByDescending(l => l.Location.SourceSpan))
{
foreach (var transformer in transformersList)
{
newNode = newNode.WithTrailingTrivia(TriviaList([..newNode.GetTrailingTrivia(), Whitespace(new string(' ', originalLength - newLength))]));
var syntaxNode = syntaxRoot.FindNode(location.SourceSpan);
var nodeToModify = transformer.GetNodeToModify(syntaxNode, context);
if (nodeToModify == null)
{
continue;
}

var newNode = transformer.Visit(nodeToModify);
var originalLength = syntaxNode.FullSpan.Length;
var newLength = newNode.FullSpan.Length;

// Ensure that the new node's length is at least the original node's length
// This is because the last few nodes processed usually make up the entire document
// If the document's length has been reduced, then an ArgumentOutOfRangeException will be thrown
if (originalLength - newLength > 0)
{
newNode = newNode.WithTrailingTrivia(TriviaList([..newNode.GetTrailingTrivia(), Whitespace(new string(' ', originalLength - newLength))]));
}

syntaxRoot = syntaxRoot.ReplaceNode(nodeToModify, newNode);
}

syntaxRoot = syntaxRoot.ReplaceNode(nodeToModify, newNode);
}

// Commit the changes to the solution
var newDocument = document.WithSyntaxRoot(syntaxRoot.NormalizeWhitespace());
solution = newDocument.Project.Solution;
}

// Commit the changes to the solution
var newDocument = document.WithSyntaxRoot(syntaxRoot.NormalizeWhitespace());
solution = newDocument.Project.Solution;
ctx.SourceProject = solution.GetProject(ctx.SourceProject.Id);
}

ctx.SourceProject = solution.GetProject(ctx.SourceProject.Id);
}
}
156 changes: 152 additions & 4 deletions sources/SilkTouch/SilkTouch/Naming/NameUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,157 @@ public static async Task RenameAllAsync(
bool includeCandidateLocations = false
)
{
var toRenameList = toRename.ToList();
await LocationTransformationUtils.ModifyAllReferencesAsync(ctx, toRenameList.Select(t => t.Symbol), [
new IdentifierRenamingTransformer(toRenameList, includeDeclarations, includeCandidateLocations)
], logger, ct);
using var _ = new ProfilingScope($"NameUtils.RenameAllAsync");

if (ctx.SourceProject is null)
{
return;
}

var sync = new object();
var findReferencesAsyncTimeElapsed = TimeSpan.Zero;
var forEachAsyncBodyTimeElapsed = TimeSpan.Zero;
var symbolCount = 0;

var locations = new ConcurrentDictionary<Location, string>();
using (new ProfilingScope($"NameUtils.RenameAllAsync - Find references"))
{
// TODO this needs parallelisation config & be sensitive to the environment (future src generator form factor?)
await Parallel.ForEachAsync(
toRename,
ct,
async (tuple, _) =>
{
var startTimestamp = Stopwatch.GetTimestamp();
// First, let's add all of the locations of the declaration identifiers.
var (symbol, newName) = tuple;
if (includeDeclarations)
{
foreach (var syntaxRef in symbol.DeclaringSyntaxReferences)
{
var identifierLocation = IdentifierLocation(
await syntaxRef.GetSyntaxAsync(ct)
);
if (identifierLocation is not null)
{
locations.TryAdd(identifierLocation, newName);
}
}
}

// Next, let's find all the references of the symbols.
var findReferencesStartTimestamp = Stopwatch.GetTimestamp();
var references = await SymbolFinder.FindReferencesAsync(
symbol,
ctx.SourceProject?.Solution
?? throw new ArgumentException("SourceProject is null"),
ct
);

var elapsed = Stopwatch.GetElapsedTime(findReferencesStartTimestamp);
lock (sync)
{
findReferencesAsyncTimeElapsed += elapsed;
symbolCount++;
}

foreach (var referencedSymbol in references)
{
foreach (var referencedSymbolLocation in referencedSymbol.Locations)
{
if (
!includeCandidateLocations
&& (
referencedSymbolLocation.IsCandidateLocation
|| referencedSymbolLocation.IsImplicit
)
)
{
continue;
}

locations.TryAdd(referencedSymbolLocation.Location, newName);
}
}

elapsed = Stopwatch.GetElapsedTime(startTimestamp);
lock (sync)
{
forEachAsyncBodyTimeElapsed += elapsed;
}
}
);
}

Console.WriteLine("Symbol count: {0}", symbolCount);
Console.WriteLine("Scope '{0}' took {1:F3} MILLISECONDS ON AVERAGE", "NameUtils.RenameAllAsync - Find references - FindReferencesAsync", findReferencesAsyncTimeElapsed.TotalMilliseconds / symbolCount);
Console.WriteLine("Scope '{0}' took {1:F3} MILLISECONDS ON AVERAGE", "NameUtils.RenameAllAsync - Find references - ForEachAsync Body", forEachAsyncBodyTimeElapsed.TotalMilliseconds / symbolCount);

logger?.LogDebug(
"{} referencing locations for renames for {}",
locations.Count,
ctx.JobKey
);

using (new ProfilingScope($"NameUtils.RenameAllAsync - Rewrite"))
{
// Now it's just a simple find and replace.
var sln = ctx.SourceProject.Solution;
var srcProjId = ctx.SourceProject.Id;
var testProjId = ctx.TestProject?.Id;
foreach (
var (syntaxTree, renameLocations) in locations
.GroupBy(x => x.Key.SourceTree)
.Select(x => (x.Key, x.OrderByDescending(y => y.Key.SourceSpan)))
)
{
if (
syntaxTree is null
|| sln.GetDocument(syntaxTree) is not { } doc
|| (doc.Project.Id != srcProjId && doc.Project.Id != testProjId)
|| await syntaxTree.GetTextAsync(ct) is not { } text
)
{
continue;
}

var ogText = text;
foreach (var (location, newName) in renameLocations)
{
var contents = ogText.GetSubText(location.SourceSpan).ToString();
if (contents.Contains(' '))
{
logger?.LogWarning(
"Refusing to do unsafe rename/replacement of \"{}\" to \"{}\" at {}",
contents,
newName,
location.GetLineSpan()
);
continue;
}

if (contents == "this" || contents == "base")
{
continue;
}

if (logger?.IsEnabled(LogLevel.Trace) ?? false)
{
logger?.LogTrace(
"\"{}\" -> \"{}\" at {}",
contents,
newName,
location.GetLineSpan()
);
}

text = text.Replace(location.SourceSpan, newName);
}

sln = doc.WithText(text).Project.Solution;
}

ctx.SourceProject = sln.GetProject(srcProjId);
}
}
}
21 changes: 21 additions & 0 deletions sources/SilkTouch/SilkTouch/ProfilingScope.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Diagnostics;

namespace Silk.NET.SilkTouch;

internal readonly struct ProfilingScope : IDisposable
{
private readonly string name;
private readonly long timestamp;

public ProfilingScope(string name)
{
this.name = name;
timestamp = Stopwatch.GetTimestamp();
}

public void Dispose()
{
var elapsed = Stopwatch.GetElapsedTime(timestamp);
Console.WriteLine("Scope '{0}' took {1:F3} seconds", name, elapsed.TotalSeconds);
}
}
4 changes: 4 additions & 0 deletions sources/SilkTouch/SilkTouch/SilkTouchGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public async Task RunAsync(
// Initialize the mods
foreach (var jobMod in jobMods)
{
using var _ = new ProfilingScope($"{jobMod.GetType().Name} Initialize");

logger.LogDebug("Using mod {0} for {1}", jobMod.GetType().Name, key);
await jobMod.InitializeAsync(ctx, ct);
if (ctx.SourceProject != srcProj || ctx.TestProject != testProj)
Expand All @@ -87,6 +89,8 @@ public async Task RunAsync(

foreach (var jobMod in jobMods)
{
using var _ = new ProfilingScope($"{jobMod.GetType().Name} Execute");

logger.LogInformation("Executing {} for {}...", jobMod.GetType().Name, key);
await jobMod.ExecuteAsync(ctx, ct);
}
Expand Down