@@ -44,6 +44,11 @@ public class Settings : CommandSettings
4444 [ Description ( "Show what changes would be made without applying them" ) ]
4545 [ DefaultValue ( false ) ]
4646 public bool DryRun { get ; init ; } = false ;
47+
48+ [ CommandOption ( "--update-callers" ) ]
49+ [ Description ( "Automatically update call sites to use the new parameter object" ) ]
50+ [ DefaultValue ( true ) ]
51+ public bool UpdateCallers { get ; init ; } = true ;
4752 }
4853
4954 public override int Execute ( CommandContext context , Settings settings )
@@ -118,8 +123,21 @@ public async Task<int> ExecuteAsync(CommandContext context, Settings settings)
118123
119124 var newRoot = root . ReplaceNode ( containingType , updatedContainingType ) ;
120125
121- // TODO: Find and update all call sites (for now, just show a warning)
122- AnsiConsole . WriteLine ( "[yellow]Note: Call sites will need to be updated manually in this version[/]" ) ;
126+ // Find and update call sites if requested
127+ if ( settings . UpdateCallers )
128+ {
129+ var methodSymbol = model . GetDeclaredSymbol ( method ) ;
130+ if ( methodSymbol != null )
131+ {
132+ // We need to work with the original semantic model to find call sites
133+ // because the new syntax tree doesn't have a semantic model yet
134+ newRoot = UpdateCallSites ( root , model , methodSymbol , parameterObjectName , settings . ParameterObjectType , newRoot ) ;
135+ }
136+ }
137+ else
138+ {
139+ AnsiConsole . WriteLine ( "[yellow]Note: Call sites will need to be updated manually (use --update-callers to enable automatic updates)[/]" ) ;
140+ }
123141
124142 var result = newRoot . NormalizeWhitespace ( ) . ToFullString ( ) ;
125143
@@ -255,7 +273,10 @@ private ConstructorDeclarationSyntax CreateConstructor(string className, Separat
255273 SyntaxFactory . ExpressionStatement (
256274 SyntaxFactory . AssignmentExpression (
257275 SyntaxKind . SimpleAssignmentExpression ,
258- SyntaxFactory . IdentifierName ( p . Identifier . ValueText ) ,
276+ SyntaxFactory . MemberAccessExpression (
277+ SyntaxKind . SimpleMemberAccessExpression ,
278+ SyntaxFactory . ThisExpression ( ) ,
279+ SyntaxFactory . IdentifierName ( p . Identifier . ValueText ) ) ,
259280 SyntaxFactory . IdentifierName ( p . Identifier . ValueText ) ) ) ) ;
260281
261282 var body = SyntaxFactory . Block ( assignments ) ;
@@ -329,4 +350,150 @@ private bool IsInParameterContext(IdentifierNameSyntax identifier)
329350 // Check if this identifier is in a parameter declaration context
330351 return identifier . Ancestors ( ) . OfType < ParameterSyntax > ( ) . Any ( ) ;
331352 }
353+
354+ private SyntaxNode UpdateCallSites ( SyntaxNode originalRoot , SemanticModel model , IMethodSymbol methodSymbol , string parameterObjectName , string parameterObjectType , SyntaxNode newRoot )
355+ {
356+ var callSiteUpdates = new Dictionary < SyntaxNode , SyntaxNode > ( ) ;
357+ var invocations = originalRoot . DescendantNodes ( ) . OfType < InvocationExpressionSyntax > ( ) . ToList ( ) ;
358+
359+ foreach ( var invocation in invocations )
360+ {
361+ var symbolInfo = model . GetSymbolInfo ( invocation ) ;
362+ if ( symbolInfo . Symbol is IMethodSymbol invokedMethod &&
363+ SymbolEqualityComparer . Default . Equals ( invokedMethod . OriginalDefinition , methodSymbol . OriginalDefinition ) )
364+ {
365+ // This is a call to our refactored method
366+ var updatedCall = CreateUpdatedMethodCall ( invocation , parameterObjectName , parameterObjectType ) ;
367+ if ( updatedCall != null )
368+ {
369+ callSiteUpdates [ invocation ] = updatedCall ;
370+ }
371+ }
372+ }
373+
374+ if ( callSiteUpdates . Any ( ) )
375+ {
376+ // Apply call site updates to the new root by finding equivalent nodes
377+ var finalUpdates = new Dictionary < SyntaxNode , SyntaxNode > ( ) ;
378+ foreach ( var kvp in callSiteUpdates )
379+ {
380+ var originalCall = kvp . Key ;
381+ var updatedCall = kvp . Value ;
382+
383+ // Find the equivalent node in the new root
384+ var equivalentNode = FindEquivalentNode ( newRoot , originalCall ) ;
385+ if ( equivalentNode != null )
386+ {
387+ finalUpdates [ equivalentNode ] = updatedCall ;
388+ }
389+ }
390+
391+ if ( finalUpdates . Any ( ) )
392+ {
393+ newRoot = newRoot . ReplaceNodes ( finalUpdates . Keys , ( original , rewritten ) => finalUpdates [ original ] ) ;
394+ AnsiConsole . WriteLine ( $ "[green]Updated { finalUpdates . Count } call site(s) to use parameter object[/]") ;
395+ }
396+ }
397+ else
398+ {
399+ AnsiConsole . WriteLine ( "[yellow]No call sites found to update[/]" ) ;
400+ }
401+
402+ return newRoot ;
403+ }
404+
405+ private SyntaxNode ? FindEquivalentNode ( SyntaxNode newRoot , SyntaxNode originalNode )
406+ {
407+ // Find a node in the new tree that corresponds to the original node
408+ // by comparing the structure rather than text spans (which may have changed)
409+ if ( originalNode is InvocationExpressionSyntax originalInvocation )
410+ {
411+ return newRoot . DescendantNodes ( )
412+ . OfType < InvocationExpressionSyntax > ( )
413+ . FirstOrDefault ( inv => AreInvocationsEquivalent ( originalInvocation , inv ) ) ;
414+ }
415+
416+ return null ;
417+ }
418+
419+ private bool AreInvocationsEquivalent ( InvocationExpressionSyntax original , InvocationExpressionSyntax candidate )
420+ {
421+ // Compare the method name
422+ var originalName = GetMethodName ( original . Expression ) ;
423+ var candidateName = GetMethodName ( candidate . Expression ) ;
424+
425+ if ( originalName != candidateName )
426+ return false ;
427+
428+ // Compare argument count and structure
429+ if ( original . ArgumentList . Arguments . Count != candidate . ArgumentList . Arguments . Count )
430+ return false ;
431+
432+ // Compare each argument (basic comparison of text representation)
433+ for ( int i = 0 ; i < original . ArgumentList . Arguments . Count ; i ++ )
434+ {
435+ var originalArg = original . ArgumentList . Arguments [ i ] . ToString ( ) . Trim ( ) ;
436+ var candidateArg = candidate . ArgumentList . Arguments [ i ] . ToString ( ) . Trim ( ) ;
437+ if ( originalArg != candidateArg )
438+ return false ;
439+ }
440+
441+ return true ;
442+ }
443+
444+ private string GetMethodName ( ExpressionSyntax expression )
445+ {
446+ return expression switch
447+ {
448+ IdentifierNameSyntax identifier => identifier . Identifier . ValueText ,
449+ MemberAccessExpressionSyntax memberAccess => memberAccess . Name . Identifier . ValueText ,
450+ _ => ""
451+ } ;
452+ }
453+
454+ private InvocationExpressionSyntax ? CreateUpdatedMethodCall ( InvocationExpressionSyntax originalCall , string parameterObjectName , string parameterObjectType )
455+ {
456+ if ( originalCall . ArgumentList . Arguments . Count == 0 )
457+ return null ;
458+
459+ // Create the parameter object instantiation
460+ var objectCreation = parameterObjectType . ToLower ( ) switch
461+ {
462+ "record" => CreateRecordInstantiation ( parameterObjectName , originalCall . ArgumentList . Arguments ) ,
463+ "struct" => CreateStructInstantiation ( parameterObjectName , originalCall . ArgumentList . Arguments ) ,
464+ _ => CreateClassInstantiation ( parameterObjectName , originalCall . ArgumentList . Arguments )
465+ } ;
466+
467+ // Create new argument list with the parameter object
468+ var newArgument = SyntaxFactory . Argument ( objectCreation ) ;
469+ var newArgumentList = SyntaxFactory . ArgumentList (
470+ SyntaxFactory . SingletonSeparatedList ( newArgument ) ) ;
471+
472+ // Return the updated invocation
473+ return originalCall . WithArgumentList ( newArgumentList ) ;
474+ }
475+
476+ private ExpressionSyntax CreateClassInstantiation ( string parameterObjectName , SeparatedSyntaxList < ArgumentSyntax > arguments )
477+ {
478+ return SyntaxFactory . ObjectCreationExpression (
479+ SyntaxFactory . IdentifierName ( parameterObjectName ) )
480+ . WithArgumentList ( SyntaxFactory . ArgumentList ( arguments ) )
481+ . NormalizeWhitespace ( ) ;
482+ }
483+
484+ private ExpressionSyntax CreateStructInstantiation ( string parameterObjectName , SeparatedSyntaxList < ArgumentSyntax > arguments )
485+ {
486+ return SyntaxFactory . ObjectCreationExpression (
487+ SyntaxFactory . IdentifierName ( parameterObjectName ) )
488+ . WithArgumentList ( SyntaxFactory . ArgumentList ( arguments ) )
489+ . NormalizeWhitespace ( ) ;
490+ }
491+
492+ private ExpressionSyntax CreateRecordInstantiation ( string parameterObjectName , SeparatedSyntaxList < ArgumentSyntax > arguments )
493+ {
494+ return SyntaxFactory . ObjectCreationExpression (
495+ SyntaxFactory . IdentifierName ( parameterObjectName ) )
496+ . WithArgumentList ( SyntaxFactory . ArgumentList ( arguments ) )
497+ . NormalizeWhitespace ( ) ;
498+ }
332499}
0 commit comments