Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SyntaxClassifier with SyntaxVisitor #2087

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
255 changes: 122 additions & 133 deletions Sources/SwiftIDEUtils/SyntaxClassifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ extension TokenSyntax {
contextualClassification: contextualClassification
)
}

fileprivate func contextFreeClassification() -> SyntaxClassification {
let kind = tokenKind.decomposeToRaw().rawKind
if kind == .unknown, text.hasPrefix("\"") {
return .stringLiteral
}
if kind == .identifier,
text.hasPrefix("<#"),
text.hasSuffix("#>")
{
return .editorPlaceholder
}
return kind.classification
}
}

extension RawTriviaPiece {
Expand Down Expand Up @@ -94,51 +108,52 @@ public struct SyntaxClassifiedRange: Equatable {
public var endOffset: Int { return range.endOffset }
}

private struct ClassificationVisitor {
private enum VisitResult {
case `continue`
case `break`
class ClassificationGenerator: SyntaxVisitor {
var classifications: [SyntaxClassifiedRange] = []

private init() {
super.init(viewMode: .sourceAccurate)
}

private struct Descriptor {
var node: RawSyntax
var byteOffset: Int
var contextualClassification: (SyntaxClassification, Bool)?
public static func classify(for tree: some SyntaxProtocol, in range: ByteSourceRange?) -> [SyntaxClassifiedRange] {
let generator = ClassificationGenerator()
generator.walk(tree)
return generator.classifications
}

/// Only tokens within this absolute range will be classified. No
/// classifications will be reported for tokens out of this range.
private var targetRange: ByteSourceRange

var classifications: [SyntaxClassifiedRange]

/// Only classify tokens in `relativeClassificationRange`, where the start
/// offset is relative to `node`.
init(node: Syntax, relativeClassificationRange: ByteSourceRange) {
let range = ByteSourceRange(
offset: node.position.utf8Offset + relativeClassificationRange.offset,
length: relativeClassificationRange.length
)
self.targetRange = range
self.classifications = []

// `withExtendedLifetime` to make sure ``SyntaxArena`` for the node alive
// during the visit.
withExtendedLifetime(node) {
_ = self.visit(
Descriptor(
node: node.raw,
byteOffset: node.position.utf8Offset,
contextualClassification: node.contextualClassification
)
)

private func classify(triviaPieces: [RawTriviaPiece], at offset: Int) {
var classifiedBytes = 0
for triviaPiece in triviaPieces {
let range = triviaPiece.classify(offset: offset + classifiedBytes)
report(range: range)
classifiedBytes += triviaPiece.byteLength
}
}

private func classify(offset: Int, length: Int, as kind: SyntaxClassification) {
let range = SyntaxClassifiedRange(kind: kind, range: ByteSourceRange(offset: offset, length: length))
report(range: range)
}

private mutating func report(range: SyntaxClassifiedRange) {
if range.kind == .none && range.length == 0 {
public func classify(_ node: some SyntaxProtocol, as kind: SyntaxClassification) {
guard let token = node.as(TokenSyntax.self),
token.tokenKind.decomposeToRaw().rawKind == .identifier else {
return
}
let range = SyntaxClassifiedRange(kind: kind, range: ByteSourceRange(offset: node.positionAfterSkippingLeadingTrivia.utf8Offset, length: node.trimmedLength.utf8Length))
report(range: range)
}

private func report(range: SyntaxClassifiedRange) {
// TODO: we should not report the ranges that have been reported already
//if classifications.contains(where: { $0.range.contains(range.range)}) {
// return
//}

if range.length == 0 {
return
}


// Merge consecutive classified ranges of the same kind.
if let last = classifications.last,
Expand All @@ -151,127 +166,101 @@ private struct ClassificationVisitor {
)
return
}

guard range.offset <= targetRange.endOffset,
range.endOffset >= targetRange.offset
else {
return
}

// TODO: add targetRange for ClassificationGenerator
/*
guard range.offset <= targetRange.endOffset,
range.endOffset >= targetRange.offset
else {
return
}
*/
//classifications.sort(by: {$0.offset < $1.offset})
classifications.append(range)
}

/// Classifies `triviaPieces` starting from `offset` and returns the number of bytes the trivia took up in the source
private mutating func classify(triviaPieces: [RawTriviaPiece], at offset: Int) -> Int {
var classifiedBytes = 0
for triviaPiece in triviaPieces {
let range = triviaPiece.classify(offset: offset + classifiedBytes)
report(range: range)
classifiedBytes += triviaPiece.byteLength
}
return classifiedBytes

override func visit(_ node: TokenSyntax) -> SyntaxVisitorContinueKind {
let tokenView = node.tokenView
// Leading trivia
classify(triviaPieces: tokenView.leadingRawTriviaPieces, at: node.position.utf8Offset)
// Token text
let range = SyntaxClassifiedRange(kind: node.contextFreeClassification(), range: ByteSourceRange(offset: node.positionAfterSkippingLeadingTrivia.utf8Offset, length: node.trimmedLength.utf8Length))
report(range: range)
// Trailing trivia
classify(triviaPieces: tokenView.trailingRawTriviaPieces, at: node.endPositionBeforeTrailingTrivia.utf8Offset)
return .visitChildren
}

override func visit(_ node: AttributeSyntax) -> SyntaxVisitorContinueKind {
let atSignOffset = node.atSign.position.utf8Offset
let length = node.atSign.totalLength - node.atSign.leadingTriviaLength + node.attributeName.totalLength - node.attributeName.trailingTriviaLength
classify(offset: atSignOffset, length: length.utf8Length, as: .attribute)
return .visitChildren
}

// Report classification ranges in `descriptor.node` that is a token.
private mutating func handleToken(_ descriptor: Descriptor) -> VisitResult {
let tokenView = descriptor.node.tokenView!
var byteOffset = descriptor.byteOffset
override func visit(_ node: DeclModifierSyntax) -> SyntaxVisitorContinueKind {
classify(node.name, as: .attribute)
return .visitChildren
}

// Leading trivia.
byteOffset += classify(triviaPieces: tokenView.leadingRawTriviaPieces, at: byteOffset)
// Token text.
do {
let range = TokenKindAndText(kind: tokenView.rawKind, text: tokenView.rawText)
.classify(offset: byteOffset, contextualClassification: descriptor.contextualClassification)
report(range: range)
byteOffset += tokenView.rawText.count
override func visit(_ node: IfConfigClauseSyntax) -> SyntaxVisitorContinueKind {
classify(node.poundKeyword, as: .ifConfigDirective)
if let condition = node.condition {
classify(condition, as: .ifConfigDirective)
}
// Trailing trivia.
byteOffset += classify(triviaPieces: tokenView.trailingRawTriviaPieces, at: byteOffset)
return .visitChildren
}

precondition(byteOffset == descriptor.byteOffset + descriptor.node.byteLength)
return .continue
override func visit(_ node: IfConfigDeclSyntax) -> SyntaxVisitorContinueKind {
classify(node.poundEndif, as: .ifConfigDirective)
return .visitChildren
}

/// Call `visit()` on all `descriptor.node` non-nil children.
private mutating func handleLayout(_ descriptor: Descriptor) -> VisitResult {
let children = descriptor.node.layoutView!.children
var byteOffset = descriptor.byteOffset
override func visit(_ node: MemberTypeSyntax) -> SyntaxVisitorContinueKind {
classify(node.name, as: .type)
return .visitChildren
}

for case (let index, let child?) in children.enumerated() {
override func visit(_ node: OperatorDeclSyntax) -> SyntaxVisitorContinueKind {
classify(node.name, as: .operator)
return .visitChildren
}

let classification: (classification: SyntaxClassification, force: Bool)?
if case .layout(let layout) = descriptor.node.kind.syntaxNodeType.structure {
classification = SyntaxClassification.classify(layout[index])
} else {
classification = nil
}
override func visit(_ node: PlatformVersionSyntax) -> SyntaxVisitorContinueKind {
classify(node.platform, as: .keyword)
return .visitChildren
}

if let classification, classification.force {
// Leading trivia.
if let leadingTriviaPieces = child.leadingTriviaPieces {
byteOffset += classify(triviaPieces: leadingTriviaPieces, at: byteOffset)
}
// Layout node text.
let layoutNodeTextLength = child.byteLength - child.leadingTriviaByteLength - child.trailingTriviaByteLength
let range = SyntaxClassifiedRange(
kind: classification.classification,
range: ByteSourceRange(
offset: byteOffset,
length: layoutNodeTextLength
)
)
report(range: range)
byteOffset += layoutNodeTextLength
override func visit(_ node: PlatformVersionItemSyntax) -> SyntaxVisitorContinueKind {
classify(node.platformVersion, as: .keyword)
return .visitChildren
}

// Trailing trivia.
if let trailingTriviaPieces = child.trailingTriviaPieces {
byteOffset += classify(triviaPieces: trailingTriviaPieces, at: byteOffset)
}
continue
}
override func visit(_ node: PrecedenceGroupAssociativitySyntax) -> SyntaxVisitorContinueKind {
classify(node.associativityLabel, as: .keyword)
return .visitChildren
}

let result = visit(
.init(
node: child,
byteOffset: byteOffset,
contextualClassification: classification ?? descriptor.contextualClassification
)
)
if result == .break {
return .break
}
byteOffset += child.byteLength
}
return .continue
override func visit(_ node: PrecedenceGroupRelationSyntax) -> SyntaxVisitorContinueKind {
classify(node.higherThanOrLowerThanLabel, as: .keyword)
return .visitChildren
}

private mutating func visit(_ descriptor: ClassificationVisitor.Descriptor) -> VisitResult {
guard descriptor.byteOffset < targetRange.endOffset else {
return .break
}
guard descriptor.byteOffset + descriptor.node.byteLength > targetRange.offset else {
return .continue
}
guard SyntaxTreeViewMode.sourceAccurate.shouldTraverse(node: descriptor.node) else {
return .continue
}
if descriptor.node.isToken {
return handleToken(descriptor)
} else {
return handleLayout(descriptor)
}
override func visit(_ node: IdentifierTypeSyntax) -> SyntaxVisitorContinueKind {
classify(node.name, as: .type)
return .visitChildren
}
}


/// Provides a sequence of ``SyntaxClassifiedRange``s for a syntax node.
public struct SyntaxClassifications: Sequence {
public typealias Iterator = Array<SyntaxClassifiedRange>.Iterator

var classifications: [SyntaxClassifiedRange]

public init(_ node: Syntax, in relRange: ByteSourceRange) {
let visitor = ClassificationVisitor(node: node, relativeClassificationRange: relRange)
self.classifications = visitor.classifications
self.classifications = ClassificationGenerator.classify(for: node, in: relRange)
}

public func makeIterator() -> Iterator {
Expand Down
7 changes: 7 additions & 0 deletions Sources/SwiftSyntax/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public struct ByteSourceRange: Equatable {
return self.endOffset > other.offset && self.offset < other.endOffset
}

public func contains(_ other: ByteSourceRange) -> Bool {
return self.endOffset >= other.endOffset && self.offset <= other.offset
}

public static func == (lhs: ByteSourceRange, rhs: ByteSourceRange) -> Bool {
return lhs.offset == rhs.offset && lhs.endOffset == rhs.endOffset
}
/// Returns the byte range for the overlapping region between two ranges.
public func intersected(_ other: ByteSourceRange) -> ByteSourceRange {
let start = max(self.offset, other.offset)
Expand Down
5 changes: 4 additions & 1 deletion Sources/swift-parser-cli/Commands/PerformanceTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import ArgumentParser
import Foundation
import SwiftParser
import SwiftSyntax
import SwiftIDEUtils

struct PerformanceTest: ParsableCommand {
static var configuration = CommandConfiguration(
Expand Down Expand Up @@ -66,10 +67,12 @@ struct PerformanceTest: ParsableCommand {
for _ in 0..<self.iterations {
for file in files {
file.withUnsafeBytes { buf in
_ = Parser.parseIncrementally(
let (tree, _) = Parser.parseIncrementally(
source: buf.bindMemory(to: UInt8.self),
parseTransition: fileTransition[file]
)
// TODO: temp test
tree.classifications
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion Tests/SwiftIDEUtilsTest/Assertions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func assertClassification(
} else {
classifications = Array(tree.classifications)
}
classifications = classifications.filter { $0.kind != .none }
classifications = classifications.filter { $0.kind != .none }.sorted(by: {$0.offset < $1.offset})

if expected.count != classifications.count {
XCTFail("Expected \(expected.count) re-used nodes but received \(classifications.count)", file: file, line: line)
Expand Down