<a href="https://colab.research.google.com/github/ShashankShorya0211/MIMDPU/blob/main/MIMDPU1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import re
import traceback
from enum import Enum
import math

class TokenType(Enum):
    NUMBER = 'NUMBER'
    KEYWORD = 'KEYWORD'
    IDENTIFIER = 'IDENTIFIER'
    OPERATOR = 'OPERATOR'
    PUNCTUATOR = 'PUNCTUATOR'
    EOF = 'EOF'

class Token:
    def __init__(self, type, value, line, column):
        self.type = type
        self.value = value
        self.line = line
        self.column = column

    def __str__(self):
        return f'Token({self.type}, {self.value}, line={self.line}, col={self.column})'

class Lexer:
    def __init__(self, source_code):
        self.source_code = source_code
        self.position = 0
        self.line = 1
        self.column = 1
        self.current_char = self.source_code[self.position] if self.position < len(self.source_code) else None

    def advance(self):
        if self.current_char == '\n':
            self.line += 1
            self.column = 1
        else:
            self.column += 1
        self.position += 1
        self.current_char = self.source_code[self.position] if self.position < len(self.source_code) else None

    def peek(self):
        peek_pos = self.position + 1
        return self.source_code[peek_pos] if peek_pos < len(self.source_code) else None

    def skip_whitespace(self):
        while self.current_char is not None and self.current_char.isspace():
            self.advance()

    def get_number(self):
        result = ''
        while self.current_char is not None and self.current_char.isdigit():
            result += self.current_char
            self.advance()
        if self.current_char == '.':
            result += self.current_char
            self.advance()
            while self.current_char is not None and self.current_char.isdigit():
                result += self.current_char
                self.advance()
        return Token(TokenType.NUMBER, result, self.line, self.column)

    def get_identifier(self):
        result = ''
        while self.current_char is not None and (self.current_char.isalnum() or self.current_char == '_'):
            result += self.current_char
            self.advance()
        if result in ['int', 'float', 'if', 'else', 'while', 'for', 'return', 'void']:
            return Token(TokenType.KEYWORD, result, self.line, self.column)
        return Token(TokenType.IDENTIFIER, result, self.line, self.column)

    def get_operator(self):
        operators = {
            '+': 'PLUS', '-': 'MINUS', '*': 'MULTIPLY', '/': 'DIVIDE',
            '=': 'ASSIGN', '==': 'EQUALS', '!=': 'NOT_EQUALS',
            '<': 'LESS_THAN', '>': 'GREATER_THAN', '<=': 'LESS_EQUAL', '>=': 'GREATER_EQUAL',
            '&&': 'AND', '||': 'OR', '!': 'NOT',
            '&': 'BITWISE_AND', '|': 'BITWISE_OR', '^': 'BITWISE_XOR', '~': 'BITWISE_NOT'
        }
        op = self.current_char
        self.advance()
        if op + self.current_char in operators:
            op += self.current_char
            self.advance()
        return Token(TokenType.OPERATOR, operators.get(op, op), self.line, self.column)

    def get_punctuator(self):
        punctuators = {'{': 'LBRACE', '}': 'RBRACE', '(': 'LPAREN', ')': 'RPAREN', ';': 'SEMI', ',': 'COMMA'}
        punct = self.current_char
        self.advance()
        return Token(TokenType.PUNCTUATOR, punctuators.get(punct, punct), self.line, self.column)

    def tokenize(self):
        tokens = []
        while self.current_char is not None:
            if self.current_char.isspace():
                self.skip_whitespace()
                continue
            if self.current_char.isdigit():
                tokens.append(self.get_number())
            elif self.current_char.isalpha() or self.current_char == '_':
                tokens.append(self.get_identifier())
            elif self.current_char in '+-*/=!<>&|^~':
                tokens.append(self.get_operator())
            elif self.current_char in '{}();,':
                tokens.append(self.get_punctuator())
            else:
                raise Exception(f'Invalid character: {self.current_char} at line {self.line}, column {self.column}')
        tokens.append(Token(TokenType.EOF, None, self.line, self.column))
        return tokens

class ASTNode:
    def __init__(self, type, children=None, value=None):
        self.type = type
        self.children = children if children else []
        self.value = value

    def __str__(self):
        return f"{self.type}: {self.value if self.value is not None else ''}"

class Parser:
    def __init__(self, tokens):
        self.tokens = tokens
        self.current_token = None
        self.token_index = -1
        self.advance()

    def advance(self):
        self.token_index += 1
        if self.token_index < len(self.tokens):
            self.current_token = self.tokens[self.token_index]
        else:
            self.current_token = Token(TokenType.EOF, None, -1, -1)

    def eat(self, token_type):
        if self.current_token.type == token_type:
            token = self.current_token
            self.advance()
            return token
        else:
            raise Exception(f'Expected {token_type}, got {self.current_token.type} at line {self.current_token.line}, column {self.current_token.column}')

    def parse(self):
        return self.program()


    def program(self):
        node = ASTNode('Program')
        while self.current_token.type != TokenType.EOF:
            node.children.append(self.function_definition())
        return node

    def function_definition(self):
        return_type = self.eat(TokenType.KEYWORD).value
        name = self.eat(TokenType.IDENTIFIER).value
        self.eat(TokenType.PUNCTUATOR)  # (
        params = self.parameter_list()
        self.eat(TokenType.PUNCTUATOR)  # )
        body = self.compound_statement()
        return ASTNode('FunctionDefinition', [ASTNode('ReturnType', value=return_type), ASTNode('FunctionName', value=name), params, body])

    def parameter_list(self):
        params = []
        while self.current_token.type != TokenType.PUNCTUATOR or self.current_token.value != 'RPAREN':
            param_type = self.eat(TokenType.KEYWORD).value
            param_name = self.eat(TokenType.IDENTIFIER).value
            params.append(ASTNode('Parameter', [ASTNode('Type', value=param_type), ASTNode('Name', value=param_name)]))
            if self.current_token.type == TokenType.PUNCTUATOR and self.current_token.value == 'COMMA':
                self.advance()
        return ASTNode('ParameterList', params)

    def compound_statement(self):
        self.eat(TokenType.PUNCTUATOR)  # {
        statements = []
        while self.current_token.type != TokenType.PUNCTUATOR or self.current_token.value != 'RBRACE':
            statements.append(self.statement())
        self.eat(TokenType.PUNCTUATOR)  # }
        return ASTNode('CompoundStatement', statements)


    def statement(self):
        if self.current_token.type == TokenType.KEYWORD:
            if self.current_token.value in ['int', 'float']:
                return self.variable_declaration()
            elif self.current_token.value == 'if':
                return self.if_statement()
            elif self.current_token.value == 'while':
                return self.while_statement()
            elif self.current_token.value == 'for':
                return self.for_statement()
            elif self.current_token.value == 'return':
                return self.return_statement()
        elif self.current_token.type == TokenType.IDENTIFIER:
            return self.assignment_statement()
        else:
            return self.expression_statement()

    def variable_declaration(self):
        var_type = self.eat(TokenType.KEYWORD).value
        var_name = self.eat(TokenType.IDENTIFIER).value
        if self.current_token.type == TokenType.OPERATOR and self.current_token.value == 'ASSIGN':
            self.eat(TokenType.OPERATOR)  # Consume the '='
            init_expr = self.expression()
            self.eat(TokenType.PUNCTUATOR)  # Consume the ';'
            return ASTNode('VariableDeclaration', [
                ASTNode('Type', value=var_type),
                ASTNode('Name', value=var_name),
                init_expr
            ])
        else:
            self.eat(TokenType.PUNCTUATOR)  # Consume the ';'
            return ASTNode('VariableDeclaration', [
                ASTNode('Type', value=var_type),
                ASTNode('Name', value=var_name)
            ])

    def if_statement(self):
        self.eat(TokenType.KEYWORD)  # if
        self.eat(TokenType.PUNCTUATOR)  # (
        condition = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # )
        if_body = self.compound_statement()
        else_body = None
        if self.current_token.type == TokenType.KEYWORD and self.current_token.value == 'else':
            self.advance()
            else_body = self.compound_statement()
        return ASTNode('IfStatement', [condition, if_body, else_body] if else_body else [condition, if_body])

    def while_statement(self):
        self.eat(TokenType.KEYWORD)  # while
        self.eat(TokenType.PUNCTUATOR)  # (
        condition = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # )
        body = self.compound_statement()
        return ASTNode('WhileStatement', [condition, body])

    def for_statement(self):
        self.eat(TokenType.KEYWORD)  # for
        self.eat(TokenType.PUNCTUATOR)  # (
        init = self.statement()
        condition = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # ;
        update = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # )
        body = self.compound_statement()
        return ASTNode('ForStatement', [init, condition, update, body])

    def return_statement(self):
        self.eat(TokenType.KEYWORD)  # return
        expr = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # ;
        return ASTNode('ReturnStatement', [expr])

    def assignment_statement(self):
        var_name = self.eat(TokenType.IDENTIFIER).value
        self.eat(TokenType.OPERATOR)  # =
        expr = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # ;
        return ASTNode('Assignment', [ASTNode('Variable', value=var_name), expr])

    def expression_statement(self):
        expr = self.expression()
        self.eat(TokenType.PUNCTUATOR)  # ;
        return ASTNode('ExpressionStatement', [expr])

    def expression(self):
        return self.logical_or()

    def logical_or(self):
        node = self.logical_and()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value == 'OR':
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.logical_and()], value=op)
        return node

    def logical_and(self):
        node = self.equality()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value == 'AND':
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.equality()], value=op)
        return node

    def equality(self):
        node = self.relational()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value in ['EQUALS', 'NOT_EQUALS']:
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.relational()], value=op)
        return node

    def relational(self):
        node = self.bitwise()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value in ['LESS_THAN', 'GREATER_THAN', 'LESS_EQUAL', 'GREATER_EQUAL']:
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.bitwise()], value=op)
        return node

    def bitwise(self):
        node = self.additive()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value in ['BITWISE_AND', 'BITWISE_OR', 'BITWISE_XOR']:
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.additive()], value=op)
        return node

    def additive(self):
        node = self.multiplicative()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value in ['PLUS', 'MINUS']:
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.multiplicative()], value=op)
        return node

    def multiplicative(self):
        node = self.unary()
        while self.current_token.type == TokenType.OPERATOR and self.current_token.value in ['MULTIPLY', 'DIVIDE']:
            op = self.eat(TokenType.OPERATOR).value
            node = ASTNode('BinaryOp', [node, self.unary()], value=op)
        return node

    def unary(self):
        if self.current_token.type == TokenType.OPERATOR and self.current_token.value in ['PLUS', 'MINUS', 'NOT', 'BITWISE_NOT']:
            op = self.eat(TokenType.OPERATOR).value
            return ASTNode('UnaryOp', [self.unary()], value=op)
        return self.primary()

    def primary(self):
        token = self.current_token
        if token.type == TokenType.NUMBER:
            self.advance()
            return ASTNode('Literal', value=token.value)
        elif token.type == TokenType.IDENTIFIER:
            self.advance()
            if self.current_token.type == TokenType.PUNCTUATOR and self.current_token.value == 'LPAREN':
                return self.function_call()
            return ASTNode('Variable', value=token.value)
        elif token.type == TokenType.PUNCTUATOR and token.value == 'LPAREN':
            self.advance()
            expr = self.expression()
            self.eat(TokenType.PUNCTUATOR)  # )
            return expr
        else:
            raise Exception(f'Unexpected token: {token.type} at line {token.line}, column {token.column}')

    def function_call(self):
        function_name = self.current_token.value
        self.eat(TokenType.PUNCTUATOR)  # (
        args = []
        if self.current_token.type != TokenType.PUNCTUATOR or self.current_token.value != 'RPAREN':
            args.append(self.expression())
            while self.current_token.type == TokenType.PUNCTUATOR and self.current_token.value == 'COMMA':
                self.advance()
                args.append(self.expression())
        self.eat(TokenType.PUNCTUATOR)  # )
        return ASTNode('FunctionCall', [ASTNode('FunctionName', value=function_name)] + args)


    def compound_statement(self):
        self.eat(TokenType.PUNCTUATOR)  # {
        statements = []
        while self.current_token.type != TokenType.PUNCTUATOR or self.current_token.value != 'RBRACE':
            statements.append(self.statement())
        self.eat(TokenType.PUNCTUATOR)  # }
        return ASTNode('CompoundStatement', statements)


class Symbol:
    def __init__(self, name, type):
        self.name = name
        self.type = type

class SymbolTable:
    def __init__(self):
        self.symbols = {}
        self.parent = None

    def define(self, name, type):
        self.symbols[name] = Symbol(name, type)

    def lookup(self, name):
        if name in self.symbols:
            return self.symbols[name]
        elif self.parent:
            return self.parent.lookup(name)
        return None

class SemanticAnalyzer:
    def __init__(self, ast):
        self.ast = ast
        self.current_scope = SymbolTable()
        self.function_table = {}
        self.current_function = None

    def analyze(self):
        self.visit(self.ast)

    def visit(self, node):
        method_name = f'visit_{node.type}'
        method = getattr(self, method_name, self.generic_visit)
        return method(node)

    def generic_visit(self, node):
        for child in node.children:
            self.visit(child)

    def visit_Program(self, node):
        for child in node.children:
            self.visit(child)

    def visit_FunctionDefinition(self, node):
        return_type = node.children[0].value
        function_name = node.children[1].value
        params = node.children[2]
        body = node.children[3]

        if function_name in self.function_table:
            raise Exception(f"Function {function_name} already defined")

        self.function_table[function_name] = {
            'return_type': return_type,
            'params': self.visit(params)
        }

        # Set the current function
        self.current_function = function_name

        # Create a new scope for the function
        function_scope = SymbolTable()
        function_scope.parent = self.current_scope
        self.current_scope = function_scope

        # Add parameters to the function scope
        for param in self.function_table[function_name]['params']:
            self.current_scope.define(param['name'], param['type'])

        self.visit(body)

        # Restore the previous scope
        self.current_scope = self.current_scope.parent

        # Clear the current function
        self.current_function = None

    def visit_ParameterList(self, node):
        params = []
        for param in node.children:
            param_type = param.children[0].value
            param_name = param.children[1].value
            params.append({'name': param_name, 'type': param_type})
        return params

    def visit_CompoundStatement(self, node):
        for child in node.children:
            self.visit(child)

    def visit_CompoundStatement(self, node):
        for child in node.children:
            self.visit(child)

    def visit_VariableDeclaration(self, node):
        var_type = node.children[0].value
        var_name = node.children[1].value
        if self.current_scope.lookup(var_name):
            raise Exception(f"Variable {var_name} already declared in this scope")
        self.current_scope.define(var_name, var_type)

        if len(node.children) > 2:
            init_expr_type = self.visit(node.children[2])
            if init_expr_type != var_type:
                raise Exception(f"Type mismatch in initialization of {var_name}: expected {var_type}, got {init_expr_type}")

    def visit_Assignment(self, node):
        var_name = node.children[0].value
        expr_type = self.visit(node.children[1])
        var_symbol = self.current_scope.lookup(var_name)
        if not var_symbol:
            raise Exception(f"Variable {var_name} not declared")
        if var_symbol.type != expr_type:
            raise Exception(f"Type mismatch in assignment to {var_name}: expected {var_symbol.type}, got {expr_type}")

    def visit_IfStatement(self, node):
        condition_type = self.visit(node.children[0])
        if condition_type != 'int':  # In C, conditions are typically integers
            raise Exception("Condition in if statement must be of type int")
        self.visit(node.children[1])  # if body
        if len(node.children) > 2:
            self.visit(node.children[2])  # else body

    def visit_WhileStatement(self, node):
        condition_type = self.visit(node.children[0])
        if condition_type != 'int':  # In C, conditions are typically integers
            raise Exception("Condition in while statement must be of type int")
        self.visit(node.children[1])  # body

    def visit_ForStatement(self, node):
        self.visit(node.children[0])  # initialization
        condition_type = self.visit(node.children[1])
        if condition_type != 'int':  # In C, conditions are typically integers
            raise Exception("Condition in for statement must be of type int")
        self.visit(node.children[2])  # update
        self.visit(node.children[3])  # body

    def visit_ReturnStatement(self, node):
        expr_type = self.visit(node.children[0])
        if self.current_function is None:
            raise Exception("Return statement outside of function")
        if expr_type != self.function_table[self.current_function]['return_type']:
            raise Exception(f"Return type mismatch: expected {self.function_table[self.current_function]['return_type']}, got {expr_type}")

    def visit_BinaryOp(self, node):
        left_type = self.visit(node.children[0])
        right_type = self.visit(node.children[1])
        if left_type != right_type:
            raise Exception(f"Type mismatch in binary operation: {left_type} {node.value} {right_type}")
        if node.value in ['PLUS', 'MINUS', 'MULTIPLY', 'DIVIDE', 'BITWISE_AND', 'BITWISE_OR', 'BITWISE_XOR']:
            return left_type
        elif node.value in ['LESS_THAN', 'GREATER_THAN', 'LESS_EQUAL', 'GREATER_EQUAL', 'EQUALS', 'NOT_EQUALS', 'AND', 'OR']:
            return 'int'  # In C, these operations return an integer
        else:
            raise Exception(f"Unknown binary operator: {node.value}")

    def visit_UnaryOp(self, node):
        operand_type = self.visit(node.children[0])
        if node.value in ['PLUS', 'MINUS', 'BITWISE_NOT']:
            return operand_type
        elif node.value == 'NOT':
            return 'int'  # In C, logical NOT returns an integer
        else:
            raise Exception(f"Unknown unary operator: {node.value}")

    def visit_Literal(self, node):
        if '.' in node.value:
            return 'float'
        else:
            return 'int'

    def visit_Variable(self, node):
        var_symbol = self.current_scope.lookup(node.value)
        if not var_symbol:
            raise Exception(f"Variable {node.value} not declared")
        return var_symbol.type

    def visit_FunctionCall(self, node):
        function_name = node.children[0].value
        if function_name not in self.function_table:
            raise Exception(f"Function {function_name} not defined")
        expected_params = self.function_table[function_name]['params']
        actual_params = node.children[1:]
        if len(expected_params) != len(actual_params):
            raise Exception(f"Function {function_name} called with wrong number of arguments")
        for i, (expected, actual) in enumerate(zip(expected_params, actual_params)):
            actual_type = self.visit(actual)
            if actual_type != expected['type']:
                raise Exception(f"Type mismatch in argument {i+1} of function {function_name}: expected {expected['type']}, got {actual_type}")
        return self.function_table[function_name]['return_type']

    def visit_ExpressionStatement(self, node):
        self.visit(node.children[0])


class IRInstruction:
    def __init__(self, op, args):
        self.op = op
        self.args = args

    def __str__(self):
        return f"{self.op} {' '.join(map(str, self.args))}"

class IRGenerator:
    def __init__(self, ast):
        self.ast = ast
        self.instructions = []
        self.temp_counter = 0
        self.label_counter = 0
        self.current_function = None

    def generate(self):
        self.visit(self.ast)
        return self.instructions

    def emit(self, op, args):
        self.instructions.append(IRInstruction(op, args))

    def visit(self, node):
        method_name = f'visit_{node.type}'
        method = getattr(self, method_name, self.generic_visit)
        return method(node)

    def generic_visit(self, node):
        raise Exception(f'No visit_{node.type} method')

    def visit_Program(self, node):
        for child in node.children:
            self.visit(child)

    def visit_FunctionDefinition(self, node):
        function_name = node.children[1].value
        self.current_function = function_name
        self.emit('FUNCTION_BEGIN', [function_name])
        self.visit(node.children[3])  # Visit function body
        self.emit('FUNCTION_END', [function_name])
        self.current_function = None

    def visit_CompoundStatement(self, node):
        for child in node.children:
            self.visit(child)

    def visit_VariableDeclaration(self, node):
        var_name = node.children[1].value
        self.emit('ALLOC', [var_name])
        if len(node.children) > 2:
            init_value = self.visit(node.children[2])
            self.emit('STORE', [var_name, init_value])

    def visit_Assignment(self, node):
        var_name = node.children[0].value
        value = self.visit(node.children[1])
        self.emit('STORE', [var_name, value])

    def visit_BinaryOp(self, node):
        left = self.visit(node.children[0])
        right = self.visit(node.children[1])
        result = self.new_temp()
        operation = {
            'PLUS': 'ADD',
            'MINUS': 'SUB',
            'MULTIPLY': 'MUL',
            'DIVIDE': 'DIV',
            'BITWISE_AND': 'AND',
            'BITWISE_OR': 'OR',
            'BITWISE_XOR': 'XOR',
            'LESS_THAN': 'LT',
            'GREATER_THAN': 'GT',
            'LESS_EQUAL': 'LE',
            'GREATER_EQUAL': 'GE',
            'EQUALS': 'EQ',
            'NOT_EQUALS': 'NE',
            'AND': 'LAND',
            'OR': 'LOR'
        }.get(node.value, node.value)
        self.emit(operation, [result, left, right])
        return result

    def visit_UnaryOp(self, node):
        operand = self.visit(node.children[0])
        result = self.new_temp()
        operation = {
            'PLUS': 'POS',
            'MINUS': 'NEG',
            'NOT': 'NOT',
            'BITWISE_NOT': 'BNOT'
        }.get(node.value, node.value)
        self.emit(operation, [result, operand])
        return result

    def visit_Literal(self, node):
        result = self.new_temp()
        self.emit('LOAD_CONST', [result, node.value])
        return result

    def visit_Variable(self, node):
        result = self.new_temp()
        self.emit('LOAD', [result, node.value])
        return result

    def visit_ReturnStatement(self, node):
        value = self.visit(node.children[0])
        self.emit('RETURN', [value])

    def new_temp(self):
        self.temp_counter += 1
        return f't{self.temp_counter}'

    def new_label(self):
        self.label_counter += 1
        return f'L{self.label_counter}'


class IROptimizer:
    def __init__(self, ir):
        self.ir = ir

    def optimize(self):
        self.constant_folding()
        self.dead_code_elimination()
        self.common_subexpression_elimination()
        return self.ir

    def constant_folding(self):
        constants = {}
        new_ir = []
        for inst in self.ir:
            if inst.op == 'LOAD_CONST':
                constants[inst.args[0]] = inst.args[1]
            elif inst.op in ['ADD', 'SUB', 'MUL', 'DIV', 'AND', 'OR', 'XOR']:
                result, op1, op2 = inst.args
                if op1 in constants and op2 in constants:
                    value = self.compute_constant(inst.op, constants[op1], constants[op2])
                    new_ir.append(IRInstruction('LOAD_CONST', [result, value]))
                    constants[result] = value
                else:
                    new_ir.append(inst)
                    constants.pop(result, None)
            else:
                new_ir.append(inst)
                if inst.op != 'STORE':
                    for arg in inst.args:
                        constants.pop(arg, None)
        self.ir = new_ir

    def compute_constant(self, op, val1, val2):
        if op == 'ADD':
            return val1 + val2
        elif op == 'SUB':
            return val1 - val2
        elif op == 'MUL':
            return val1 * val2
        elif op == 'DIV':
            return val1 // val2 if isinstance(val1, int) and isinstance(val2, int) else val1 / val2
        elif op == 'AND':
            return val1 & val2
        elif op == 'OR':
            return val1 | val2
        elif op == 'XOR':
            return val1 ^ val2

    def dead_code_elimination(self):
        used_vars = set()
        new_ir = []
        for inst in reversed(self.ir):
            if inst.op in ['STORE', 'RETURN'] or inst.op.startswith('JUMP'):
                new_ir.insert(0, inst)
                used_vars.update(inst.args)
            elif any(arg in used_vars for arg in inst.args):
                new_ir.insert(0, inst)
                used_vars.update(inst.args[1:])
                used_vars.add(inst.args[0])
        self.ir = new_ir

    def common_subexpression_elimination(self):
        expr_to_var = {}
        new_ir = []
        for inst in self.ir:
            if inst.op in ['ADD', 'SUB', 'MUL', 'DIV', 'AND', 'OR', 'XOR']:
                expr = (inst.op, inst.args[1], inst.args[2])
                if expr in expr_to_var:
                    new_ir.append(IRInstruction('MOVE', [inst.args[0], expr_to_var[expr]]))
                else:
                    new_ir.append(inst)
                    expr_to_var[expr] = inst.args[0]
            else:
                new_ir.append(inst)
                if inst.op == 'STORE':
                    expr_to_var = {k: v for k, v in expr_to_var.items() if v != inst.args[0]}
        self.ir = new_ir

class MIMDPUInstruction:
    def __init__(self, op, args):
        self.op = op
        self.args = args

    def __str__(self):
        return f"{self.op} {', '.join(map(str, self.args))}"

# New instruction types
CONFIGURE_TREE = 'CONFIGURE_TREE'
SET_WEIGHTS = 'SET_WEIGHTS'
SET_ACTIVATION = 'SET_ACTIVATION'
PROCESS = 'PROCESS'
GET_OUTPUT = 'GET_OUTPUT'

class MIMDPUCodeGenerator:
    def __init__(self, ir):
        self.ir = ir
        self.code = []
        self.variable_map = {}
        self.next_variable_id = 0
        self.num_inputs = self.count_inputs(ir)
        self.tree_depth = math.ceil(math.log2(self.num_inputs))

    def count_inputs(self, ir):
        input_vars = set()
        for inst in ir:
            if inst.op in ['LOAD', 'LOAD_CONST']:
                input_vars.add(inst.args[1])
        return len(input_vars)

    def generate(self):
        self.configure_processing_unit()
        for inst in self.ir:
            self.generate_instruction(inst)
        return self.code

    def configure_processing_unit(self):
        self.code.append(MIMDPUInstruction(CONFIGURE_TREE, [self.tree_depth]))
        weights = [1.0] * (2**self.tree_depth)
        self.code.append(MIMDPUInstruction(SET_WEIGHTS, weights))
        self.code.append(MIMDPUInstruction(SET_ACTIVATION, ['RELU']))

    def generate_instruction(self, inst):
        if inst.op == 'ALLOC':
            self.generate_alloc(inst)
        elif inst.op in ['ADD', 'SUB', 'MUL', 'DIV', 'AND', 'OR', 'XOR']:
            self.generate_binary_op(inst)
        elif inst.op in ['NEG', 'NOT']:
            self.generate_unary_op(inst)
        elif inst.op == 'LOAD_CONST':
            self.generate_load_const(inst)
        elif inst.op == 'LOAD':
            self.generate_load(inst)
        elif inst.op == 'STORE':
            self.generate_store(inst)
        elif inst.op.startswith('JUMP'):
            self.generate_jump(inst)
        elif inst.op == 'CALL':
            self.generate_function_call(inst)
        elif inst.op == 'RETURN':
            self.generate_return(inst)
        else:
            raise Exception(f"Unsupported operation: {inst.op}")

    def generate_alloc(self, inst):
        var_name = inst.args[0]
        var_id = self.get_variable_id(var_name)
        self.code.append(MIMDPUInstruction('ALLOC', [var_id]))


    def generate_binary_op(self, inst):
        result, op1, op2 = inst.args
        op1_id = self.get_variable_id(op1)
        op2_id = self.get_variable_id(op2)
        result_id = self.get_variable_id(result)

        # Set weights for this operation
        weights = [1.0, 1.0] + [0.0] * (2**self.tree_depth - 2)
        self.code.append(MIMDPUInstruction(SET_WEIGHTS, weights))

        # Set the operator in the first layer
        self.code.append(MIMDPUInstruction(CONFIGURE_TREE, [self.tree_depth, 0, inst.op]))

        # Process
        self.code.append(MIMDPUInstruction(PROCESS, [op1_id, op2_id]))

        # Get result
        self.code.append(MIMDPUInstruction(GET_OUTPUT, [result_id]))


    def generate_unary_op(self, inst):
        result, operand = inst.args
        operand_id = self.get_variable_id(operand)
        result_id = self.get_variable_id(result)
        self.code.append(MIMDPUInstruction('SET_INPUT', [operand_id]))
        self.code.append(MIMDPUInstruction('SET_OPERATOR', [inst.op]))
        self.code.append(MIMDPUInstruction('COMPUTE', []))
        self.code.append(MIMDPUInstruction('GET_RESULT', [result_id]))

    def generate_load_const(self, inst):
        result, value = inst.args
        result_id = self.get_variable_id(result)
        self.code.append(MIMDPUInstruction('LOAD_CONST', [result_id, value]))

    def generate_load(self, inst):
        result, var_name = inst.args
        var_id = self.get_variable_id(var_name)
        result_id = self.get_variable_id(result)
        self.code.append(MIMDPUInstruction('LOAD', [result_id, var_id]))

    def generate_store(self, inst):
        var_name, value = inst.args
        var_id = self.get_variable_id(var_name)
        value_id = self.get_variable_id(value)
        self.code.append(MIMDPUInstruction('STORE', [var_id, value_id]))

    def generate_jump(self, inst):
        if inst.op == 'JUMP':
            self.code.append(MIMDPUInstruction('JUMP', inst.args))
        elif inst.op == 'JUMP_IF_FALSE':
            condition, label = inst.args
            condition_id = self.get_variable_id(condition)
            self.code.append(MIMDPUInstruction('JUMP_IF_FALSE', [condition_id, label]))

    def generate_function_call(self, inst):
        result, func_name, *args = inst.args
        arg_ids = [self.get_variable_id(arg) for arg in args]
        result_id = self.get_variable_id(result)
        self.code.append(MIMDPUInstruction('CALL', [func_name, result_id] + arg_ids))

    def generate_return(self, inst):
        value = inst.args[0]
        value_id = self.get_variable_id(value)
        self.code.append(MIMDPUInstruction('RETURN', [value_id]))

    def generate_alloc(self, inst):
        var_name = inst.args[0]
        var_id = self.get_variable_id(var_name)
        self.code.append(MIMDPUInstruction('ALLOC', [var_id]))

    def get_variable_id(self, var_name):
        if var_name not in self.variable_map:
            self.variable_map[var_name] = f'v{self.next_variable_id}'
            self.next_variable_id += 1
        return self.variable_map[var_name]

class BitSerialProcessor:
    def __init__(self, bit_width=32):
        self.bit_width = bit_width

    def process(self, a, b, op):
        result = 0
        carry = 0
        for i in range(self.bit_width):
            bit_a = (a >> i) & 1
            bit_b = (b >> i) & 1
            if op == 'ADD':
                sum_bit = bit_a ^ bit_b ^ carry
                carry = (bit_a & bit_b) | (bit_a & carry) | (bit_b & carry)
                result |= sum_bit << i
            elif op == 'MUL':
                partial_product = bit_a & bit_b
                result |= partial_product << i
            elif op == 'AND':
                result |= (bit_a & bit_b) << i
            elif op == 'OR':
                result |= (bit_a | bit_b) << i
            elif op == 'XOR':
                result |= (bit_a ^ bit_b) << i
            # Add more operations as needed
        return result

    def apply_activation(self, x, activation_type):
        if activation_type == 'RELU':
            return max(0, x)
        # Add more activation functions as needed
        return x

class CompilationError(Exception):
    pass

class Compiler:
    def __init__(self, source_code):
        self.source_code = source_code
        self.lexer = Lexer(source_code)
        self.parser = None
        self.semantic_analyzer = None
        self.ir_generator = None
        self.ir_optimizer = None
        self.code_generator = None
        self.bit_serial_processor = BitSerialProcessor()

    def compile(self):
        try:
            print("Starting compilation process...")

            # Lexical Analysis
            print("Performing lexical analysis...")
            tokens = self.lexer.tokenize()
            print("Tokens generated:")
            for token in tokens:
                print(f"  {token}")

            # Syntax Analysis
            print("\nPerforming syntax analysis...")
            self.parser = Parser(tokens)
            ast = self.parser.parse()
            print("Abstract Syntax Tree generated:")
            self.print_ast(ast)

            # Semantic Analysis
            print("\nPerforming semantic analysis...")
            self.semantic_analyzer = SemanticAnalyzer(ast)
            self.semantic_analyzer.analyze()
            print("Semantic analysis completed successfully.")

            # IR Generation
            print("\nGenerating Intermediate Representation...")
            self.ir_generator = IRGenerator(ast)
            ir = self.ir_generator.generate()
            print("Intermediate Representation generated:")
            for inst in ir:
                print(f"  {inst}")

            # IR Optimization
            print("\nOptimizing Intermediate Representation...")
            self.ir_optimizer = IROptimizer(ir)
            optimized_ir = self.ir_optimizer.optimize()
            print("Optimized Intermediate Representation:")
            for inst in optimized_ir:
                print(f"  {inst}")

            # MIMD-PU Code Generation
            print("\nGenerating MIMD-PU code...")
            self.code_generator = MIMDPUCodeGenerator(optimized_ir)
            mimdpu_code = self.code_generator.generate()
            print("MIMD-PU code generated:")
            for inst in mimdpu_code:
                print(f"  {inst}")

            print("\nCompilation completed successfully.")
            return mimdpu_code

        except Exception as e:
            raise Exception(f"Compilation error: {str(e)}\n{traceback.format_exc()}")


    def print_ast(self, node, level=0):
        print('  ' * level + f"{node.type}: {node.value if hasattr(node, 'value') else ''}")
        for child in node.children:
            self.print_ast(child, level + 1)

    def print_mimdpu_code(self, mimdpu_code):
        print("\nGenerated MIMD-PU code:")
        for inst in mimdpu_code:
            print(inst)

# Example usage
source_code = """
int main() {
    int a1 = 1;
    int a2 = 2;
    int a3 = 3;
    int a4 = 4;

    int b1 = 8;
    int b2 = 7;
    int b3 = 6;
    int b4 = 5;

    int r1 = a1 + b1;
    int r2 = a2 + b2;
    int r3 = a3 + b3;
    int r4 = a4 + b4;

    int result = r1 + r2 + r3 + r4;

    return result;
}
"""


compiler = Compiler(source_code)
try:
    mimdpu_code = compiler.compile()
    compiler.print_mimdpu_code(mimdpu_code)
except CompilationError as e:
    print(e)

Starting compilation process...
Performing lexical analysis...
Tokens generated:
  Token(TokenType.KEYWORD, int, line=2, col=4)
  Token(TokenType.IDENTIFIER, main, line=2, col=9)
  Token(TokenType.PUNCTUATOR, LPAREN, line=2, col=10)
  Token(TokenType.PUNCTUATOR, RPAREN, line=2, col=11)
  Token(TokenType.PUNCTUATOR, LBRACE, line=2, col=13)
  Token(TokenType.KEYWORD, int, line=3, col=8)
  Token(TokenType.IDENTIFIER, a1, line=3, col=11)
  Token(TokenType.OPERATOR, ASSIGN, line=3, col=13)
  Token(TokenType.NUMBER, 1, line=3, col=15)
  Token(TokenType.PUNCTUATOR, SEMI, line=3, col=16)
  Token(TokenType.KEYWORD, int, line=4, col=8)
  Token(TokenType.IDENTIFIER, a2, line=4, col=11)
  Token(TokenType.OPERATOR, ASSIGN, line=4, col=13)
  Token(TokenType.NUMBER, 2, line=4, col=15)
  Token(TokenType.PUNCTUATOR, SEMI, line=4, col=16)
  Token(TokenType.KEYWORD, int, line=5, col=8)
  Token(TokenType.IDENTIFIER, a3, line=5, col=11)
  Token(TokenType.OPERATOR, ASSIGN, line=5, col=13)
  Token(TokenType.