## Tokens
##### Keywords: if, else, elif, while, for, in, def, return, True, False, None
##### Operators: +, -, *, /, %, **, =, ==, !=, <, >, <=, >=, and, or, not
##### Delimiters: (, ), {, }, [, ], ',' , '.', ':', ';'
##### Literals: INTEGER, FLOAT, STRING
##### Other: IDENTIFIER, INDENT, DEDENT, NEWLINE, EOF

## Token and Error Class

In [2]:
import re
from typing import Generator, List, Optional, Tuple

class Token:
    def __init__(self, type: str, value: str, line: int, column: int):
        self.type = type
        self.value = value
        self.line = line
        self.column = column
    
    def __str__(self) -> str:
        return f"Token({self.type}, '{self.value}', line={self.line}, col={self.column})"
    
    def __repr__(self) -> str:
        return self.__str__()

class Error:
    def __init__(self, message: str, line: int, column: int):
        self.message = message
        self.line = line
        self.column = column
    
    def __str__(self) -> str:
        return f"Error: {self.message} at line {self.line}, column {self.column}"


## Lexer

In [3]:
class Lexer:
    # Defining regular expressions for token
    TOKEN_SPECS = [
        ('COMMENT', r'#.*'),                                  # Comments
        ('STRING', r'\"([^\\\"]|\\.)*\"|\'([^\\\']|\\.)*\''), # String literals
        ('FLOAT', r'\d+\.\d+'),                               # Float literals
        ('INTEGER', r'\d+'),                                  # Integer literals
        ('KEYWORD', r'(if|else|elif|while|for|in|def|return|True|False|None)\b'),  # Keywords
        ('IDENTIFIER', r'[a-zA-Z_]\w*'),                      # Identifiers
        # Operators (multi-character ones first)
        ('OP_EQ', r'=='),
        ('OP_NE', r'!='),
        ('OP_LE', r'<='),
        ('OP_GE', r'>='),
        ('OP_ASSIGN', r'='),
        ('OP_PLUS', r'\+'),
        ('OP_MINUS', r'-'),
        ('OP_MULT', r'\*\*|\/\/|\*|\/'),  # ** (power), // (floor div), * (mult), / (div)
        ('OP_MOD', r'%'),
        ('OP_LT', r'<'),
        ('OP_GT', r'>'),
        # Delimiters
        ('LPAREN', r'\('),
        ('RPAREN', r'\)'),
        ('LBRACKET', r'\['),
        ('RBRACKET', r'\]'),
        ('LBRACE', r'\{'),
        ('RBRACE', r'\}'),
        ('COMMA', r','),
        ('DOT', r'\.'),
        ('COLON', r':'),
        ('SEMICOLON', r';'),
        ('NEWLINE', r'\n'),
        ('WHITESPACE', r'[ \t]+'),                            # Whitespace
        # INDENT and DEDENT tokens are not matched by regex but generated based on whitespace
    ]
    
    def __init__(self, source_code: str):
        self.source_code = source_code
        self.tokens = []
        self.errors = []
        self.line = 1
        self.column = 1
        self.indent_levels = [0]  # Start with indent level 0
    
    def tokenize(self) -> Generator[Token, None, None]:
        """Tokenize the source code and yield tokens."""
        # Process the source code line by line to handle indentation properly
        lines = self.source_code.split('\n')
        
        for line_num, line in enumerate(lines):
            self.line = line_num + 1
            self.column = 1
            
            #calculating the amount of leading whitespace (indentation) at the beginning of the current line
            if line.strip():  # Non-empty line      #strip() is equivalent of trim() in JS
                indent_size = len(line) - len(line.lstrip())    #lstrip() will remove leading spaces in the line i.e. remove any indentation; and the difference in length tells you how much whitespace there was
                indent_tokens = self._handle_indentation(indent_size)
                for token in indent_tokens:
                    yield token
            else:   # Skip empty lines but still count them for line numbers
                yield Token('NEWLINE', '\\n', self.line, self.column)
                continue
            
            # Process the rest of the line
            i = indent_size if 'indent_size' in locals() else 0
            line_content = line
            
            while i < len(line_content):    #traversing a line character by character
                match = None
                
                ## Skip spaces and tabs (already handled indentation)
                ##if line_content[i].isspace():   #isspace() checks if the entire line is whitespace or not; here it checks whether a single character is whitespace or not since its applied on line_content[i]
                ##    i += 1
                ##    self.column += 1
                ##    continue
                
                # Try to match each token pattern
                for token_type, pattern in self.TOKEN_SPECS:
                    regex = re.compile(pattern) #compiling the pattern into a regex so that we can use it to find tokens in the line
                    match = regex.match(line_content, i)    #match() will check if the "starting" of the string matches the given regex or not, line_content gives the line whose starting it has to compare with the regex with i telling from which point in the line to assume as thhe starting of the line
                    #if a match is found, 'match' will contain an object with info about the found match (like start and end, its value and so on), otherwise it'll contain "none"
                    
                    if match:
                        value = match.group(0)  #group will give the value of the match object
                        
                        if token_type == 'WHITESPACE':
                            #whitespace is not yielded becuz parser receives stream of tokens without space and doesnt care about space
                            # Just update column and continue
                            self.column += len(value)
                            i += len(value)
                            continue
                        elif token_type == 'COMMENT':
                            # Ignore comments
                            i += len(value)
                            self.column += len(value)
                            break   # Break out of the current character processing loop after handling comments
                        else:
                            # For all other tokens
                            token = Token(token_type, value, self.line, self.column)
                            yield token
                        
                        i += len(value)
                        self.column += len(value)
                        break
                
                if not match:
                    # No token matched, raise an error
                    error_msg = f"Invalid character: '{line_content[i]}'"
                    error = Error(error_msg, self.line, self.column)
                    self.errors.append(error)
                    print(error)  # Print the error but continue
                    i += 1
                    self.column += 1
            
            # Add NEWLINE at the end of each line except the last one
            if line_num < len(lines) - 1 or self.source_code.endswith('\n'):
                yield Token('NEWLINE', '\\n', self.line, self.column)
        
        # Output any pending dedents at the end of the file
        #end of file means all indented blocks have been dedented so pop all elements of the indent_levels stack
        while len(self.indent_levels) > 1:
            self.indent_levels.pop()    
            yield Token('DEDENT', '', self.line, self.column)
        
        # End of file token
        yield Token('EOF', '', self.line, self.column)
    
    def _handle_indentation(self, indent_size: int) -> List[Token]:
        """Handle Python's indentation-based block structure. Returns a list of INDENT or DEDENT tokens as needed."""
        tokens = []
        previous_line_indent = self.indent_levels[-1]
        
        if indent_size > previous_line_indent:
            # This is an indentation (start of a new block)
            self.indent_levels.append(indent_size)  #when a new block is recognized through indentation its indentation level is pushed in the stack so that w ecan later check whether that block was dedented properly or not
            tokens.append(Token('INDENT', ' ' * (indent_size - previous_line_indent), self.line, 1))
        
        elif indent_size < previous_line_indent:
            # This is a dedentation (end of one or more blocks)
            while self.indent_levels and indent_size < self.indent_levels[-1]: # ensuring that the latest code block is being dedented
                self.indent_levels.pop()
                tokens.append(Token('DEDENT', '', self.line, 1))
            
            if indent_size != self.indent_levels[-1]:
                # Invalid indentation
                error_msg = f"Inconsistent indentation"
                error = Error(error_msg, self.line, 1)
                self.errors.append(error)
                print(error)  # Print the error but continue
                
                
                # and what about handling indent_size == self.indent_levels[-1]
        return tokens
  
def test_lexer(source_code):
    print(f"Source code:\n{source_code}")
    print("\nTokens:")
    
    lexer = Lexer(source_code)
    token_count = 0
    
    try:
        for token in lexer.tokenize():
            print(token)
            token_count += 1
        
        print(f"\nTotal tokens: {token_count}")
        if lexer.errors:
            print(f"\nErrors ({len(lexer.errors)}):")
            for error in lexer.errors:
                print(f"  {error}")
    except Exception as e:
        print(f"Exception occurred: {str(e)}")
        import traceback
        traceback.print_exc()
    
if __name__ == "__main__":
    test_lexer("""
#INSERT TEST CODE HERE
    """, 
    )
    
    # Interactive testing option
    # do_interactive = input("\nDo you want to run an interactive test? (y/n): ")
    # if do_interactive.lower() == 'y':
    #     print("\n=== Interactive Lexer Test ===")
    #     print("Enter Python code (type 'exit()' on a new line to finish):")
        
    #     lines = []
    #     while True:
    #         line = input("> ")
    #         if line.strip() == "exit()":
    #             break
    #         lines.append(line)
        
    #     source_code = "\n".join(lines)
    #     test_lexer(source_code, "Interactive Input")

Source code:

#INSERT TEST CODE HERE
    

Tokens:
Token(NEWLINE, '\n', line=1, col=1)
Token(NEWLINE, '\n', line=2, col=23)
Token(NEWLINE, '\n', line=3, col=1)
Token(EOF, '', line=3, col=1)

Total tokens: 4


## ***ABSTRACT SYNTAX TREE***

In [4]:
#This file is of abstract syntax tree (AST) node definitions.

from dataclasses import dataclass
from typing import List, Optional, Union

# Base class for all AST nodes ASTNode is the base (parent) class for all other AST node types.
#It doesn’t do anything itself, but helps us group all node types together under one type.

class ASTNode:
    pass

# Root node; body is a list of things in the program like function definitions, statements, etc.
@dataclass
class Program(ASTNode):
    body: List[ASTNode]
    line: int
    column: int

# Statements
@dataclass
class FunctionDef(ASTNode):
    name: str
    params: List[str]
    body: List[ASTNode]
    line: int
    column: int

@dataclass
class IfStatement(ASTNode):
    condition: ASTNode
    then_branch: List[ASTNode]
    elif_branches: List[tuple[ASTNode, List[ASTNode]]]
    else_branch: Optional[List[ASTNode]]
    line: int
    column: int

@dataclass
class WhileLoop(ASTNode):
    condition: ASTNode
    body: List[ASTNode]
    line: int
    column: int

@dataclass
class ForLoop(ASTNode):
    var: str
    iterable: ASTNode
    body: List[ASTNode]
    line: int
    column: int

@dataclass
class Return(ASTNode):
    value: Optional[ASTNode]
    line: int
    column: int

@dataclass
class Assignment(ASTNode):
    target: str
    value: ASTNode
    line: int
    column: int

@dataclass
class ExpressionStatement(ASTNode):
    expression: ASTNode
    line: int
    column: int

# Expressions
@dataclass
class BinaryOp(ASTNode):
    left: ASTNode
    operator: str
    right: ASTNode
    line: int
    column: int

@dataclass
class UnaryOp(ASTNode):
    operator: str
    operand: ASTNode
    line: int
    column: int

@dataclass
class Call(ASTNode):
    func: ASTNode
    args: List[ASTNode]
    line: int
    column: int

@dataclass
class Identifier(ASTNode):
    name: str
    line: int
    column: int

@dataclass
class Literal(ASTNode):
    value: Union[str, int, float, bool, None]
    line: int
    column: int



## ***AST VISUALIZER***

In [5]:
# ast_visualizer.py

from graphviz import Digraph
from typing import Union, List
def print_ast(node: ASTNode, indent: str = ''):
    info = []
    if hasattr(node, "inferred_type") and node.inferred_type:
        info.append(f"type={node.inferred_type}")
    if hasattr(node, "constant_value") and node.constant_value is not None:
        info.append(f"const={node.constant_value}")
    info_str = f" [{', '.join(info)}]" if info else ""

    if isinstance(node, Program):
        print(indent + "Program" + info_str)
        for stmt in node.body:
            print_ast(stmt, indent + "  ")
    elif isinstance(node, FunctionDef):
        info = []
        if hasattr(node, "return_type") and node.return_type:
            info.append(f"return_type={node.return_type}")
        info_str = f" [{', '.join(info)}]" if info else ""
        print(indent + f"FunctionDef({node.name})" + info_str)
        for stmt in node.body:
            print_ast(stmt, indent + "  ")
    elif isinstance(node, IfStatement):
        print(indent + "IfStatement" + info_str)
        print(indent + "  Condition:")
        print_ast(node.condition, indent + "    ")
        print(indent + "  Then:")
        for stmt in node.then_branch:
            print_ast(stmt, indent + "    ")
        for cond, branch in node.elif_branches:
            print(indent + "  Elif:")
            print_ast(cond, indent + "    ")
            for stmt in branch:
                print_ast(stmt, indent + "      ")
        if node.else_branch:
            print(indent + "  Else:")
            for stmt in node.else_branch:
                print_ast(stmt, indent + "    ")
    elif isinstance(node, WhileLoop):
        print(indent + "WhileLoop" + info_str)
        print(indent + "  Condition:")
        print_ast(node.condition, indent + "    ")
        print(indent + "  Body:")
        for stmt in node.body:
            print_ast(stmt, indent + "    ")
    elif isinstance(node, ForLoop):
        print(indent + f"ForLoop(var={node.var})" + info_str)
        print(indent + "  Iterable:")
        print_ast(node.iterable, indent + "    ")
        print(indent + "  Body:")
        for stmt in node.body:
            print_ast(stmt, indent + "    ")
    elif isinstance(node, Return):
        print(indent + "Return" + info_str)
        if node.value:
            print_ast(node.value, indent + "  ")
    elif isinstance(node, Assignment):
        print(indent + f"Assignment(target={node.target})" + info_str)
        print_ast(node.value, indent + "  ")
    elif isinstance(node, ExpressionStatement):
        print(indent + "ExpressionStatement" + info_str)
        print_ast(node.expression, indent + "  ")
    elif isinstance(node, BinaryOp):
        print(indent + f"BinaryOp({node.operator})" + info_str)
        print_ast(node.left, indent + "  ")
        print_ast(node.right, indent + "  ")
    elif isinstance(node, UnaryOp):
        print(indent + f"UnaryOp({node.operator})" + info_str)
        print_ast(node.operand, indent + "  ")
    elif isinstance(node, Call):
        print(indent + "Call" + info_str)
        print_ast(node.func, indent + "  ")
        for arg in node.args:
            print_ast(arg, indent + "    ")
    elif isinstance(node, Identifier):
        print(indent + f"Identifier({node.name})" + info_str)
    elif isinstance(node, Literal):
        print(indent + f"Literal({repr(node.value)})" + info_str)
    else:
        print(indent + f"UnknownNode({node})" + info_str)
        
# Optional: Graphviz visualizer
def _build_graph(node: ASTNode, dot: Digraph, parent: str = None, counter=[0]):
    node_id = f"node{counter[0]}"
    counter[0] += 1

    # Build label with annotations
    label = type(node).__name__
    # Add main info
    if isinstance(node, Identifier):
        label += f"({node.name})"
    elif isinstance(node, Literal):
        label += f"({repr(node.value)})"
    elif isinstance(node, Assignment):
        label += f"({node.target})"
    elif isinstance(node, BinaryOp):
        label += f"({node.operator})"
    elif isinstance(node, UnaryOp):
        label += f"({node.operator})"
    elif isinstance(node, FunctionDef):
        label += f"({node.name})"
        if hasattr(node, "return_type") and node.return_type:
            label += f"\\nreturn_type={node.return_type}"
    elif isinstance(node, ForLoop):
        label += f"({node.var})"

    # Add annotation info
    info = []
    if hasattr(node, "inferred_type") and node.inferred_type:
        info.append(f"type={node.inferred_type}")
    if hasattr(node, "constant_value") and node.constant_value is not None:
        info.append(f"const={node.constant_value}")
    if info:
        label += "\\n" + ", ".join(info)

    dot.node(node_id, label)

    if parent:
        dot.edge(parent, node_id)

    for field in getattr(node, '__dataclass_fields__', {}):
        child = getattr(node, field)
        if isinstance(child, ASTNode):
            _build_graph(child, dot, node_id, counter)
        elif isinstance(child, list):
            for item in child:
                if isinstance(item, ASTNode):
                    _build_graph(item, dot, node_id, counter)
        elif isinstance(child, tuple):
            for item in child:
                if isinstance(item, ASTNode):
                    _build_graph(item, dot, node_id, counter)

    return dot

def visualize_ast(node: ASTNode):
    dot = Digraph(comment="AST")
    _build_graph(node, dot)
    return dot

## ***PARSER***

In [6]:
# Custom error class for parser errors
class ParserError(Exception):
    pass

# Parser class converts tokens into an AST
class Parser:
    def __init__(self, lexer: Lexer):
        self.tokens = list(lexer.tokenize())  # Convert lexer generator into list of tokens
        self.pos = 0  # Current position in token list
        self.current_token = self.tokens[self.pos]  # Currently looked-at token

    def error(self, msg: str):
        # Raise a ParserError with current token's position
        raise ParserError(f"{msg} at line {self.current_token.line}, column {self.current_token.column}")
    
    def skip_newlines(self):
        # Ignore NEWLINE tokens before/after blocks or statements
        while self.current_token.type == 'NEWLINE':
            self.advance()

    def advance(self):
        # Move to the next token
        self.pos += 1
        if self.pos < len(self.tokens):
            self.current_token = self.tokens[self.pos]

    def expect(self, token_type: str):
        # Ensure the current token is of the expected type, or throw error
        if self.current_token.type == token_type:
            self.advance()
        else:
            self.error(f"Expected token {token_type}, got {self.current_token.type}")

    def match(self, token_type: str):
        # Match a token and advance if matched
        if self.current_token.type == token_type:
            self.advance()
            return True
        return False

    def parse(self) -> Program:
        # Parse a full program (list of statements)
        body = []
        self.skip_newlines()
        while self.current_token.type != 'EOF':
            stmt = self.parse_statement()
            if stmt:
                body.append(stmt)
            self.skip_newlines()
        return Program(body=body, line=0, column=0) 

    def parse_statement(self) -> ASTNode:
        # Decide which type of statement to parse
        if self.current_token.type == 'KEYWORD':
            if self.current_token.value == 'def':
                return self.parse_function_def()
            elif self.current_token.value == 'return':
                return self.parse_return()
            elif self.current_token.value == 'if':
                return self.parse_if()
            elif self.current_token.value == 'while':
                return self.parse_while()
            elif self.current_token.value == 'for':
                return self.parse_for()
            else:
                print(f"DEBUG: Unknown keyword: {self.current_token.value}")
                self.error(f"Unknown keyword: {self.current_token.value}")
            
        #######COULDNT WORK DURING SEMANTIC ANALYSIS SO CHNAGED THE LOGIC HERE##########
        # If not a keyword, it's either an assignment or expression
        # expr = self.parse_expression()
        # if isinstance(expr, Identifier) and self.match('OP_ASSIGN'):
        #     value = self.parse_expression()
        #     return Assignment(target=expr.name, value=value)
        # return ExpressionStatement(expr)
        ################################################################################
        
        # If not a keyword, it could be assignment or expression
        if self.current_token.type == 'IDENTIFIER':
            # Lookahead to check for assignment
            next_token = self.tokens[self.pos + 1] if self.pos + 1 < len(self.tokens) else None
            if next_token and next_token.type == 'OP_ASSIGN':
                assign_token = self.current_token
                identifier = self.current_token.value
                self.advance()  # consume identifier
                self.advance()  # consume '='
                value = self.parse_expression()
                return Assignment(target=identifier, value=value, line=assign_token.line, column=assign_token.column)

        # Otherwise, treat it as an expression statement
        expr_token = self.current_token
        expr = self.parse_expression()
        return ExpressionStatement(expression=expr, line=expr_token.line, column=expr_token.column)

    def parse_function_def(self) -> FunctionDef:
        # Parse function definition: def name(params):\n  <body>
        def_token = self.current_token
        self.expect('KEYWORD')  # 'def'
        if self.current_token.type != 'IDENTIFIER':
            self.error("Expected function name")
        name = self.current_token.value
        self.advance()
        self.expect('LPAREN')

        # Parse parameter list
        params = []
        if self.current_token.type != 'RPAREN':
            while True:
                if self.current_token.type != 'IDENTIFIER':
                    self.error("Expected parameter name")
                params.append(self.current_token.value)
                self.advance()
                if not self.match('COMMA'):
                    break

        self.expect('RPAREN')
        self.expect('COLON')
        self.expect('NEWLINE')
        self.expect('INDENT')
        body = self.parse_block()
        return FunctionDef(name=name, params=params, body=body, line=def_token.line, column=def_token.column)

    def parse_return(self) -> Return:
        token = self.current_token
        # Parse return statement
        self.expect('KEYWORD')  # 'return'
        if self.current_token.type == 'NEWLINE':
            return Return(value=None, line=token.line, column=token.column)
        value = self.parse_expression()
        return Return(value=value, line=token.line, column=token.column)

    def parse_if(self) -> IfStatement:
        if_token = self.current_token
        # Parse if-elif-else statement
        self.expect('KEYWORD')  # 'if'
        condition = self.parse_expression()
        self.expect('COLON')
        self.expect('NEWLINE')
        self.expect('INDENT')
        then_branch = self.parse_block()

        elif_branches = []
        # Handle optional elif blocks
        while self.current_token.type == 'KEYWORD' and self.current_token.value == 'elif':
            self.advance()
            cond = self.parse_expression()
            self.expect('COLON')
            self.expect('NEWLINE')
            self.expect('INDENT')
            body = self.parse_block()
            elif_branches.append((cond, body))

        else_branch = None
        # Handle optional else block
        if self.current_token.type == 'KEYWORD' and self.current_token.value == 'else':
            self.advance()
            self.expect('COLON')
            self.expect('NEWLINE')
            self.expect('INDENT')
            else_branch = self.parse_block()

        return IfStatement(condition, then_branch, elif_branches, else_branch, line=if_token.line, column=if_token.column)

    def parse_while(self) -> WhileLoop:
        while_token = self.current_token
        # Parse while loop
        self.expect('KEYWORD')  # 'while'
        condition = self.parse_expression()
        self.expect('COLON')
        self.expect('NEWLINE')
        self.expect('INDENT')
        body = self.parse_block()
        return WhileLoop(condition=condition, body=body, line=while_token.line, column=while_token.column)

    def parse_for(self) -> ForLoop:
        for_token = self.current_token
        # Parse for loop: for x in iterable:
        self.expect('KEYWORD')  # 'for'
        if self.current_token.type != 'IDENTIFIER':
            self.error("Expected variable name")
        var = self.current_token.value
        self.advance()
        self.expect('KEYWORD')  # 'in'
        iterable = self.parse_expression()
        self.expect('COLON')
        self.expect('NEWLINE')
        self.expect('INDENT')
        body = self.parse_block()
        return ForLoop(var=var, iterable=iterable, body=body, line=for_token.line, column=for_token.column)

    def parse_block(self) -> List[ASTNode]:
        # Parse an indented block of statements
        body = []
        self.skip_newlines()
        while self.current_token.type not in ('DEDENT', 'EOF'):
            stmt = self.parse_statement()
            if stmt:
                body.append(stmt)
            self.skip_newlines()
        self.expect('DEDENT')
        return body

    def parse_expression(self, precedence=0) -> ASTNode:
        # Parse binary expressions with precedence (e.g., 1 + 2 * 3)
        left = self.parse_primary()
        while self.is_operator(self.current_token) and self.get_precedence(self.current_token) >= precedence:
            op_token = self.current_token
            self.advance()
            right = self.parse_expression(self.get_precedence(op_token) + 1)
            left = BinaryOp(left=left, operator=op_token.value, right=right, line=op_token.line, column=op_token.column)
        return left

    def parse_primary(self) -> ASTNode:
        # Parse literals, variables, function calls, unary ops, and parenthesis
        token = self.current_token
        if token.type == 'INTEGER':
            self.advance()
            return Literal(value=int(token.value), line=token.line, column=token.column)
        elif token.type == 'FLOAT':
            self.advance()
            return Literal(value=float(token.value), line=token.line, column=token.column)
        elif token.type == 'STRING':
            self.advance()
            return Literal(value=token.value[1:-1], line=token.line, column=token.column)
        elif token.type == 'KEYWORD':
            if token.value in ('True', 'False', 'None'):
                self.advance()
                val = {'True': True, 'False': False, 'None': None}[token.value]
                return Literal(value=val, line=token.line, column=token.column)
        elif token.type == 'IDENTIFIER':
            self.advance()
            if self.match('LPAREN'):  # Function call
                args = []
                if self.current_token.type != 'RPAREN':
                    while True:
                        args.append(self.parse_expression())
                        if not self.match('COMMA'):
                            break
                self.expect('RPAREN')
                return Call(func=Identifier(name=token.value, line=token.line, column=token.column), args=args, line=token.line, column=token.column)
            return Identifier(name=token.value, line=token.line, column=token.column)
        elif token.type == 'OP_MINUS':
            self.advance()
            operand = self.parse_primary()
            return UnaryOp(operator='-', operand=operand)
        elif token.type == 'LPAREN':
            self.advance()
            expr = self.parse_expression()
            self.expect('RPAREN')
            return expr
        else:
            self.error(f"Unexpected token {token}")

    def is_operator(self, token: Token) -> bool:
        # Check if token is an operator
        return token.type.startswith('OP_')

    def get_precedence(self, token: Token) -> int:
        # Define operator precedence levels (higher = tighter binding)
        precedence = {
            'OP_OR': 1,
            'OP_AND': 2,
            'OP_EQ': 3, 'OP_NE': 3,
            'OP_LT': 4, 'OP_LE': 4, 'OP_GT': 4, 'OP_GE': 4,
            'OP_PLUS': 5, 'OP_MINUS': 5,
            'OP_MULT': 6, 'OP_MOD': 6,
            'OP_ASSIGN': 0,  # Assignment is handled separately
        }
        return precedence.get(token.type, -1)
 

testCode = """

"""
#Test code
lexer = Lexer(testCode)
parser = Parser(lexer)
ast = parser.parse()

#Console view
print_ast(ast)

    #graph
dot = visualize_ast(ast)
dot.render('ast_tree', view=True, format='png')


Program


'ast_tree.png'

## ***SEMANTIC ANALYSIS***

In [7]:
from colorama import Fore, Style, init
from IPython.display import Markdown, display
# Enable ANSI color in terminal
init(autoreset=True)

# Custom exception class used to indicate semantic errors in the code being analyzed
class SemanticError(Exception):
    pass

# Set of built-in function names that user-defined functions/variables should not override
BUILTINS = {'print', 'len', 'range', 'input', 'int', 'float', 'str', 'bool'}  # Extend as needed

# Scope class manages variables/functions within a block or function, supports nested scoping
class Scope:
    def __init__(self, parent=None):
        self.parent = parent             # Reference to the parent scope (for nested scopes)
        self.variables = {}              # declared variable names in current scope
        self.functions = {}              # Dictionary of declared functions: {func_name: param_list}
        self.in_function = False         # Flag to check if this scope is inside a function

    # Declare a variable in the current scope; return False if it already exists
    def declare_var(self, name, var_type=None):
        if name in self.variables:
            return False  # Duplicate in same scope
        self.variables[name] = {"type": var_type}
        return True

    # Check whether a variable is declared in current or any parent scope
    def is_var_declared(self, name):
        if name in self.variables:
            return True
        if self.parent:
            return self.parent.is_var_declared(name)
        return False

    # Declare a function with a list of parameters; returns False if already declared
    def declare_func(self, name, func_def):
        if name in self.functions:
            return False
        self.functions[name] = func_def
        return True

    # Recursively retrieve the parameter list of a declared function
    def get_func(self, name):
        if name in self.functions:
            return self.functions[name]
        if self.parent:
            return self.parent.get_func(name)
        return None

    # Check if a function is declared by trying to get it
    def is_func_declared(self, name):
        return self.get_func(name) is not None
    
    #to show symbol table 
    def print_symbols(self, indent=0):
        prefix = "  " * indent
        print(f"{prefix}Scope:")
        if self.variables:
            for name, info in self.variables.items():
                print(f"{prefix}  Variable: {name}, Type: {info.get('type', 'unknown')}")
        if self.functions:
            for fname, func_def in self.functions.items():
                param_strs = [f"{p}: {t}" for p, t in func_def.param_types.items()]
                return_type = getattr(func_def, 'return_type', None)
                return_type_str = f", Return: {return_type}" if return_type else ""
                print(f"{prefix}  Function: {fname}({', '.join(param_strs)}){return_type_str}")
        if self.parent:
            self.parent.print_symbols(indent + 1)

# SemanticAnalyzer walks through the AST to validate semantics and detect errors
class SemanticAnalyzer:
    def __init__(self):
        self.global_scope = Scope()        # Top-level scope for global variables/functions
        self.current_scope = self.global_scope  # Pointer to current active scope
        self.errors = []                   # List to store error messages


    # Report a semantic error with line, column, and node context
    def error(self, msg, node):
        line = getattr(node, 'line', '?')
        col = getattr(node, 'column', '?')
        node_name = getattr(node, 'name', type(node).__name__)
        self.errors.append(f"[Line {line}, Column {col}] Error at '{node_name}': {msg}")
        # self.errors.append(formatted)

    # Entry point to start analyzing the AST
    def analyze(self, node):
        self.visit(node)  # Start visiting from the root node
        if self.errors:
            raise SemanticError("\n".join(self.errors))

    # Dispatch method that calls the appropriate visit_* method based on node type
    def visit(self, node):
        method = 'visit_' + type(node).__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)

    # Generic visitor for composite nodes (e.g., containing other nodes in fields/lists/tuples)
    def generic_visit(self, node):
        for field in getattr(node, '__dataclass_fields__', {}):
            value = getattr(node, field)
            if isinstance(value, ASTNode):
                self.visit(value)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, ASTNode):
                        self.visit(item)
            elif isinstance(value, tuple):
                for item in value:
                    if isinstance(item, ASTNode):
                        self.visit(item)

    # Visit a program node which consists of multiple statements
    def visit_Program(self, node: Program):
        for stmt in node.body:
            self.visit(stmt)

    # Visit a function definition and validate function name, parameters, and body
    def visit_FunctionDef(self, node: FunctionDef):
        if not self.current_scope.declare_func(node.name, node):
            self.error(f"Duplicate function definition '{node.name}'", node)
        func_scope = Scope(parent=self.current_scope)
        func_scope.in_function = True

        node.param_types = {}
        for param in node.params:
            p_name = param if isinstance(param, str) else param[0]
            p_type = None if isinstance(param, str) else param[1]
            func_scope.declare_var(p_name, p_type)
            node.param_types[p_name] = p_type

        prev_scope = self.current_scope
        self.current_scope = func_scope

        return_types = set()
        for stmt in node.body:
            self.visit(stmt)
            if isinstance(stmt, Return) and hasattr(stmt, 'inferred_type') and stmt.inferred_type is not None:
                return_types.add(stmt.inferred_type)

        self.current_scope = prev_scope
        node.return_type = return_types.pop() if len(return_types) == 1 else 'multiple' if return_types else 'None'

    # Visit an assignment: validate left-hand name, check for name conflicts, and analyze RHS
    def visit_Assignment(self, node: Assignment):
        if node.target in BUILTINS:
            self.error(f"Cannot assign to built-in name '{node.target}'", node)
        if self.current_scope.is_func_declared(node.target):
            self.error(f"Cannot assign to function name '{node.target}'", node)
        self.visit(node.value)
        var_type = getattr(node.value, 'inferred_type', None)
        node.inferred_type = var_type
        self.current_scope.declare_var(node.target, var_type)

    # Visit an identifier node to check if it was declared
    def visit_Identifier(self, node: Identifier):
        scope = self.current_scope
        while scope:
            if node.name in scope.variables:
                node.inferred_type = scope.variables[node.name]["type"]
                return
            scope = scope.parent
        if self.current_scope.is_func_declared(node.name):
            self.error(f"Function '{node.name}' used as a variable", node)
        else:
            self.error(f"Undeclared variable '{node.name}'", node)

    # Visit a function call and validate function existence and argument count
    def visit_Call(self, node: Call):
        func_def = self.current_scope.get_func(node.func.name)
        func_name = node.func.name
        func_def = self.current_scope.get_func(func_name)
        if func_def is None:
            if func_name in BUILTINS:
                node.inferred_type = 'None'  # assume built-ins return nothing for now
                for arg in node.args:
                    self.visit(arg)
                return
            else:
                self.error(f"Call to undefined function '{func_name}'", node)
                return
        else:
            # Infer parameter types if not set
            param_types_changed = False
            for (pname, exp_type), arg in zip(func_def.param_types.items(), node.args):
                self.visit(arg)
                actual = getattr(arg, 'inferred_type', None)
                if exp_type is None:
                    func_def.param_types[pname] = actual  # Infer type from first call
                    param_types_changed = True
                elif actual != exp_type:
                    self.error(f"Type mismatch in argument: expected {exp_type}, got {actual}", arg)
            if len(node.args) != len(func_def.param_types):
                self.error(f"Function '{node.func.name}' expects {len(func_def.param_types)} arguments, got {len(node.args)}", node)
            # Re-analyze function body if parameter types were updated
            if param_types_changed:
                prev_scope = self.current_scope
                func_scope = Scope(parent=self.global_scope)
                func_scope.in_function = True
                for pname, ptype in func_def.param_types.items():
                    func_scope.declare_var(pname, ptype)
                self.current_scope = func_scope
                return_types = set()
                for stmt in func_def.body:
                    self.visit(stmt)
                    if isinstance(stmt, Return) and hasattr(stmt, 'inferred_type') and stmt.inferred_type is not None:
                        return_types.add(stmt.inferred_type)
                self.current_scope = prev_scope
                func_def.return_type = return_types.pop() if len(return_types) == 1 else 'multiple' if return_types else 'None'
        if func_def and hasattr(func_def, 'return_type'):
            node.inferred_type = func_def.return_type


    # Visit a return statement and ensure it's inside a function
    def visit_Return(self, node: Return):
        scope = self.current_scope
        while scope:
            if scope.in_function:
                break
            scope = scope.parent
        else:
            self.error("Return statement outside function", node)

        if node.value:
            self.visit(node.value)
            node.inferred_type = getattr(node.value, 'inferred_type', None)

    # Visit if/elif/else blocks and analyze each conditional branch
    def visit_IfStatement(self, node: IfStatement):
        self.visit(node.condition)
        for stmt in node.then_branch:
            self.visit(stmt)
        for cond, body in node.elif_branches:
            self.visit(cond)
            for stmt in body:
                self.visit(stmt)
        if node.else_branch:
            for stmt in node.else_branch:
                self.visit(stmt)

    # Visit while loop and analyze condition and body statements
    def visit_WhileLoop(self, node: WhileLoop):
        self.visit(node.condition)
        for stmt in node.body:
            self.visit(stmt)

    # Visit for loop: declare loop variable in new scope and analyze loop body
    def visit_ForLoop(self, node: ForLoop):
        self.visit(node.iterable)
        loop_scope = Scope(parent=self.current_scope)
        if not loop_scope.declare_var(node.var):
            self.error(f"Loop variable '{node.var}' already declared", node)
        prev_scope = self.current_scope
        self.current_scope = loop_scope
        for stmt in node.body:
            self.visit(stmt)
        self.current_scope = prev_scope

    # Visit an expression statement (e.g., standalone function call)
    def visit_ExpressionStatement(self, node: ExpressionStatement):
        self.visit(node.expression)

    # Visit a binary operation node and analyze both left and right operands
    def visit_BinaryOp(self, node: BinaryOp):
        self.visit(node.left)
        self.visit(node.right)
        
        ltype = getattr(node.left, 'inferred_type', None)
        rtype = getattr(node.right, 'inferred_type', None)

        # Handle type inference for numeric operations
        numeric_ops = {'+', '-', '*', '/', '%', '**'}

        if ltype and rtype:
            # if node.operator in numeric_ops:
            #     if ltype in {'int', 'float'} and rtype in {'int', 'float'}:
            #         # Type promotion rule: int + float => float
            #         if ltype == 'float' or rtype == 'float':
            #             node.inferred_type = 'float'
            #         else:
            #             node.inferred_type = 'int'
            #     else:
            #         self.error(f"Unsupported operand types for {node.operator}: '{ltype}' and '{rtype}'", node)
            if node.operator == '+':
                if ltype == rtype == 'str':
                    node.inferred_type = 'str'
                elif ltype in {'int', 'float'} and rtype in {'int', 'float'}:
                    node.inferred_type = 'float' if 'float' in (ltype, rtype) else 'int'
                else:
                    self.error(f"Unsupported operand types for +: '{ltype}' and '{rtype}'", node)
            elif node.operator in {'-', '*', '/', '%', '**'}:
                if ltype in {'int', 'float'} and rtype in {'int', 'float'}:
                    node.inferred_type = 'float' if 'float' in (ltype, rtype) else 'int'
                else:
                    self.error(f"Unsupported operand types for {node.operator}: '{ltype}' and '{rtype}'", node)
            elif node.operator in {'==', '!=', '<', '>', '<=', '>='}:
                node.inferred_type = 'bool'
            elif node.operator in {'and', 'or'}:
                if ltype == rtype == 'bool':
                    node.inferred_type = 'bool'
                else:
                    self.error(f"Logical operators require boolean operands, got '{ltype}' and '{rtype}'", node)
            else:
                self.error(f"Unknown binary operator '{node.operator}'", node)

        # Constant folding (optional)
        if hasattr(node.left, 'constant_value') and hasattr(node.right, 'constant_value'):
            try:
                node.constant_value = eval(f"{repr(node.left.constant_value)} {node.operator} {repr(node.right.constant_value)}")
            except Exception:
                pass

            
    # Visit a unary operation node and analyze its operand
    def visit_UnaryOp(self, node: UnaryOp):
        self.visit(node.operand)
        if hasattr(node.operand, 'inferred_type'):
            node.inferred_type = node.operand.inferred_type
        if hasattr(node.operand, 'constant_value'):
            try:
                node.constant_value = eval(f"{node.operator}{node.operand.constant_value}")
            except:
                pass

    # Visit a literal node (e.g., number, string); literals are always valid
    def visit_Literal(self, node: Literal):
        node.inferred_type = type(node.value).__name__
        node.constant_value = node.value
        

source_code = """

"""
lexer = Lexer(source_code)
parser = Parser(lexer)
ast = parser.parse()

display(Markdown("### AST Before Semantic Analysis"))
print_ast(ast)

analyzer = SemanticAnalyzer()

try:
    analyzer.analyze(ast)

    display(Markdown(f"<span style='color:green; font-weight:bold'>✅ Semantic analysis passed!</span>"))

    display(Markdown("### Symbol Table"))
    analyzer.global_scope.print_symbols()
        
    display(Markdown("### AST After Semantic Analysis (With Annotations)"))
    print_ast(ast)

    dot = visualize_ast(ast)
    dot.render('ast_tree', view=True, format='png')

except Exception as e:
    display(Markdown(f"<span style='color:red; font-weight:bold'>❌ Semantic analysis failed:</span> {e}"))



### AST Before Semantic Analysis

Program


<span style='color:green; font-weight:bold'>✅ Semantic analysis passed!</span>

### Symbol Table

Scope:


### AST After Semantic Analysis (With Annotations)

Program


## Code Optimization

In [10]:
#Note: These optimizations are performed: Constant folding and propagation, dead code elimination, control flow optimization, expression simplification
from typing import List, Optional
from dataclasses import dataclass

class Optimizer:
    def __init__(self):
        self.optimizations_applied = 0
        self.modified = True  # Track if any optimizations were made
        self.constants = {}   # track constant values
        self.scope_stack = [{}]  # Stack of scopes for tracking variables in different scopes
        self.function_params = set()  # Track current function parameters
    
    def visit_Identifier(self, node: Identifier) -> ASTNode:
        """Propagate known constant values"""
        # Don't optimize function parameters
        if node.name in self.function_params:
            return node
            
        # Look for constant value in all scopes, from innermost to outermost
        for scope in reversed(self.scope_stack):
            if node.name in scope:
                self.modified = True
                self.optimizations_applied += 1
                # Create new literal and preserve any existing attributes
                new_literal = Literal(value=scope[node.name], line=node.line, column=node.column)
                # Copy any additional attributes that might exist (like type annotations)
                for attr in dir(node):
                    if not attr.startswith('_') and attr not in ['name', 'line', 'column']:
                        if hasattr(node, attr) and not callable(getattr(node, attr)):
                            try:
                                setattr(new_literal, attr, getattr(node, attr))
                            except:
                                pass
                return new_literal
        return node
    
    def optimize(self, node: ASTNode) -> ASTNode:
        """Main optimization entry point"""
        if node is None:
            return None
            
        # Keep optimizing until no more changes can be made
        max_iterations = 10  # Prevent infinite loops
        iteration = 0
        
        while self.modified and iteration < max_iterations:
            self.modified = False
            old_optimizations = self.optimizations_applied
            node = self.visit(node)
            iteration += 1
            
            # If no new optimizations were applied, we're done
            if self.optimizations_applied == old_optimizations:
                break
                
        return node
    
    def visit(self, node: ASTNode) -> Optional[ASTNode]:
        """Visit and potentially transform a node"""
        if node is None:
            return None
            
        method = 'visit_' + type(node).__name__
        visitor = getattr(self, method, self.generic_visit)
        return visitor(node)
    
    def generic_visit(self, node: ASTNode) -> ASTNode:
        """Default visitor that traverses all fields"""
        for field in getattr(node, '__dataclass_fields__', {}):
            value = getattr(node, field)
            if isinstance(value, ASTNode):
                optimized = self.visit(value)
                if optimized is not None:
                    setattr(node, field, optimized)
            elif isinstance(value, list):
                new_list = []
                for item in value:
                    if isinstance(item, ASTNode):
                        optimized = self.visit(item)
                        if optimized is not None:
                            new_list.append(optimized)
                    else:
                        new_list.append(item)
                setattr(node, field, new_list)
        return node

    def visit_Program(self, node: Program) -> Program:
        """Optimize the program's body"""
        optimized_body = []
        for stmt in node.body:
            result = self.visit(stmt)
            if result is not None:
                if isinstance(result, list):
                    optimized_body.extend(result)
                else:
                    optimized_body.append(result)
        node.body = optimized_body
        return node

    def visit_BinaryOp(self, node: BinaryOp) -> ASTNode:
        """Enhanced constant folding for binary operations"""
        # Visit operands first to ensure constant propagation happens
        node.left = self.visit(node.left)
        node.right = self.visit(node.right)
        
        # Try to evaluate if both operands are literals
        left_val = getattr(node.left, 'value', None) if isinstance(node.left, Literal) else None
        right_val = getattr(node.right, 'value', None) if isinstance(node.right, Literal) else None
        
        if left_val is not None and right_val is not None:
            try:
                if node.operator == '+':
                    value = left_val + right_val
                elif node.operator == '-':
                    value = left_val - right_val
                elif node.operator == '*':
                    value = left_val * right_val
                elif node.operator == '/':
                    value = left_val / right_val
                elif node.operator == '%':
                    value = left_val % right_val
                elif node.operator == '==':
                    value = left_val == right_val
                elif node.operator == '!=':
                    value = left_val != right_val
                elif node.operator == '<':
                    value = left_val < right_val
                elif node.operator == '>':
                    value = left_val > right_val
                elif node.operator == '<=':
                    value = left_val <= right_val
                elif node.operator == '>=':
                    value = left_val >= right_val
                else:
                    return node
                
                self.modified = True
                self.optimizations_applied += 1
                # Create new literal and preserve any existing attributes
                new_literal = Literal(value=value, line=node.line, column=node.column)
                # Copy any additional attributes that might exist (like type annotations)
                for attr in dir(node):
                    if not attr.startswith('_') and attr not in ['left', 'right', 'operator', 'line', 'column', 'value']:
                        if hasattr(node, attr) and not callable(getattr(node, attr)):
                            try:
                                setattr(new_literal, attr, getattr(node, attr))
                            except:
                                pass
                return new_literal
            except (ZeroDivisionError, TypeError, ValueError):
                # Don't optimize if operation would cause an error
                return node
        return node

    def visit_IfStatement(self, node: IfStatement) -> Optional[List[ASTNode]]:
        """Optimize if statements"""
        # First optimize the condition
        node.condition = self.visit(node.condition)
        
        # If condition is a constant literal
        if isinstance(node.condition, Literal):
            self.modified = True
            self.optimizations_applied += 1
            if node.condition.value:
                # True condition - only keep then branch
                optimized_then = []
                for stmt in node.then_branch:
                    result = self.visit(stmt)
                    if result is not None:
                        if isinstance(result, list):
                            optimized_then.extend(result)
                        else:
                            optimized_then.append(result)
                return optimized_then
            elif node.else_branch:
                # False condition - only keep else branch
                optimized_else = []
                for stmt in node.else_branch:
                    result = self.visit(stmt)
                    if result is not None:
                        if isinstance(result, list):
                            optimized_else.extend(result)
                        else:
                            optimized_else.append(result)
                return optimized_else
            else:
                # False condition with no else - remove entirely
                return []
        
        # Normal case: optimize all branches without removing them
        # Optimize then branch
        if node.then_branch:  # Only process if then_branch exists and is not empty
            optimized_then = []
            for stmt in node.then_branch:
                result = self.visit(stmt)
                if result is not None:
                    if isinstance(result, list):
                        optimized_then.extend(result)
                    else:
                        optimized_then.append(result)
            node.then_branch = optimized_then
                    
        # Optimize elif branches
        if node.elif_branches:  # Only process if elif_branches exists and is not empty
            optimized_elif = []
            for cond, body in node.elif_branches:
                optimized_cond = self.visit(cond)
                optimized_body = []
                for stmt in body:
                    result = self.visit(stmt)
                    if result is not None:
                        if isinstance(result, list):
                            optimized_body.extend(result)
                        else:
                            optimized_body.append(result)
                optimized_elif.append((optimized_cond, optimized_body))
            node.elif_branches = optimized_elif
            
        # Optimize else branch
        if node.else_branch:  # Only process if else_branch exists and is not empty
            optimized_else = []
            for stmt in node.else_branch:
                result = self.visit(stmt)
                if result is not None:
                    if isinstance(result, list):
                        optimized_else.extend(result)
                    else:
                        optimized_else.append(result)
            node.else_branch = optimized_else
        
        return node

    def visit_WhileLoop(self, node: WhileLoop) -> Optional[WhileLoop]:
        """Optimize while loops"""
        node.condition = self.visit(node.condition)
        
        # If condition is a constant false, remove the loop
        if isinstance(node.condition, Literal) and not node.condition.value:
            self.modified = True
            self.optimizations_applied += 1
            return None
                
        # Handle loop body optimization differently based on loop type
        if isinstance(node.condition, Literal) and node.condition.value is True:
            # For infinite loops, don't optimize variable assignments
            optimized_body = []
            for stmt in node.body:
                if isinstance(stmt, Assignment):
                    optimized_body.append(stmt)
                else:
                    result = self.visit(stmt)
                    if result is not None:
                        optimized_body.append(result)
            node.body = optimized_body
        else:
            # For normal loops, optimize everything
            optimized_body = []
            for stmt in node.body:
                result = self.visit(stmt)
                if result is not None:
                    if isinstance(result, list):
                        optimized_body.extend(result)
                    else:
                        optimized_body.append(result)
            node.body = optimized_body
        
        return node
        

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        """Handle function scope - FIX: Don't optimize based on parameters"""
        # Save current function parameters
        old_params = self.function_params.copy()
        
        # Add current function parameters to avoid optimizing them
        self.function_params.update(node.params)
        
        # Push new scope but don't add parameters as constants
        self.scope_stack.append({})
        
        optimized_body = []
        for stmt in node.body:
            result = self.visit(stmt)
            if result is not None:
                if isinstance(result, list):
                    optimized_body.extend(result)
                else:
                    optimized_body.append(result)
        node.body = optimized_body
        
        # Restore previous state
        self.scope_stack.pop()
        self.function_params = old_params
        return node

    def visit_Return(self, node: Return) -> Return:
        """Optimize return value expressions"""
        if node.value:
            node.value = self.visit(node.value)
        return node

    def visit_Assignment(self, node: Assignment) -> Assignment:
        """Optimize assignment values and track constants"""
        # Don't track assignments to function parameters as constants
        if node.target in self.function_params:
            node.value = self.visit(node.value)
            return node
            
        # First optimize the value expression
        node.value = self.visit(node.value)
        
        # Only track as constant if we're in the global scope (not inside a function)
        # Inside functions, variables can have different values in different control flow paths
        if len(self.scope_stack) == 1:  # Global scope only
            # If the result is a literal, track it as a constant
            if isinstance(node.value, Literal):
                self.scope_stack[-1][node.target] = node.value.value
                # Mark as modified to ensure another pass happens
                self.modified = True
            else:
                # If assigning a non-constant, remove from constants if it exists
                if node.target in self.scope_stack[-1]:
                    del self.scope_stack[-1][node.target]
        else:
            # Inside function scope, don't track any assignments as constants
            # because they might be in different control flow branches
            if node.target in self.scope_stack[-1]:
                del self.scope_stack[-1][node.target]
        
        return node

    def visit_ExpressionStatement(self, node: ExpressionStatement) -> ExpressionStatement:
        """Optimize expressions in expression statements"""
        node.expression = self.visit(node.expression)
        return node

    def visit_Call(self, node: Call) -> Call:
        """Optimize function calls (visit arguments)"""
        if hasattr(node, 'args'):
            node.args = [self.visit(arg) for arg in node.args]
        return node

def optimize_ast(ast: ASTNode) -> ASTNode:
    """
    Optimize the AST and print detailed optimization information.
    
    Args:
        ast: The AST to optimize
    Returns:
        The optimized AST
    """
    
    optimizer = Optimizer()
    optimized_ast = optimizer.optimize(ast)
    
    print(f"\nApplied {optimizer.optimizations_applied} optimizations")
   
    try:
        dot_after = visualize_ast(optimized_ast)
        dot_after.render('ast_after_opt', view=True, format='png')
    except Exception as e:
        print(f"Visualization error: {e}")
    
    return optimized_ast
        
if __name__ == "__main__":
    source_code = """
def calculate():
    a = 5 + 3 * 2
    b = 10 / 2
    c = a + b - (2 * 3)
    return c

result = calculate()
"""

    lexer = Lexer(source_code)
    parser = Parser(lexer)
    ast = parser.parse()
    analyzer = SemanticAnalyzer()

    try:
        analyzer.analyze(ast)
    except SemanticError as e:
        print(f"\nSemantic Error: {e}")

    print("\nBefore optimization:")
    print_ast(ast)
    
    optimized_ast = optimize_ast(ast)
    
    print("\nAfter optimization:")
    print_ast(optimized_ast)


Before optimization:
Program
  FunctionDef(calculate) [return_type=int]
    Assignment(target=a) [type=int]
      BinaryOp(+) [type=int, const=11]
        Literal(5) [type=int, const=5]
        BinaryOp(*) [type=int, const=6]
          Literal(3) [type=int, const=3]
          Literal(2) [type=int, const=2]
    Assignment(target=b) [type=int]
      BinaryOp(/) [type=int, const=5.0]
        Literal(10) [type=int, const=10]
        Literal(2) [type=int, const=2]
    Assignment(target=c) [type=int]
      BinaryOp(-) [type=int]
        BinaryOp(+) [type=int]
          Identifier(a) [type=int]
          Identifier(b) [type=int]
        BinaryOp(*) [type=int, const=6]
          Literal(2) [type=int, const=2]
          Literal(3) [type=int, const=3]
    Return [type=int]
      Identifier(c) [type=int]
  Assignment(target=result) [type=int]
    Call [type=int]
      Identifier(calculate)

Applied 4 optimizations

After optimization:
Program
  FunctionDef(calculate) [return_type=int]
    Assign

## Code Generation

In [None]:
#Note: let is used for all variable decalarations; Python's // becomes Math.floor(a / b) in JavaScript. and Converts range() calls to JavaScript array creation.

class JavaScriptCodeGenerator:
    def __init__(self):
        self.indent_level = 0
        self.indent_size = 4
        self.declared_variables = set()  # Track declared variables to avoid re-declaring
    
    def indent(self):
        """Get current indentation string"""
        return " " * (self.indent_level * self.indent_size)
    
    def increase_indent(self):
        """Increase indentation level"""
        self.indent_level += 1
    
    def decrease_indent(self):
        """Decrease indentation level"""
        self.indent_level = max(0, self.indent_level - 1)
    
    def generate(self, node):
        """Main entry point for code generation"""
        if node is None:
            return ""
        
        # Dispatch to appropriate method based on node type
        method_name = f"generate_{type(node).__name__}"
        if hasattr(self, method_name):
            return getattr(self, method_name)(node)
        else:
            raise NotImplementedError(f"Code generation not implemented for {type(node).__name__}")
    
    def generate_Program(self, node):
        """Generate code for the entire program"""
        lines = []
        for stmt in node.body:
            generated = self.generate(stmt)
            if generated.strip():  # Only add non-empty lines
                lines.append(generated)
        return "\n".join(lines)
    
    def generate_FunctionDef(self, node):
        """Generate JavaScript function definition"""
        params = ", ".join(node.params)
        lines = [f"{self.indent()}function {node.name}({params}) {{"]
        
        self.increase_indent()
        for stmt in node.body:
            generated = self.generate(stmt)
            if generated.strip():
                lines.append(generated)
        self.decrease_indent()
        
        lines.append(f"{self.indent()}}}")
        return "\n".join(lines)
    
    def generate_IfStatement(self, node):
        """Generate JavaScript if statement"""
        condition = self.generate(node.condition)
        lines = [f"{self.indent()}if ({condition}) {{"]
        
        # Generate then branch
        self.increase_indent()
        for stmt in node.then_branch:
            generated = self.generate(stmt)
            if generated.strip():
                lines.append(generated)
        self.decrease_indent()
        
        # Generate elif branches
        for elif_condition, elif_body in node.elif_branches:
            elif_cond = self.generate(elif_condition)
            lines.append(f"{self.indent()}}} else if ({elif_cond}) {{")
            
            self.increase_indent()
            for stmt in elif_body:
                generated = self.generate(stmt)
                if generated.strip():
                    lines.append(generated)
            self.decrease_indent()
        
        # Generate else branch
        if node.else_branch:
            lines.append(f"{self.indent()}}} else {{")
            
            self.increase_indent()
            for stmt in node.else_branch:
                generated = self.generate(stmt)
                if generated.strip():
                    lines.append(generated)
            self.decrease_indent()
        
        lines.append(f"{self.indent()}}}")
        return "\n".join(lines)
    
    def generate_WhileLoop(self, node):
        """Generate JavaScript while loop"""
        condition = self.generate(node.condition)
        lines = [f"{self.indent()}while ({condition}) {{"]
        
        self.increase_indent()
        for stmt in node.body:
            generated = self.generate(stmt)
            if generated.strip():
                lines.append(generated)
        self.decrease_indent()
        
        lines.append(f"{self.indent()}}}")
        return "\n".join(lines)
    
    def generate_ForLoop(self, node):
        """Generate JavaScript for loop (converts Python for to JS for...of)"""
        iterable = self.generate(node.iterable)
        
        # For loop variables are block-scoped, so we can always use let
        lines = [f"{self.indent()}for (let {node.var} of {iterable}) {{"]
        
        self.increase_indent()
        for stmt in node.body:
            generated = self.generate(stmt)
            if generated.strip():
                lines.append(generated)
        self.decrease_indent()
        
        lines.append(f"{self.indent()}}}")
        return "\n".join(lines)
    
    def generate_Return(self, node):
        """Generate JavaScript return statement"""
        if node.value:
            value = self.generate(node.value)
            return f"{self.indent()}return {value};"
        else:
            return f"{self.indent()}return;"
    
    def generate_Assignment(self, node):
        """Generate JavaScript variable assignment"""
        value = self.generate(node.value)
        
        # Check if variable was already declared to avoid re-declaration
        if node.target in self.declared_variables:
            # Just assignment, no declaration
            return f"{self.indent()}{node.target} = {value};"
        else:
            # First declaration - use let
            self.declared_variables.add(node.target)
            return f"{self.indent()}let {node.target} = {value};"
    
    def generate_ExpressionStatement(self, node):
        """Generate JavaScript expression statement"""
        expr = self.generate(node.expression)
        return f"{self.indent()}{expr};"
    
    def generate_BinaryOp(self, node):
        """Generate JavaScript binary operation"""
        left = self.generate(node.left)
        right = self.generate(node.right)
        
        # Map Python operators to JavaScript operators
        operator_map = {
            '+': '+',
            '-': '-',
            '*': '*',
            '/': '/',
            '//': 'Math.floor(/ )',  # Floor division needs special handling
            '%': '%',
            '**': '**',  # Exponentiation
            '==': '===',  # Use strict equality
            '!=': '!==',  # Use strict inequality
            '<': '<',
            '>': '>',
            '<=': '<=',
            '>=': '>=',
            'and': '&&',
            'or': '||',
            'in': 'in',
        }
        
        js_operator = operator_map.get(node.operator, node.operator)
        
        # Special handling for floor division
        if node.operator == '//':
            return f"Math.floor({left} / {right})"
        
        return f"({left} {js_operator} {right})"
    
    def generate_UnaryOp(self, node):
        """Generate JavaScript unary operation"""
        operand = self.generate(node.operand)
        
        # Map Python unary operators to JavaScript
        operator_map = {
            '+': '+',
            '-': '-',
            'not': '!',
        }
        
        js_operator = operator_map.get(node.operator, node.operator)
        return f"({js_operator}{operand})"
    
    def generate_Call(self, node):
        """Generate JavaScript function call"""
        func_name = self.generate(node.func)
        args = [self.generate(arg) for arg in node.args]
        
        # Handle built-in Python functions
        if isinstance(node.func, Identifier):
            if node.func.name == 'print':
                # Convert print() to console.log()
                return f"console.log({', '.join(args)})"
            elif node.func.name == 'len':
                # Convert len() to .length property
                if args:
                    return f"{args[0]}.length"
            elif node.func.name == 'range':
                # Convert range() to array creation
                if len(args) == 1:
                    return f"Array.from({{length: {args[0]}}}, (_, i) => i)"
                elif len(args) == 2:
                    return f"Array.from({{length: {args[1]} - {args[0]}}}, (_, i) => i + {args[0]})"
                elif len(args) == 3:
                    start, stop, step = args
                    return f"Array.from({{length: Math.ceil(({stop} - {start}) / {step})}}, (_, i) => {start} + i * {step})"
        
        # Regular function call
        return f"{func_name}({', '.join(args)})"
    
    def generate_Identifier(self, node):
        """Generate JavaScript identifier"""
        return node.name
    
    def generate_Literal(self, node):
        """Generate JavaScript literal"""
        if isinstance(node.value, str):
            # Escape quotes and return as string literal
            escaped = node.value.replace('\\', '\\\\').replace('"', '\\"')
            return f'"{escaped}"'
        elif isinstance(node.value, bool):
            # Convert Python True/False to JavaScript true/false
            return str(node.value).lower()
        elif node.value is None:
            # Convert Python None to JavaScript null
            return "null"
        else:
            # Numbers remain the same
            return str(node.value)


def generate_javascript_code(ast_node):
    """
    Main function to generate JavaScript code from an optimized AST
    
    Args:
        ast_node: The root AST node (usually a Program node)
    
    Returns:
        str: Generated JavaScript code
    """
    generator = JavaScriptCodeGenerator()
    return generator.generate(ast_node)


if __name__ == "__main__":
    test1_python = """
def calculate():
    a = 5 + 3 * 2
    b = 10 / 2
    c = a + b - (2 * 3)
    return c

result = calculate()
"""

    lexer = Lexer(test1_python)
    parser = Parser(lexer)
    ast = parser.parse()
    analyzer = SemanticAnalyzer()
    analyzer.analyze(ast)    
    
    print("Original Python Code:")
    print(test1_python)
    print("\n======================================================================================================================\n")
    generator = JavaScriptCodeGenerator()
    unoptimized_js_code = generator.generate(ast)
    print("Javascript Code (Without Optimizations):")
    print(unoptimized_js_code)
    print("\n======================================================================================================================\n")
    optimized_ast = optimize_ast(ast)
    generator = JavaScriptCodeGenerator()
    js_code = generator.generate(optimized_ast)
    print("Javascript Code After Optimizations:")
    print(js_code)

Original Python Code:

def calculate():
    a = 5 + 3 * 2
    b = 10 / 2
    c = a + b - (2 * 3)
    return c

result = calculate()



Javascript Code (Without Optimizations):
function calculate() {
    let a = (5 + (3 * 2));
    let b = (10 / 2);
    let c = ((a + b) - (2 * 3));
    return c;
}
let result = calculate();



Applied 4 optimizations

Javascript Code After Optimizations:
function calculate() {
    let a = 11;
    let b = 5.0;
    let c = ((a + b) - 6);
    return c;
}
let result = calculate();
