Skip to content

Commit

Permalink
Fix #1217
Browse files Browse the repository at this point in the history
The root cause was that ATNConfigSet was not using he required custom hashing strategy for ParserATNSimulator.
The commit includes a number of additional fixes, related to code that was never executed before due to the root cause.
A similar issue is also likely to exist in the JavaScript runtime, I'll fix it later.
  • Loading branch information
ericvergnaud committed Jun 23, 2016
1 parent 15430d4 commit 26c4091
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 205 deletions.
73 changes: 38 additions & 35 deletions runtime/Python2/src/antlr4/PredictionContext.py
Expand Up @@ -30,6 +30,7 @@
#/
from io import StringIO
from antlr4.RuleContext import RuleContext
from antlr4.atn.ATN import ATN
from antlr4.atn.ATNState import ATNState


Expand Down Expand Up @@ -98,7 +99,7 @@ def calculateHashCode(parent, returnState):

def calculateListsHashCode(parents, returnStates ):
h = 0
for parent, returnState in parents, returnStates:
for parent, returnState in zip(parents, returnStates):
h = hash((h, calculateHashCode(parent, returnState)))
return h

Expand Down Expand Up @@ -254,6 +255,10 @@ def __unicode__(self):
buf.write(u"]")
return buf.getvalue()

def __hash__(self):
return self.cachedHashCode



# Convert a {@link RuleContext} tree to a {@link PredictionContext} graph.
# Return {@link #EMPTY} if {@code outerContext} is empty or null.
Expand Down Expand Up @@ -328,18 +333,18 @@ def merge(a, b, rootIsWildcard, mergeCache):
#/
def mergeSingletons(a, b, rootIsWildcard, mergeCache):
if mergeCache is not None:
previous = mergeCache.get(a,b)
previous = mergeCache.get((a,b), None)
if previous is not None:
return previous
previous = mergeCache.get(b,a)
previous = mergeCache.get((b,a), None)
if previous is not None:
return previous

rootMerge = mergeRoot(a, b, rootIsWildcard)
if rootMerge is not None:
merged = mergeRoot(a, b, rootIsWildcard)
if merged is not None:
if mergeCache is not None:
mergeCache.put(a, b, rootMerge)
return rootMerge
mergeCache[(a, b)] = merged
return merged

if a.returnState==b.returnState:
parent = merge(a.parentCtx, b.parentCtx, rootIsWildcard, mergeCache)
Expand All @@ -352,10 +357,10 @@ def mergeSingletons(a, b, rootIsWildcard, mergeCache):
# merge parents x and y, giving array node with x,y then remainders
# of those graphs. dup a, a' points at merged array
# new joined parent so create new singleton pointing to it, a'
a_ = SingletonPredictionContext.create(parent, a.returnState)
merged = SingletonPredictionContext.create(parent, a.returnState)
if mergeCache is not None:
mergeCache.put(a, b, a_)
return a_
mergeCache[(a, b)] = merged
return merged
else: # a != b payloads differ
# see if we can collapse parents due to $+x parents if local ctx
singleParent = None
Expand All @@ -365,26 +370,24 @@ def mergeSingletons(a, b, rootIsWildcard, mergeCache):
# sort payloads and use same parent
payloads = [ a.returnState, b.returnState ]
if a.returnState > b.returnState:
payloads[0] = b.returnState
payloads[1] = a.returnState
payloads = [ b.returnState, a.returnState ]
parents = [singleParent, singleParent]
a_ = ArrayPredictionContext(parents, payloads)
merged = ArrayPredictionContext(parents, payloads)
if mergeCache is not None:
mergeCache.put(a, b, a_)
return a_
mergeCache[(a, b)] = merged
return merged
# parents differ and can't merge them. Just pack together
# into array; can't merge.
# ax + by = [ax,by]
payloads = [ a.returnState, b.returnState ]
parents = [ a.parentCtx, b.parentCtx ]
if a.returnState > b.returnState: # sort by payload
payloads[0] = b.returnState
payloads[1] = a.returnState
payloads = [ b.returnState, a.returnState ]
parents = [ b.parentCtx, a.parentCtx ]
a_ = ArrayPredictionContext(parents, payloads)
merged = ArrayPredictionContext(parents, payloads)
if mergeCache is not None:
mergeCache.put(a, b, a_)
return a_
mergeCache[(a, b)] = merged
return merged


#
Expand Down Expand Up @@ -466,10 +469,10 @@ def mergeRoot(a, b, rootIsWildcard):
#/
def mergeArrays(a, b, rootIsWildcard, mergeCache):
if mergeCache is not None:
previous = mergeCache.get(a,b)
previous = mergeCache.get((a,b), None)
if previous is not None:
return previous
previous = mergeCache.get(b,a)
previous = mergeCache.get((b,a), None)
if previous is not None:
return previous

Expand All @@ -478,8 +481,8 @@ def mergeArrays(a, b, rootIsWildcard, mergeCache):
j = 0 # walks b
k = 0 # walks target M array

mergedReturnStates = [] * (len(a.returnState) + len( b.returnStates))
mergedParents = [] * len(mergedReturnStates)
mergedReturnStates = [None] * (len(a.returnStates) + len( b.returnStates))
mergedParents = [None] * len(mergedReturnStates)
# walk and merge to yield mergedParents, mergedReturnStates
while i<len(a.returnStates) and j<len(b.returnStates):
a_parent = a.parents[i]
Expand Down Expand Up @@ -525,30 +528,30 @@ def mergeArrays(a, b, rootIsWildcard, mergeCache):
# trim merged if we combined a few that had same stack tops
if k < len(mergedParents): # write index < last position; trim
if k == 1: # for just one merged element, return singleton top
a_ = SingletonPredictionContext.create(mergedParents[0], mergedReturnStates[0])
merged = SingletonPredictionContext.create(mergedParents[0], mergedReturnStates[0])
if mergeCache is not None:
mergeCache.put(a,b,a_)
return a_
mergeCache[(a,b)] = merged
return merged
mergedParents = mergedParents[0:k]
mergedReturnStates = mergedReturnStates[0:k]

M = ArrayPredictionContext(mergedParents, mergedReturnStates)
merged = ArrayPredictionContext(mergedParents, mergedReturnStates)

# if we created same array as a or b, return that instead
# TODO: track whether this is possible above during merge sort for speed
if M==a:
if merged==a:
if mergeCache is not None:
mergeCache.put(a,b,a)
mergeCache[(a,b)] = a
return a
if M==b:
if merged==b:
if mergeCache is not None:
mergeCache.put(a,b,b)
mergeCache[(a,b)] = b
return b
combineCommonParents(mergedParents)

if mergeCache is not None:
mergeCache.put(a,b,M)
return M
mergeCache[(a,b)] = merged
return merged


#
Expand Down Expand Up @@ -642,6 +645,6 @@ def getAllContextNodes(context, nodes=None, visited=None):
visited.put(context, context)
nodes.add(context)
for i in range(0, len(context)):
getAllContextNodes(context.getParent(i), nodes, visited);
getAllContextNodes(context.getParent(i), nodes, visited)
return nodes

25 changes: 25 additions & 0 deletions runtime/Python2/src/antlr4/atn/ATNConfig.py
Expand Up @@ -95,6 +95,19 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.state.stateNumber, self.alt, self.context, self.semanticContext))

def hashCodeForConfigSet(self):
return hash((self.state.stateNumber, self.alt, hash(self.semanticContext)))

def equalsForConfigSet(self, other):
if self is other:
return True
elif not isinstance(other, ATNConfig):
return False
else:
return self.state.stateNumber==other.state.stateNumber \
and self.alt==other.alt \
and self.semanticContext==other.semanticContext

def __str__(self):
return unicode(self)

Expand Down Expand Up @@ -144,6 +157,18 @@ def __eq__(self, other):
return False
return super(LexerATNConfig, self).__eq__(other)



def hashCodeForConfigSet(self):
return hash(self)



def equalsForConfigSet(self, other):
return self==other



def checkNonGreedyDecision(self, source, target):
return source.passedThroughNonGreedyDecision \
or isinstance(target, DecisionState) and target.nonGreedy
8 changes: 4 additions & 4 deletions runtime/Python2/src/antlr4/atn/ATNConfigSet.py
Expand Up @@ -105,8 +105,8 @@ def add(self, config, mergeCache=None):
rootIsWildcard = not self.fullCtx
merged = merge(existing.context, config.context, rootIsWildcard, mergeCache)
# no need to check for existing.context, config.context in cache
# since only way to create new graphs is "call rule" and here. We
# cache at both places.
# since only way to create new graphs is "call rule" and here.
# We cache at both places.
existing.reachesIntoOuterContext = max(existing.reachesIntoOuterContext, config.reachesIntoOuterContext)
# make sure to preserve the precedence filter suppression during the merge
if config.precedenceFilterSuppressed:
Expand All @@ -115,11 +115,11 @@ def add(self, config, mergeCache=None):
return True

def getOrAdd(self, config):
h = hash(config)
h = config.hashCodeForConfigSet()
l = self.configLookup.get(h, None)
if l is not None:
for c in l:
if c==config:
if config.equalsForConfigSet(c):
return c
if l is None:
l = [config]
Expand Down
36 changes: 18 additions & 18 deletions runtime/Python2/src/antlr4/atn/LexerATNSimulator.py
Expand Up @@ -130,7 +130,7 @@ def reset(self):
def matchATN(self, input):
startState = self.atn.modeToStartState[self.mode]

if self.debug:
if LexerATNSimulator.debug:
print("matchATN mode " + str(self.mode) + " start: " + str(startState))

old_mode = self.mode
Expand All @@ -144,13 +144,13 @@ def matchATN(self, input):

predict = self.execATN(input, next)

if self.debug:
if LexerATNSimulator.debug:
print("DFA after matchATN: " + str(self.decisionToDFA[old_mode].toLexerString()))

return predict

def execATN(self, input, ds0):
if self.debug:
if LexerATNSimulator.debug:
print("start state closure=" + str(ds0.configs))

if ds0.isAcceptState:
Expand All @@ -161,8 +161,8 @@ def execATN(self, input, ds0):
s = ds0 # s is current/from DFA state

while True: # while more work
if self.debug:
print("execATN loop starting closure: %s\n", s.configs)
if LexerATNSimulator.debug:
print("execATN loop starting closure:", str(s.configs))

# As we move src->trg, src->trg, we keep track of the previous trg to
# avoid looking up the DFA state again, which is expensive.
Expand Down Expand Up @@ -223,8 +223,8 @@ def getExistingTargetState(self, s, t):
return None

target = s.edges[t - self.MIN_DFA_EDGE]
if self.debug and target is not None:
print("reuse state "+s.stateNumber+ " edge to "+target.stateNumber)
if LexerATNSimulator.debug and target is not None:
print("reuse state", str(s.stateNumber), "edge to", str(target.stateNumber))

return target

Expand Down Expand Up @@ -280,8 +280,8 @@ def getReachableConfigSet(self, input, closure, reach, t):
if currentAltReachedAcceptState and cfg.passedThroughNonGreedyDecision:
continue

if self.debug:
print("testing %s at %s\n", self.getTokenName(t), cfg.toString(self.recog, True))
if LexerATNSimulator.debug:
print("testing", self.getTokenName(t), "at", str(cfg))

for trans in cfg.state.transitions: # for each transition
target = self.getReachableTarget(trans, t)
Expand All @@ -298,8 +298,8 @@ def getReachableConfigSet(self, input, closure, reach, t):
skipAlt = cfg.alt

def accept(self, input, lexerActionExecutor, startIndex, index, line, charPos):
if self.debug:
print("ACTION %s\n", lexerActionExecutor)
if LexerATNSimulator.debug:
print("ACTION", lexerActionExecutor)

# seek to after last char in token
input.seek(index)
Expand Down Expand Up @@ -334,15 +334,15 @@ def computeStartState(self, input, p):
# {@code false}.
def closure(self, input, config, configs, currentAltReachedAcceptState,
speculative, treatEofAsEpsilon):
if self.debug:
print("closure("+config.toString(self.recog, True)+")")
if LexerATNSimulator.debug:
print("closure(" + str(config) + ")")

if isinstance( config.state, RuleStopState ):
if self.debug:
if LexerATNSimulator.debug:
if self.recog is not None:
print("closure at %s rule stop %s\n", self.recog.getRuleNames()[config.state.ruleIndex], config)
print("closure at", self.recog.symbolicNames[config.state.ruleIndex], "rule stop", str(config))
else:
print("closure at rule stop %s\n", config)
print("closure at rule stop", str(config))

if config.context is None or config.context.hasEmptyPath():
if config.context is None or config.context.isEmpty():
Expand Down Expand Up @@ -404,7 +404,7 @@ def getEpsilonTarget(self, input, config, t, configs, speculative, treatEofAsEps
# states reached by traversing predicates. Since this is when we
# test them, we cannot cash the DFA state target of ID.

if self.debug:
if LexerATNSimulator.debug:
print("EVAL rule "+ str(t.ruleIndex) + ":" + str(t.predIndex))
configs.hasSemanticContext = True
if self.evaluatePredicate(input, t.ruleIndex, t.predIndex, speculative):
Expand Down Expand Up @@ -516,7 +516,7 @@ def addDFAEdge(self, from_, tk, to=None, cfgs=None):
# Only track edges within the DFA bounds
return to

if self.debug:
if LexerATNSimulator.debug:
print("EDGE " + str(from_) + " -> " + str(to) + " upon "+ chr(tk))

if from_.edges is None:
Expand Down

0 comments on commit 26c4091

Please sign in to comment.