Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions Sources/SwiftParser/Core.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,31 @@ public struct CodeContext {
self.errors = errors
self.input = input
}

/// Snapshot represents a parser state that can be restored later.
public struct Snapshot {
fileprivate let index: Int
fileprivate let node: CodeNode
fileprivate let childCount: Int
fileprivate let errorCount: Int
}

/// Capture the current parser state so it can be restored on demand.
public func snapshot() -> Snapshot {
Snapshot(index: index, node: currentNode, childCount: currentNode.children.count, errorCount: errors.count)
}

/// Restore the parser to a previously captured state, discarding any new nodes or errors.
public mutating func restore(_ snapshot: Snapshot) {
index = snapshot.index
currentNode = snapshot.node
if currentNode.children.count > snapshot.childCount {
currentNode.children.removeLast(currentNode.children.count - snapshot.childCount)
}
if errors.count > snapshot.errorCount {
errors.removeLast(errors.count - snapshot.errorCount)
}
}
}

public protocol CodeLanguage {
Expand Down
20 changes: 16 additions & 4 deletions Sources/SwiftParser/Languages/PythonLanguage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public struct PythonLanguage: CodeLanguage {
case identifier(String, Range<String.Index>)
case number(String, Range<String.Index>)
case string(String, Range<String.Index>)
case unterminatedString(String, Range<String.Index>)
case keyword(String, Range<String.Index>)
case equal(Range<String.Index>)
case colon(Range<String.Index>)
Expand All @@ -36,6 +37,7 @@ public struct PythonLanguage: CodeLanguage {
case .identifier: return "identifier"
case .number: return "number"
case .string: return "string"
case .unterminatedString: return "unterminatedString"
case .keyword(let k, _): return "keyword(\(k))"
case .equal: return "="
case .colon: return ":"
Expand All @@ -55,6 +57,8 @@ public struct PythonLanguage: CodeLanguage {
switch self {
case let .identifier(s, _), let .number(s, _), let .string(s, _), let .keyword(s, _):
return s
case let .unterminatedString(s, _):
return s
case .equal: return "="
case .colon: return ":"
case .comma: return ","
Expand All @@ -71,7 +75,7 @@ public struct PythonLanguage: CodeLanguage {

public var range: Range<String.Index> {
switch self {
case .identifier(_, let r), .number(_, let r), .string(_, let r), .keyword(_, let r), .equal(let r),
case .identifier(_, let r), .number(_, let r), .string(_, let r), .unterminatedString(_, let r), .keyword(_, let r), .equal(let r),
.colon(let r), .comma(let r), .plus(let r), .minus(let r), .star(let r), .slash(let r),
.lparen(let r), .rparen(let r), .newline(let r), .eof(let r):
return r
Expand Down Expand Up @@ -112,9 +116,14 @@ public struct PythonLanguage: CodeLanguage {
while index < input.endIndex && input[index] != quote {
advance()
}
if index < input.endIndex { advance() }
let text = String(input[start..<index])
add(.string(text, start..<index))
if index < input.endIndex {
advance()
let text = String(input[start..<index])
add(.string(text, start..<index))
} else {
let text = String(input[start..<index])
add(.unterminatedString(text, start..<index))
}
} else if ch.isLetter || ch == "_" {
let start = index
while index < input.endIndex && (input[index].isLetter || input[index].isNumber || input[index] == "_") {
Expand Down Expand Up @@ -176,6 +185,9 @@ public struct PythonLanguage: CodeLanguage {
return CodeNode(type: Element.number, value: text, range: range)
case .identifier(let text, let range):
return CodeNode(type: Element.identifier, value: text, range: range)
case .unterminatedString(let text, let range):
context.errors.append(CodeError("Unterminated string", range: range))
return CodeNode(type: Element.string, value: text, range: range)
case .lparen:
let node = parse(context: &context, minBP: 0)
if context.index < context.tokens.count, let r = context.tokens[context.index] as? Token, case .rparen = r {
Expand Down
22 changes: 22 additions & 0 deletions Tests/SwiftParserTests/SwiftParserTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,26 @@ final class SwiftParserTests: XCTestCase {

XCTAssertEqual(n1.id, n2.id)
}

func testUnterminatedStringError() {
let parser = SwiftParser()
let source = "x = \"hello"
let result = parser.parse(source, language: PythonLanguage())
XCTAssertEqual(result.errors.count, 1)
}

func testContextSnapshotRestore() {
let tokenizer = PythonLanguage.Tokenizer()
let tokens = tokenizer.tokenize("x = 1")
let root = CodeNode(type: PythonLanguage.Element.root, value: "")
var ctx = CodeContext(tokens: tokens, index: 0, currentNode: root, errors: [], input: "x = 1")
let snap = ctx.snapshot()
ctx.index = 2
ctx.errors.append(CodeError("err"))
ctx.currentNode.addChild(CodeNode(type: PythonLanguage.Element.number, value: "1"))
ctx.restore(snap)
XCTAssertEqual(ctx.index, 0)
XCTAssertEqual(ctx.errors.count, 0)
XCTAssertEqual(root.children.count, 0)
}
}