In [5]:
from dataset.grammar import get_cfg, sample_programs, parse_program

grammar = get_cfg()

programs = []
for program in sample_programs(grammar, n=1000):
    if not parse_program(program):
        program = "# Compiled: ❌\n" + program
        programs.append(program)

print("\n\n=========================\n\n".join(programs))












In [80]:
from dataset.grammar import get_cfg
from nltk.parse import EarleyChartParser, RecursiveDescentParser
from nltk.grammar import Nonterminal
import random
from nltk import Nonterminal
from nltk import CFG

import ast

def parse_program(program):
    try:
        ast.parse(program)
        return True
    except Exception as e:
        return False

def valid_next_tokens(chart, position, grammar):
    """
    Given a chart and the current position in the input,
    return the set of terminals that could appear next.
    """
    candidates = set()
    for edge in chart.select(end=position):
        # edge is something like A -> α • β, [i:j]
        next_sym = edge.nextsym()
        if next_sym is None:
            continue
        if isinstance(next_sym, Nonterminal):
            # Expand FIRST set of this nonterminal
            for prod in grammar.productions(lhs=next_sym):
                first = prod.rhs()[0]
                if isinstance(first, str):  # terminal
                    candidates.add(first)
        else:
            # Directly a terminal
            candidates.add(next_sym)
    return candidates



import random
from nltk import Nonterminal

def generate_random(grammar, symbol=None, max_depth=6):
    """Randomly generate a sentence from a CFG with depth control."""
    if symbol is None:
        symbol = grammar.start()

    # Terminal symbol
    if not isinstance(symbol, Nonterminal):
        return [symbol]

    # Get productions for this symbol
    prods = grammar.productions(lhs=symbol)

    if max_depth <= 0:
        # Only pick productions that do NOT recursively include this Nonterminal
        prods = [
            p for p in prods
            if all(not (isinstance(r, Nonterminal) and r == symbol) for r in p.rhs())
        ]
        if not prods:
            # fallback: pick any production to avoid crash
            prods = grammar.productions(lhs=symbol)

    prod = random.choice(prods)
    result = []
    for r in prod.rhs():
        # decrease depth on recursive non-terminals, else leave depth
        if isinstance(r, Nonterminal):
            result.extend(generate_random(grammar, r, max_depth-1))
        else:
            result.append(r)
    return result


import random
from nltk import Nonterminal

import random
from nltk import Nonterminal

def generate_random_v2(grammar, symbol=None, max_depth=6, defined_vars=None):
    """Randomly generate a sentence from a CFG with depth control and variable tracking."""
    if defined_vars is None:
        defined_vars = set()

    if symbol is None:
        symbol = grammar.start()

    # Terminal symbol
    if not isinstance(symbol, Nonterminal):
        return [symbol]

    # Get productions for this non-terminal
    prods = grammar.productions(lhs=symbol)

    if max_depth <= 0:
        # Filter out self-recursive productions
        prods = [
            p for p in prods
            if all(not (isinstance(r, Nonterminal) and r == symbol) for r in p.rhs())
        ]
        if not prods:
            prods = grammar.productions(lhs=symbol)

    # Pick a random production
    prod = random.choice(prods)
    result = []

    # Track local variable definitions
    local_defined = defined_vars.copy()

    for r in prod.rhs():
        # Special handling: assignments define variables
        if isinstance(r, Nonterminal) and r.symbol() == "ASSIGNMENT":
            stmt_tokens = generate_random_v2(grammar, r, max_depth-1, local_defined)
            # first token is VARIABLE (assume VARIABLE EQUALS ...)
            local_defined.add(stmt_tokens[0])
            result.extend(stmt_tokens)

        elif isinstance(r, Nonterminal) and r.symbol() == "PARAMS":
            params_tokens = generate_random_v2(grammar, r, max_depth-1, local_defined)
            for tok in params_tokens:
                if tok.isalpha():  # simple check for variable name
                    local_defined.add(tok)
            result.extend(params_tokens)

        elif isinstance(r, Nonterminal) and r.symbol() == "VARIABLE":
            # Pick only from defined variables if any
            if local_defined:
                result.append(random.choice(list(local_defined)))
            else:
                # fallback: pick a random single-letter variable
                result.append(random.choice([chr(c) for c in range(ord('a'), ord('z')+1)]))
        else:
            result.extend(generate_random_v2(grammar, r, max_depth-1, local_defined))

    return result




def pretty_print(tokens):
    code = []
    indent = 0
    i = 0
    while i < len(tokens):
        tok = tokens[i]
        if tok == "<NEWLINE>":
            code.append("\n")
        elif tok == "<INDENT>":
            indent += 1
            code.append("    " * indent)
        elif tok == "<DEDENT>":
            indent -= 1
            code.append("    " * indent)
        else:
            # space before normal tokens except after newline/indent
            if code and not code[-1].endswith(("\n", " ", "(", "[")):
                code.append(" ")
            code.append(tok)
        i += 1
    return "".join(code)




# Define a minimal CFG that can generate the 3 given programs

variables = [chr(c) for c in range(ord('a'), ord('z') + 1)]
digits = [str(i) for i in range(0, 21)]
binary_operators = ["+", "-", "*", "/", "<", ">", "<=", ">=", "!=", "==", "and", "or"]
unary_operators = ["not"]

terminal_rules = {
    "VARIABLE": variables,
    "DIGIT": digits,

    # keywords
    "DEF": ["def"],
    "PROGRAM_NAME": ["program"],
    "RETURN": ["return"],
    "IF": ["if"],
    "ELSE": ["else"],

    # syntax
    "LPAREN": ["("],
    "RPAREN": [")"],
    "COMMA": [","],
    "COLON": [":"],
    "EQUALS": ["="],
    "NEWLINE": ["<NEWLINE>"],
    "INDENT": ["<INDENT>"],
    "DEDENT": ["<DEDENT>"],

    # operators
    "ADDOP": ["+", "-"],
    "MULOP": ["*", "/"],
    "BINARY_CMP": ["<", ">", "<=", ">=", "==", "!="],
    "AND": ["and"],
    "OR": ["or"],
    "NOT": ["not"],
}

non_terminal_rules = {
    # Start
    "S": ["FUNC_DEF"],

    # Function definition
    "FUNC_DEF": ["DEF PROGRAM_NAME LPAREN PARAMS RPAREN COLON NEWLINE INDENT BODY DEDENT"],

    # Parameters
    "PARAMS": ["VARIABLE", "VARIABLE COMMA PARAMS"],

    # Function body: either a statement or an if-block
    "BODY": ["STMT", "IF_BLOCK", "ASSIGNMENT"],

    # Assignment and return
    "ASSIGNMENT": ["VARIABLE EQUALS EXPR NEWLINE"],
    "STMT": ["RETURN EXPR NEWLINE"],

    # If/else branching
    "IF_BLOCK": ["IF COND COLON NEWLINE INDENT STMT DEDENT ELSE_BLOCK"],
    "ELSE_BLOCK": ["ELSE COLON NEWLINE INDENT STMT DEDENT"],

    # Conditions (just expressions now, since precedence is layered)
    "COND": ["EXPR"],

    # ---- Expressions with operator precedence ----
    "EXPR": ["OR_EXPR"],

    "OR_EXPR": ["AND_EXPR", "OR_EXPR OR AND_EXPR"],

    "AND_EXPR": ["NOT_EXPR", "AND_EXPR AND NOT_EXPR"],

    "NOT_EXPR": ["NOT NOT_EXPR", "COMPARISON"],

    # Comparisons like a < b <= c
    "COMPARISON": ["ARITH_EXPR", "ARITH_EXPR COMP_CHAIN"],
    "COMP_CHAIN": ["BINARY_CMP ARITH_EXPR", "BINARY_CMP ARITH_EXPR COMP_CHAIN"],

    # Arithmetic expressions
    "ARITH_EXPR": ["TERM", "ARITH_EXPR ADDOP TERM"],
    "TERM": ["FACTOR", "TERM MULOP FACTOR"],

    # Atoms
    "FACTOR": ["VARIABLE", "DIGIT", "LPAREN EXPR RPAREN"],
}



lines = [
]
for lhs, rhs_list in terminal_rules.items():
    for rhs_item in rhs_list:
        lines.append(f"{lhs} -> '{rhs_item}'")

for lhs, rhs in non_terminal_rules.items():
    rhs_str = " | ".join(rhs)
    lines.append(f"{lhs} -> {rhs_str}")

grammar_text = "\n".join(lines)
grammar = CFG.fromstring(grammar_text)
grammar._start = Nonterminal("S")

for i in range(1000):
    tokens = generate_random_v2(grammar, max_depth=6)
    code = pretty_print(tokens)
    print("✅" if parse_program(code) else "❌")

    print(code)
    print("=========================================")


✅
def program (n ) :
    if n >= (n ) :
        return 12
    else :
        return 20 > ((4 != n ) )
    
✅
def program (b ) :
    b = 1 or 18 > b

✅
def program (p , y , q , i ) :
    if y :
        return (((0 == y ) < 12 ) != 17 ) != y
    else :
        return i > 11
    
✅
def program (w ) :
    w = w <= 15 or 14

✅
def program (c ) :
    c = c > c

✅
def program (b , y , l ) :
    if b :
        return 17 <= (15 < y )
    else :
        return 9 > 20
    
✅
def program (a ) :
    a = (16 < a ) or (3 >= (20 > a ) ) != 7

✅
def program (b ) :
    if b :
        return b < ((b ) != 4 )
    else :
        return ((9 ) > 16 )
    
✅
def program (h , x ) :
    return ((x ) ) >= 20

✅
def program (r ) :
    if r > 10 :
        return r
    else :
        return 20
    
✅
def program (u , n ) :
    return 8

✅
def program (u , w , b ) :
    if w <= 20 :
        return 18 == 0
    else :
        return (5 )
    
✅
def program (r , d , k , y ) :
    return k or d == 19

✅
def program (i )

In [None]:
for edge in chart.select(start=0, end=2, lhs="ALL"):
    print(edge)
    print(edge.child_pointer_lists())

In [None]:
for edge in chart.select(start=0, end=2, lhs="ALL"):
    print(edge)
    print(edge.child_pointer_lists())

In [None]:
for edge in chart.select(start=0, end=2, lhs="ALL"):
    print(edge)
    print(edge.child_pointer_lists())

In [49]:
"  " * 2

'    '

In [47]:
pretty_print(tokens)

'def program (o , p ) :\n  if k + not 3 :\n        return 11\n          else :\n        return q\n          '

In [31]:
print(pretty_print(tokens))

def program (x ) :
return c


In [25]:
for edge in chart.select(start=0, end=2, lhs="ALL"):
    print(edge)
    print(edge.child_pointer_lists())

In [2]:
from dataset.grammar import sample_programs

sample_programs(n=1, level="LEVEL4_1", max_depth=10)

['a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\nwhilea<0:\\n\\tprint(a)\\n\\ta=e+9']

In [6]:
grammar = get_cfg()
grammar._start = Nonterminal("ALL")

parser = RecursiveDescentParser(grammar)

tokens = ["x", "=", "10", "\\n"]
# tokens = [special_tokens.get(token, token) for token in tokens]

for tree in parser.parse(tokens):
    print(tree)

(ALL
  (LEVEL0_1
    (INITIALIZATION (VARIABLE x) (EQUALS =) (DIGIT 10) (NEW_LINE \n))))


KeyboardInterrupt: 

In [None]:
from dataset.grammar import sample_programs

sample_programs(n=10, level="LEVEL1_1")

['a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(a*g)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(t)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(v)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(e)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(a-u)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na

In [None]:
from dataset.grammar import sample_programs

sample_programs(n=10, level="LEVEL1_1")

['a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(a*g)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(t)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(v)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(e)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=a+a\\nprint(a-u)',
 'a=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na=0\\na

In [None]:
from dataset.grammar import get_cfg
from pprint import pprint
grammar = get_cfg()

pprint(grammar.productions())

[VARIABLE -> 'a',
 VARIABLE -> 'b',
 VARIABLE -> 'c',
 VARIABLE -> 'd',
 VARIABLE -> 'e',
 VARIABLE -> 'f',
 VARIABLE -> 'g',
 VARIABLE -> 'h',
 VARIABLE -> 'i',
 VARIABLE -> 'j',
 VARIABLE -> 'k',
 VARIABLE -> 'l',
 VARIABLE -> 'm',
 VARIABLE -> 'n',
 VARIABLE -> 'o',
 VARIABLE -> 'p',
 VARIABLE -> 'q',
 VARIABLE -> 'r',
 VARIABLE -> 's',
 VARIABLE -> 't',
 VARIABLE -> 'u',
 VARIABLE -> 'v',
 VARIABLE -> 'w',
 VARIABLE -> 'x',
 VARIABLE -> 'y',
 VARIABLE -> 'z',
 DIGIT -> '0',
 DIGIT -> '1',
 DIGIT -> '2',
 DIGIT -> '3',
 DIGIT -> '4',
 DIGIT -> '5',
 DIGIT -> '6',
 DIGIT -> '7',
 DIGIT -> '8',
 DIGIT -> '9',
 DIGIT -> '10',
 DIGIT -> '11',
 DIGIT -> '12',
 DIGIT -> '13',
 DIGIT -> '14',
 DIGIT -> '15',
 DIGIT -> '16',
 DIGIT -> '17',
 DIGIT -> '18',
 DIGIT -> '19',
 DIGIT -> '20',
 ARITHMETIC_OPERATOR -> '+',
 ARITHMETIC_OPERATOR -> '-',
 ARITHMETIC_OPERATOR -> '*',
 ARITHMETIC_OPERATOR -> '/',
 RELATIONAL_OPERATOR -> '<',
 RELATIONAL_OPERATOR -> '>',
 RELATIONAL_OPERATOR -> '<=',
 R