Skip to content

Commit b4ead92

Browse files
CopilotAdamFrisby
andcommitted
Add call site update functionality to ParametersToParameterObject refactoring
Co-authored-by: AdamFrisby <114041+AdamFrisby@users.noreply.github.com>
1 parent 09d6c63 commit b4ead92

File tree

2 files changed

+295
-3
lines changed

2 files changed

+295
-3
lines changed

Cast.Tool.Tests/ParametersToParameterObjectCommandTests.cs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,129 @@ public void ProcessData(bool isActive, string message, List<int> numbers)
270270
File.Delete(csFile);
271271
}
272272
}
273+
274+
[Fact]
275+
public async Task ParametersToParameterObject_WithCallSites_ShouldUpdateCallers()
276+
{
277+
// Arrange
278+
var testCode = @"using System;
279+
280+
namespace MyProject
281+
{
282+
public class Calculator
283+
{
284+
public int Add(int a, int b)
285+
{
286+
return a + b;
287+
}
288+
289+
public void TestMethod()
290+
{
291+
var result1 = Add(5, 10);
292+
var result2 = Add(3, 7);
293+
Console.WriteLine(result1 + result2);
294+
}
295+
}
296+
}";
297+
298+
var tempFile = Path.GetTempFileName();
299+
var csFile = Path.ChangeExtension(tempFile, ".cs");
300+
File.Move(tempFile, csFile);
301+
await File.WriteAllTextAsync(csFile, testCode);
302+
303+
try
304+
{
305+
// Act
306+
var command = new ParametersToParameterObjectCommand();
307+
var settings = new ParametersToParameterObjectCommand.Settings
308+
{
309+
FilePath = csFile,
310+
LineNumber = 7, // Line with public int Add
311+
ParameterObjectName = "AddParams",
312+
ParameterObjectType = "class",
313+
UpdateCallers = true,
314+
DryRun = false
315+
};
316+
317+
var result = await command.ExecuteAsync(null!, settings);
318+
Assert.Equal(0, result);
319+
320+
// Verify the transformation
321+
var modifiedCode = await File.ReadAllTextAsync(csFile);
322+
Assert.Contains("public int Add(AddParams args)", modifiedCode);
323+
Assert.Contains("return args.a + args.b;", modifiedCode);
324+
Assert.Contains("public class AddParams", modifiedCode);
325+
Assert.Contains("var result1 = Add(new AddParams(5, 10));", modifiedCode);
326+
Assert.Contains("var result2 = Add(new AddParams(3, 7));", modifiedCode);
327+
Assert.Contains("this.a = a;", modifiedCode);
328+
Assert.Contains("this.b = b;", modifiedCode);
329+
}
330+
finally
331+
{
332+
// Cleanup
333+
if (File.Exists(csFile))
334+
File.Delete(csFile);
335+
}
336+
}
337+
338+
[Fact]
339+
public async Task ParametersToParameterObject_WithCallSitesDisabled_ShouldNotUpdateCallers()
340+
{
341+
// Arrange
342+
var testCode = @"using System;
343+
344+
namespace MyProject
345+
{
346+
public class Calculator
347+
{
348+
public int Add(int a, int b)
349+
{
350+
return a + b;
351+
}
352+
353+
public void TestMethod()
354+
{
355+
var result = Add(5, 10);
356+
Console.WriteLine(result);
357+
}
358+
}
359+
}";
360+
361+
var tempFile = Path.GetTempFileName();
362+
var csFile = Path.ChangeExtension(tempFile, ".cs");
363+
File.Move(tempFile, csFile);
364+
await File.WriteAllTextAsync(csFile, testCode);
365+
366+
try
367+
{
368+
// Act
369+
var command = new ParametersToParameterObjectCommand();
370+
var settings = new ParametersToParameterObjectCommand.Settings
371+
{
372+
FilePath = csFile,
373+
LineNumber = 7, // Line with public int Add
374+
ParameterObjectName = "AddParams",
375+
ParameterObjectType = "class",
376+
UpdateCallers = false,
377+
DryRun = false
378+
};
379+
380+
var result = await command.ExecuteAsync(null!, settings);
381+
Assert.Equal(0, result);
382+
383+
// Verify the transformation
384+
var modifiedCode = await File.ReadAllTextAsync(csFile);
385+
Assert.Contains("public int Add(AddParams args)", modifiedCode);
386+
Assert.Contains("return args.a + args.b;", modifiedCode);
387+
Assert.Contains("public class AddParams", modifiedCode);
388+
// Call site should remain unchanged
389+
Assert.Contains("var result = Add(5, 10);", modifiedCode);
390+
}
391+
finally
392+
{
393+
// Cleanup
394+
if (File.Exists(csFile))
395+
File.Delete(csFile);
396+
}
397+
}
273398
}

Cast.Tool/Commands/ParametersToParameterObjectCommand.cs

Lines changed: 170 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ public class Settings : CommandSettings
4444
[Description("Show what changes would be made without applying them")]
4545
[DefaultValue(false)]
4646
public bool DryRun { get; init; } = false;
47+
48+
[CommandOption("--update-callers")]
49+
[Description("Automatically update call sites to use the new parameter object")]
50+
[DefaultValue(true)]
51+
public bool UpdateCallers { get; init; } = true;
4752
}
4853

4954
public override int Execute(CommandContext context, Settings settings)
@@ -118,8 +123,21 @@ public async Task<int> ExecuteAsync(CommandContext context, Settings settings)
118123

119124
var newRoot = root.ReplaceNode(containingType, updatedContainingType);
120125

121-
// TODO: Find and update all call sites (for now, just show a warning)
122-
AnsiConsole.WriteLine("[yellow]Note: Call sites will need to be updated manually in this version[/]");
126+
// Find and update call sites if requested
127+
if (settings.UpdateCallers)
128+
{
129+
var methodSymbol = model.GetDeclaredSymbol(method);
130+
if (methodSymbol != null)
131+
{
132+
// We need to work with the original semantic model to find call sites
133+
// because the new syntax tree doesn't have a semantic model yet
134+
newRoot = UpdateCallSites(root, model, methodSymbol, parameterObjectName, settings.ParameterObjectType, newRoot);
135+
}
136+
}
137+
else
138+
{
139+
AnsiConsole.WriteLine("[yellow]Note: Call sites will need to be updated manually (use --update-callers to enable automatic updates)[/]");
140+
}
123141

124142
var result = newRoot.NormalizeWhitespace().ToFullString();
125143

@@ -255,7 +273,10 @@ private ConstructorDeclarationSyntax CreateConstructor(string className, Separat
255273
SyntaxFactory.ExpressionStatement(
256274
SyntaxFactory.AssignmentExpression(
257275
SyntaxKind.SimpleAssignmentExpression,
258-
SyntaxFactory.IdentifierName(p.Identifier.ValueText),
276+
SyntaxFactory.MemberAccessExpression(
277+
SyntaxKind.SimpleMemberAccessExpression,
278+
SyntaxFactory.ThisExpression(),
279+
SyntaxFactory.IdentifierName(p.Identifier.ValueText)),
259280
SyntaxFactory.IdentifierName(p.Identifier.ValueText))));
260281

261282
var body = SyntaxFactory.Block(assignments);
@@ -329,4 +350,150 @@ private bool IsInParameterContext(IdentifierNameSyntax identifier)
329350
// Check if this identifier is in a parameter declaration context
330351
return identifier.Ancestors().OfType<ParameterSyntax>().Any();
331352
}
353+
354+
private SyntaxNode UpdateCallSites(SyntaxNode originalRoot, SemanticModel model, IMethodSymbol methodSymbol, string parameterObjectName, string parameterObjectType, SyntaxNode newRoot)
355+
{
356+
var callSiteUpdates = new Dictionary<SyntaxNode, SyntaxNode>();
357+
var invocations = originalRoot.DescendantNodes().OfType<InvocationExpressionSyntax>().ToList();
358+
359+
foreach (var invocation in invocations)
360+
{
361+
var symbolInfo = model.GetSymbolInfo(invocation);
362+
if (symbolInfo.Symbol is IMethodSymbol invokedMethod &&
363+
SymbolEqualityComparer.Default.Equals(invokedMethod.OriginalDefinition, methodSymbol.OriginalDefinition))
364+
{
365+
// This is a call to our refactored method
366+
var updatedCall = CreateUpdatedMethodCall(invocation, parameterObjectName, parameterObjectType);
367+
if (updatedCall != null)
368+
{
369+
callSiteUpdates[invocation] = updatedCall;
370+
}
371+
}
372+
}
373+
374+
if (callSiteUpdates.Any())
375+
{
376+
// Apply call site updates to the new root by finding equivalent nodes
377+
var finalUpdates = new Dictionary<SyntaxNode, SyntaxNode>();
378+
foreach (var kvp in callSiteUpdates)
379+
{
380+
var originalCall = kvp.Key;
381+
var updatedCall = kvp.Value;
382+
383+
// Find the equivalent node in the new root
384+
var equivalentNode = FindEquivalentNode(newRoot, originalCall);
385+
if (equivalentNode != null)
386+
{
387+
finalUpdates[equivalentNode] = updatedCall;
388+
}
389+
}
390+
391+
if (finalUpdates.Any())
392+
{
393+
newRoot = newRoot.ReplaceNodes(finalUpdates.Keys, (original, rewritten) => finalUpdates[original]);
394+
AnsiConsole.WriteLine($"[green]Updated {finalUpdates.Count} call site(s) to use parameter object[/]");
395+
}
396+
}
397+
else
398+
{
399+
AnsiConsole.WriteLine("[yellow]No call sites found to update[/]");
400+
}
401+
402+
return newRoot;
403+
}
404+
405+
private SyntaxNode? FindEquivalentNode(SyntaxNode newRoot, SyntaxNode originalNode)
406+
{
407+
// Find a node in the new tree that corresponds to the original node
408+
// by comparing the structure rather than text spans (which may have changed)
409+
if (originalNode is InvocationExpressionSyntax originalInvocation)
410+
{
411+
return newRoot.DescendantNodes()
412+
.OfType<InvocationExpressionSyntax>()
413+
.FirstOrDefault(inv => AreInvocationsEquivalent(originalInvocation, inv));
414+
}
415+
416+
return null;
417+
}
418+
419+
private bool AreInvocationsEquivalent(InvocationExpressionSyntax original, InvocationExpressionSyntax candidate)
420+
{
421+
// Compare the method name
422+
var originalName = GetMethodName(original.Expression);
423+
var candidateName = GetMethodName(candidate.Expression);
424+
425+
if (originalName != candidateName)
426+
return false;
427+
428+
// Compare argument count and structure
429+
if (original.ArgumentList.Arguments.Count != candidate.ArgumentList.Arguments.Count)
430+
return false;
431+
432+
// Compare each argument (basic comparison of text representation)
433+
for (int i = 0; i < original.ArgumentList.Arguments.Count; i++)
434+
{
435+
var originalArg = original.ArgumentList.Arguments[i].ToString().Trim();
436+
var candidateArg = candidate.ArgumentList.Arguments[i].ToString().Trim();
437+
if (originalArg != candidateArg)
438+
return false;
439+
}
440+
441+
return true;
442+
}
443+
444+
private string GetMethodName(ExpressionSyntax expression)
445+
{
446+
return expression switch
447+
{
448+
IdentifierNameSyntax identifier => identifier.Identifier.ValueText,
449+
MemberAccessExpressionSyntax memberAccess => memberAccess.Name.Identifier.ValueText,
450+
_ => ""
451+
};
452+
}
453+
454+
private InvocationExpressionSyntax? CreateUpdatedMethodCall(InvocationExpressionSyntax originalCall, string parameterObjectName, string parameterObjectType)
455+
{
456+
if (originalCall.ArgumentList.Arguments.Count == 0)
457+
return null;
458+
459+
// Create the parameter object instantiation
460+
var objectCreation = parameterObjectType.ToLower() switch
461+
{
462+
"record" => CreateRecordInstantiation(parameterObjectName, originalCall.ArgumentList.Arguments),
463+
"struct" => CreateStructInstantiation(parameterObjectName, originalCall.ArgumentList.Arguments),
464+
_ => CreateClassInstantiation(parameterObjectName, originalCall.ArgumentList.Arguments)
465+
};
466+
467+
// Create new argument list with the parameter object
468+
var newArgument = SyntaxFactory.Argument(objectCreation);
469+
var newArgumentList = SyntaxFactory.ArgumentList(
470+
SyntaxFactory.SingletonSeparatedList(newArgument));
471+
472+
// Return the updated invocation
473+
return originalCall.WithArgumentList(newArgumentList);
474+
}
475+
476+
private ExpressionSyntax CreateClassInstantiation(string parameterObjectName, SeparatedSyntaxList<ArgumentSyntax> arguments)
477+
{
478+
return SyntaxFactory.ObjectCreationExpression(
479+
SyntaxFactory.IdentifierName(parameterObjectName))
480+
.WithArgumentList(SyntaxFactory.ArgumentList(arguments))
481+
.NormalizeWhitespace();
482+
}
483+
484+
private ExpressionSyntax CreateStructInstantiation(string parameterObjectName, SeparatedSyntaxList<ArgumentSyntax> arguments)
485+
{
486+
return SyntaxFactory.ObjectCreationExpression(
487+
SyntaxFactory.IdentifierName(parameterObjectName))
488+
.WithArgumentList(SyntaxFactory.ArgumentList(arguments))
489+
.NormalizeWhitespace();
490+
}
491+
492+
private ExpressionSyntax CreateRecordInstantiation(string parameterObjectName, SeparatedSyntaxList<ArgumentSyntax> arguments)
493+
{
494+
return SyntaxFactory.ObjectCreationExpression(
495+
SyntaxFactory.IdentifierName(parameterObjectName))
496+
.WithArgumentList(SyntaxFactory.ArgumentList(arguments))
497+
.NormalizeWhitespace();
498+
}
332499
}

0 commit comments

Comments
 (0)