In [None]:
#| default_exp core

In [None]:
#| export

class Parser:
    def __init__(self, input):
        self.input = input
        self.head = 0
        self.length = len(input)

    def peek(self):
        return self.input[self.head]

    def next(self):
        ret = self.peek()
        self.head += 1
        return ret

    def nextn(self,n):
        ret = self.input[self.head:self.head+n]
        if len(ret) == n:
            self.head+=n
            return ret
        raise Exception("not enough input")
            

    def r(self):
        self.head = 0

    def step(self): self.head += 1

    def __len__(self): return self.length
        
    def __repr__(self): return f"{self.input[:self.head]}|{self.input[self.head:]}"


In [None]:
Parser("asdf")

|asdf

In [None]:
#| export
class PRV():
    def __init__(self,v='',s=True,e=''):
        self.s = s
        self.v = v
        self.e = e
        
    def __bool__(self): return self.s
    
    def __repr__(self): return f"({self.s},{self.v},{self.e})"
    
    def __eq__(self,other): return self.v==other.v and self.s==other.s and self.e==other.e

In [None]:
#| export
def char():
    def p(input):
        try:
            n = input.next()
            return PRV(n)
        except:
            return PRV(s=False,e='end of input')

    return p

In [None]:
#| export
def nchar(n):
    def p(input):
        try: return PRV(input.nextn(n))
        except: return PRV(s=False,e='end of input')
    return p
                    

In [None]:
assert nchar(3)(Parser("abcdef")) == PRV("abc")

True

In [None]:
assert char()(Parser("1")) == PRV('1')

In [None]:
#| export
def satisfy(parser, acceptor):
    def p(input):
        head = input.head
        res = parser(input)
        if res:
            if acceptor(res.v): return res
            else: 
                input.head = head
                return PRV(s=False,e='satisfy failed')
        else:
            return res

    return p

In [None]:
#| export
def digit():
    return satisfy(char(), lambda x: x in "0123456789")


def ascii_letter():
    return satisfy(
        char(), lambda x: x in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
    )

In [None]:
assert digit()(Parser("1")) == PRV('1')
assert digit()(Parser("a")) == PRV(s=False,e='satisfy failed')

In [None]:
assert ascii_letter()(Parser("1")) == PRV(s=False,e='satisfy failed')
assert ascii_letter()(Parser("a")) == PRV('a')

In [None]:
#| export
def many(parser):
    def p(input):
        acc = []
        head = input.head
        while res := parser(input):
            head = input.head
            acc.append(res.v)
        input.head=head
        return PRV(acc)

    return p

In [None]:
assert many(digit())(Parser("123abc")) == PRV(['1','2','3'])
assert many(digit())(Parser("abc")) == PRV([])

In [None]:
#| export
def sequence(parsers):
    def p(input):
        head =input.head
        acc = []
        for parser in parsers:
            res = parser(input)
            if res: acc.append(res.v)
            else: 
                input.head=head
                return PRV(s=False,e=f"sequence failed")
            
        return PRV([x for x in acc if x])
    
    return p

In [None]:
assert sequence([digit(),digit(),digit()])(Parser("123")) == PRV(['1','2','3'])
assert sequence([digit(),ascii_letter(),digit()])(Parser("1a3")) == PRV(['1','a','3'])

In [None]:
#| export
def mapper(parser, funcs):
    def p(input):
        res = parser(input)
        if res:
            for f in funcs:
                res.v = f(res.v)
        return res

    return p

In [None]:
#| export
def accumulator(parser, acc):
    def p(input):
        while True:
            h = input.head
            r = parser(input)
            if r: acc.add(r.v)
            else:
                input.head = h
                break
        return PRV(acc)
    return p

In [None]:
#| export
def digits():
    def helper(x):
        if len(x) == 1:
            return x[0]
        else:
            return ''.join([x[0]]+x[1])
    return mapper(sequence([digit(), many(digit())]), 
                  [helper])

In [None]:
assert digits()(Parser("1")) == PRV('1')
assert digits()(Parser("123")) == PRV('123')
assert digits()(Parser("abc")) == PRV(s=False,e='sequence failed')

In [None]:
#| export
def integer():
    return mapper(digits(), [lambda x: int(x)])

In [None]:
assert integer()(Parser("123")) == PRV(123)

In [None]:
#| export
def ws():
    return many(satisfy(char(), lambda x: x[0] in " \t"))

def ws_():
    def p(input):
        res = ws()(input)
        res.v = []
        return res
    return p
    

In [None]:
#| export
def choice(parsers):
    def p(input):
        for parser in parsers:
            h=input.head
            r = parser(input)
            if r:
                return r
        return PRV(s=False,e="choice failed")
    return p

In [None]:
assert choice([digit(),ascii_letter()])(Parser("1")) == PRV("1")

In [None]:
assert choice([digit(),ascii_letter()])(Parser("a")) == PRV("a")

In [None]:
#| export
def nl():
    return satisfy(char(), lambda x: x == "\n")

def nl_():
    def p(input):
        res = nl()(input)
        res.v = []
        return res
    return p


In [None]:
#| export
def find(parser):
    def p(input):
        while input.head <= len(input):
            r = parser(input)
            if r: return r
            input.step()
        return PRV(s=False,e="find failed")
    return p

In [None]:
assert find(digit())(Parser("asdcc1")) == PRV("1")

In [None]:
#| export

def drop(parser): 
    def p(input):
        r = parser(input)
        if r: return PRV('')
        else: return r
    return p

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()