# Lab 6: Trusted Execution Environments
In this lab we will work with a simulated TEE, experimenting with the functionality it provides and how to use them.

## Components
- Enums - CPUMode, Opcode, ENCLSLeafs, ENCLULeafs, Reg, ArgType: These are used to pass data and define constants
- Assembler: static class which encodes instructions for the processor 
- Memory: Memory of the CPU, takes care of privileges, and has the memory encryption engine
- Processor: CPU, executes instructions which are given to it
- SecureBankApp: An example application which will be executed inside the enclave, performing simple banking operations like deposits and transfers for two users
- (Compromised)OS: interacts with the CPU to set up the enclave and execute the banking app inside the enclave

## Tasks of the lab
In this lab you will implement some functionality of the TEE which is currently missing, leaving it vulnerable. To highlight the flaws there is the CompromisedOS class which tries to execute 4 different attacks on the banking app running in the enclave. You should add the required functionality to thwart all 4 attack attempts. It's recommended that you implement the fixes in the order of thwarting the attack (i.e. first implement the security fix to thwart attack 1). 

After fixing the security flaws of the current implementation, you will implement a missing function in the banking app which is account withdraws. This allows you to get familiar with how the different components interact and how to use them.

<b> NOTE </b> You only need to write code inside the `Memory` class to protect against the attacks, the other code is provided for your reference but should not be altered. The locations for you to write code have been marked with TODOs.

In [None]:
import hashlib
import hmac
import json
import struct
import os
from enum import IntEnum

In [None]:
class CPUMode(IntEnum):
    NORMAL = 1
    SECURE = 2
    KERNEL = 3

class Opcode(IntEnum):
    MOV = 1
    ADD = 2
    SUB = 3
    MUL = 4
    CMP = 5
    JE  = 6
    JL  = 7
    LOAD  = 8
    STORE = 9
    ENCLS = 10
    ENCLU = 11

class ENCLSLeafs(IntEnum):
    ENCLS_ECREATE = 0
    ENCLS_EADD    = 1
    ENCLS_EINIT   = 2
    ENCLS_EEXTEND = 6

class ENCLULeafs(IntEnum):
    ENCLU_EENTER  = 2
    ENCLU_EEXIT   = 4

class Reg(IntEnum):
    RAX = 0; RBX = 1; RCX = 2; RDX = 3
    RSI = 4; RDI = 5; RIP = 6; RSP = 7

class ArgType(IntEnum):
    NONE = 0
    REG  = 1
    IMM  = 2

class Assembler:
    @staticmethod
    def get_instr_fmt():
        return "<BBIBIxxxxx" 

    @staticmethod
    def get_instr_size():
        return struct.calcsize(Assembler.get_instr_fmt())
        
    @staticmethod
    def get_arg(instr, i):
        if len(instr) <= i:
            return ArgType.NONE, 0
        elif isinstance(instr[i], Reg): 
            return ArgType.REG, instr[i].value
        else: 
            return ArgType.IMM, instr[i]
    @staticmethod            
    def encode(instr):  
        t1, v1 = Assembler.get_arg(instr, 1)
        t2, v2 = Assembler.get_arg(instr, 2)
        return instr[0].value, t1, v1, t2, v2

    @staticmethod
    def assemble(instructions):
        binary = bytearray()
        for instr in instructions:
            binary.extend(struct.pack(Assembler.get_instr_fmt(), *Assembler.encode(instr)))
        return binary
    
    @staticmethod
    def disassemble(binary_chunk):
        return struct.unpack(Assembler.get_instr_fmt(), binary_chunk)

class SecurityViolation(Exception): pass
class MemoryError(Exception): pass

In [None]:
class Memory:
    def __init__(self, memsize = 64 * 1024):
        self.BLOCK_SIZE = 32
        assert (memsize % self.BLOCK_SIZE) == 0, f"memory size needs to be a multiple of block size ({self.BLOCK_SIZE})"
        
        self.phys_mem = bytearray(memsize)
        self.SECURE_REGION_START, self.SECURE_REGION_END = self.BLOCK_SIZE * round(0.8 * memsize / self.BLOCK_SIZE), memsize
        NUM_SECURE_BLOCKS = (self.SECURE_REGION_END - self.SECURE_REGION_START + 1) // self.BLOCK_SIZE

        self.cpu = None 
        self.__root_key = b'\xcc\x9f\xd9\xc8\xdd\xccS~'
        self.__mek = os.urandom(16)

        self.__sram_counters = [0] * NUM_SECURE_BLOCKS
        self.__tags_ram = [None] * NUM_SECURE_BLOCKS

    def connect_cpu(self, cpu):
        self.cpu = cpu

    def sign_data(self, data):
        return hmac.new(self.__root_key, data, hashlib.sha256).hexdigest()

    def _assert_address(self, address):
        if not (self.SECURE_REGION_START <= address <= self.SECURE_REGION_END):
            raise MemoryError(f"MEE Panic: Address {hex(address)} outside Secure Region [{hex(self.SECURE_REGION_START)}, {hex(self.SECURE_REGION_END)}]")

    def _assert_cpu(self):
        if not self.cpu: raise Exception("Panic: CPU not connected to memory")

    def _get_block_index(self, address):
        self._assert_address(address)
        offset = address - self.SECURE_REGION_START
        return offset // self.BLOCK_SIZE

    def _mee_encrypt_decrypt(self, block_address, data_block, is_write):
        if len(data_block) != self.BLOCK_SIZE:
            raise Exception(f"MEE received {len(data_block)} bytes, expected {self.BLOCK_SIZE}")

        block_idx = self._get_block_index(block_address)

        version = self.__sram_counters[block_idx]
        # TODO update counter

        base_nonce = struct.pack("II", block_address, version)

        processed_data = bytearray()
        
        counter_block = struct.pack("II", block_address, version) + b'\x00' * 8
        keystream_block = hmac.new(self.__mek, counter_block, hashlib.sha256).digest()
        
        # TODO use the keystream inside this for-loop to encrypt the memory
        for i in range(self.BLOCK_SIZE):
            processed_data.append(data_block[i])

        if is_write:
            mac_payload = processed_data + base_nonce
            tag = hmac.new(self.__mek, mac_payload, hashlib.sha256).hexdigest()
            # TODO store the tag

            return processed_data 
        else:
            stored_tag = self.__tags_ram[block_idx]
            
            # TODO Uncomment this after implementing your tag storing above.
            # if stored_tag is None:
            #     return bytearray(self.BLOCK_SIZE)

            mac_payload = data_block + base_nonce 
            expected_tag = hmac.new(self.__mek, mac_payload, hashlib.sha256).hexdigest()

            # TODO read operation: check that the memory was not tampered with, raise SecurityViolation if it is
            
            return processed_data 

    def read(self, address, length):
        self._assert_cpu()

        is_accessing_secure = (address >= self.SECURE_REGION_START)
        # TODO check if CPU is in correct mode for the given address, throw SecurityViolation if not
            
        if not is_accessing_secure:
            return self.phys_mem[address : address + length]

        result = bytearray()
        bytes_read = 0

        while bytes_read < length:
            curr_addr = address + bytes_read
            
            block_offset = curr_addr % self.BLOCK_SIZE
            block_start_addr = curr_addr - block_offset
            
            encrypted_block = self.phys_mem[block_start_addr : block_start_addr + self.BLOCK_SIZE]
            plaintext_block = self._mee_encrypt_decrypt(block_start_addr, encrypted_block, is_write=False)
            
            bytes_from_this_block = min(length - bytes_read, self.BLOCK_SIZE - block_offset)
            
            result.extend(plaintext_block[block_offset : block_offset + bytes_from_this_block])
            bytes_read += bytes_from_this_block
        
        return result

    def write(self, address, data):
        self._assert_cpu()

        is_accessing_secure = (address >= self.SECURE_REGION_START)
        # TODO check if CPU is in correct mode for the given address, throw SecurityViolation if not
       
        if not is_accessing_secure:
            self.phys_mem[address : address + len(data)] = data
            return

        bytes_written = 0
        total_len = len(data)

        while bytes_written < total_len:
            curr_addr = address + bytes_written
            
            block_offset = curr_addr % self.BLOCK_SIZE
            block_start_addr = curr_addr - block_offset
            
            encrypted_existing = self.phys_mem[block_start_addr : block_start_addr + self.BLOCK_SIZE]
            plaintext_block = self._mee_encrypt_decrypt(block_start_addr, encrypted_existing, is_write=False)
            
            bytes_to_write = min(total_len - bytes_written, self.BLOCK_SIZE - block_offset)
            
            for i in range(bytes_to_write):
                plaintext_block[block_offset + i] = data[bytes_written + i]
            
            new_ciphertext = self._mee_encrypt_decrypt(block_start_addr, plaintext_block, is_write=True)
            self.phys_mem[block_start_addr : block_start_addr + self.BLOCK_SIZE] = new_ciphertext
            
            bytes_written += bytes_to_write

    def probe(self, address, length):
        # simulate direct probing of an address 
        return self.phys_mem[address : address + length]

    def inject(self, address, data):
        # simulate direct injection of data to an address
        self.phys_mem[address : address + len(data)] = data

In [None]:
class Processor:
    def __init__(self, bus):
        self.mode = CPUMode.NORMAL
        self.bus = bus 
        self.reg_file = [0] * 8 
        self.flags_eq = False
        self.flags_lt = False
        
        self.enclave_measurement = hashlib.sha256()
        self.enclave_initialized = False
        self.enclave_code_start = 0
        self.enclave_code_size = 0

    def get_reg(self, reg_id): return self.reg_file[reg_id]
    def set_reg(self, reg_id, val): self.reg_file[reg_id] = val
    
    def get_val(self, type_flag, val):
        if type_flag == ArgType.REG: return self.reg_file[val]
        elif type_flag == ArgType.IMM: return val
        else:return 0

    def exec_instruction(self, op, t1, v1, t2, v2):
        val1 = self.get_val(t1, v1)
        val2 = self.get_val(t2, v2)
        
        if op == Opcode.MOV:
            if t1 == ArgType.REG: self.set_reg(reg_id=v1, val=val2)
        elif op == Opcode.ADD:
            if t1 == ArgType.REG: self.set_reg(reg_id=v1, val=val1 + val2)
        elif op == Opcode.SUB:
            if t1 == ArgType.REG: self.set_reg(reg_id=v1, val=val1 - val2)
        elif op == Opcode.MUL:
            if t1 == ArgType.REG: self.set_reg(reg_id=v1, val=val1 * val2)
        elif op == Opcode.CMP:
            self.flags_eq = (val1 == val2)
            self.flags_lt = (val1 < val2)
        elif op == Opcode.JE:
            if self.flags_eq: self.set_reg(reg_id=Reg.RIP, val=v1)
        elif op == Opcode.JL:
            if self.flags_lt: self.set_reg(reg_id=Reg.RIP, val=v1)
        elif op == Opcode.LOAD:
            data = self.bus.read(address=val2, length=4)
            unpacked = struct.unpack("I", data)[0]
            self.set_reg(reg_id=v1, val=unpacked)
        elif op == Opcode.STORE:
            self.bus.write(address=val2, data=struct.pack("I", val1))
        elif op == Opcode.ENCLS:
            self._ENCLS()
        elif op == Opcode.ENCLU:
            return self._ENCLU()
            
        return False

    def _ENCLS(self):
        self.mode = CPUMode.KERNEL 
        leaf = self.get_reg(Reg.RAX)
        
        if leaf == ENCLSLeafs.ENCLS_ECREATE:
            self.enclave_measurement = hashlib.sha256()
            self.enclave_initialized = False
        elif leaf == ENCLSLeafs.ENCLS_EADD:
            if self.enclave_initialized: raise SecurityViolation("Enclave already initialized")
            src = self.get_reg(Reg.RBX)
            dst = self.get_reg(Reg.RCX)
            size = self.get_reg(Reg.RDX)
            data = self.bus.read(src, size)
            self.bus.write(dst, data)
        elif leaf == ENCLSLeafs.ENCLS_EEXTEND:
            if self.enclave_initialized: raise SecurityViolation("Enclave already initialized")
            target = self.get_reg(Reg.RBX)
            size = self.get_reg(Reg.RCX)
            content = self.bus.read(target, size)
            self.enclave_measurement.update(content)
        elif leaf == ENCLSLeafs.ENCLS_EINIT:
            self.enclave_initialized = True
            
            self.enclave_code_start = self.bus.SECURE_REGION_START
            self.enclave_code_size  = self.get_reg(Reg.RBX) # Length passed in RBX
            
            print(f"-> Enclave LOCKED. MRENCLAVE: {self.enclave_measurement.hexdigest()[:10]}...")
            print()
        
        self.mode = CPUMode.NORMAL

    def _ENCLU(self):
        leaf = self.get_reg(Reg.RAX)
        if leaf == ENCLULeafs.ENCLU_EENTER:
            if not self.enclave_initialized: raise SecurityViolation("Enclave not initialized before ENCLU_EENTER")
            self.mode = CPUMode.SECURE
            self.set_reg(Reg.RIP, 0)
            self._run_enclave_loop()
        elif leaf == ENCLULeafs.ENCLU_EEXIT:
            self.mode = CPUMode.NORMAL
            return True # Signal Exit
        return False

    def _run_enclave_loop(self):
        while self.mode == CPUMode.SECURE:
            rip = self.get_reg(Reg.RIP)
            offset = rip * Assembler.get_instr_size()
            
            if offset >= self.enclave_code_size:
                print("[CPU] Crash: RIP out of bounds")
                break
                
            phys_addr = self.enclave_code_start + offset
            
            raw_bytes = self.bus.read(phys_addr, Assembler.get_instr_size())
            op, t1, v1, t2, v2 = Assembler.disassemble(raw_bytes)
            
            self.set_reg(Reg.RIP, rip + 1)
            if self.exec_instruction(op, t1, v1, t2, v2):
                break
    
    def exec(self, instr):
        return self.exec_instruction(*Assembler.encode(instr))

    def sec_reg_start(self):
        return self.bus.SECURE_REGION_START

    def sec_reg_end(self):
        return self.bus.SECURE_REGION_END

In [None]:
class SecureBankApp:
    def __init__(self, acc_a_bal = 1000, acc_b_bal = 500):
        self.initial_data = struct.pack("II", acc_a_bal, acc_b_bal)
        self.relocate(0) # has to be called by OS when location in enclave is known

    def relocate(self, new_data_base):
        self.data_base_addr = new_data_base
        self.instructions = [
            # desired function is determined by RDI
            (Opcode.CMP, Reg.RDI, 0), (Opcode.JE, 8),   # Deposit
            (Opcode.CMP, Reg.RDI, 1), (Opcode.JE, 16),  # Balance
            (Opcode.CMP, Reg.RDI, 2), (Opcode.JE, 23),  # Transfer
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- DEPOSIT (0) ---
            # 8:
            (Opcode.MOV, Reg.RDI, Reg.RBX), (Opcode.MUL, Reg.RDI, 4), (Opcode.ADD, Reg.RDI, self.data_base_addr),
            (Opcode.LOAD, Reg.RAX, Reg.RDI), (Opcode.ADD, Reg.RAX, Reg.RCX), (Opcode.STORE, Reg.RAX, Reg.RDI),
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- BALANCE (1) ---
            # 16:
            (Opcode.MOV, Reg.RDI, Reg.RBX), (Opcode.MUL, Reg.RDI, 4), (Opcode.ADD, Reg.RDI, self.data_base_addr),
            (Opcode.LOAD, Reg.RAX, Reg.RDI), (Opcode.MOV, Reg.RBX, Reg.RAX),
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- TRANSFER (2) ---
            # 23:
            (Opcode.MOV, Reg.RSI, Reg.RBX), (Opcode.MUL, Reg.RSI, 4), (Opcode.ADD, Reg.RSI, self.data_base_addr),
            (Opcode.MOV, Reg.RDI, Reg.RCX), (Opcode.MUL, Reg.RDI, 4), (Opcode.ADD, Reg.RDI, self.data_base_addr),
            (Opcode.LOAD, Reg.RAX, Reg.RSI), (Opcode.CMP, Reg.RAX, Reg.RDX), (Opcode.JL, 40),
            (Opcode.SUB, Reg.RAX, Reg.RDX), (Opcode.STORE, Reg.RAX, Reg.RSI),
            (Opcode.LOAD, Reg.RBX, Reg.RDI), (Opcode.ADD, Reg.RBX, Reg.RDX), (Opcode.STORE, Reg.RBX, Reg.RDI),
            (Opcode.MOV, Reg.RBX, 1), (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # 40: Fail
            (Opcode.MOV, Reg.RBX, 0), 
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,)
        ]

In [None]:
class OS:
    def __init__(self, cpu: Processor):
        self.cpu = cpu
        self.epc_free_ptr = self.cpu.sec_reg_start()
        self.untrusted_heap_ptr = 0x0

    def allocate_epc_page(self, size_bytes):
        # gets a pointer to a memory in secure memory - does not write to it!
        if self.epc_free_ptr + size_bytes > self.cpu.sec_reg_end():
            raise MemoryError("Out of EPC Memory!")
        addr = self.epc_free_ptr
        self.epc_free_ptr += size_bytes
        return addr

    def allocate_untrusted_ram(self, data):
        # writes data to unsecure memory
        start_addr = self.untrusted_heap_ptr
        
        pad_len = (4 - (len(data) % 4)) % 4
        data_to_write = data + b'\x00' * pad_len
        
        for i in range(0, len(data_to_write), 4):
            chunk = data_to_write[i : i+4]
            val_int = struct.unpack("I", chunk)[0]
            self.cpu.exec((Opcode.STORE, val_int, start_addr + i))

        self.untrusted_heap_ptr += len(data_to_write)
        return start_addr

    def init_enclave(self, app):
        print("--- OS: Initializing Enclave ---")
        est_code_size = len(app.instructions) * Assembler.get_instr_size()
        enclave_code_addr = self.allocate_epc_page(est_code_size)
        self.enclave_data_addr = self.allocate_epc_page(len(app.initial_data))
        
        # Relocate app and get code and data of app in bytes
        app.relocate(self.enclave_data_addr)
        code_blob = Assembler.assemble(app.instructions)
        data_blob = app.initial_data

        # Get pointers for code and data
        untrusted_code_ptr = self.allocate_untrusted_ram(code_blob)
        untrusted_data_ptr = self.allocate_untrusted_ram(data_blob)

        # ENCLS_ECREATE: Create enclave
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLSLeafs.ENCLS_ECREATE))
        self.cpu.exec((Opcode.ENCLS,))
        
        # ENCLS_EADD: Add Code
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLSLeafs.ENCLS_EADD))
        self.cpu.exec((Opcode.MOV, Reg.RBX, untrusted_code_ptr))
        self.cpu.exec((Opcode.MOV, Reg.RCX, enclave_code_addr))
        self.cpu.exec((Opcode.MOV, Reg.RDX, len(code_blob)))
        self.cpu.exec((Opcode.ENCLS,))
        
        # ENCLS_EADD: Add Data
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLSLeafs.ENCLS_EADD))
        self.cpu.exec((Opcode.MOV, Reg.RBX, untrusted_data_ptr))
        self.cpu.exec((Opcode.MOV, Reg.RCX, self.enclave_data_addr))
        self.cpu.exec((Opcode.MOV, Reg.RDX, len(data_blob)))
        self.cpu.exec((Opcode.ENCLS,))
        
        # ENCLS_EEXTEND: Measure code part
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLSLeafs.ENCLS_EEXTEND))
        self.cpu.exec((Opcode.MOV, Reg.RBX, enclave_code_addr))
        self.cpu.exec((Opcode.MOV, Reg.RCX, len(code_blob)))
        self.cpu.exec((Opcode.ENCLS,))

        # ENCLS_EEXTEND: Measure data part
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLSLeafs.ENCLS_EEXTEND))
        self.cpu.exec((Opcode.MOV, Reg.RBX, self.enclave_data_addr))
        self.cpu.exec((Opcode.MOV, Reg.RCX, len(data_blob)))
        self.cpu.exec((Opcode.ENCLS,))
        
        # ENCLS_EINIT: Init, finalizes measurement and initializes enclave
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLSLeafs.ENCLS_EINIT))
        self.cpu.exec((Opcode.MOV, Reg.RBX, len(code_blob)))
        self.cpu.exec((Opcode.ENCLS,))

    def call_transfer(self, from_id, to_id, amount, v = True):
        if v:
            print(f"Bank Transfer: User {from_id} -> User {to_id} (${amount})")
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EENTER))
        self.cpu.exec((Opcode.MOV, Reg.RDI, 2))
        self.cpu.exec((Opcode.MOV, Reg.RBX, from_id))
        self.cpu.exec((Opcode.MOV, Reg.RCX, to_id))
        self.cpu.exec((Opcode.MOV, Reg.RDX, amount))
        self.cpu.exec((Opcode.ENCLU,))
        
        res = self.cpu.get_reg(Reg.RBX)
        if v:
            print(f"\tBank Transfer {"FAILED" if res == 0 else "SUCCESS"}")

    def call_get_balance(self, user_id):
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EENTER))
        self.cpu.exec((Opcode.MOV, Reg.RDI, 1))
        self.cpu.exec((Opcode.MOV, Reg.RBX, user_id))
        self.cpu.exec((Opcode.ENCLU,))
        return self.cpu.get_reg(Reg.RBX)

    def call_deposit(self, user_id, amount, v = True):
        if v:
            print(f"Bank Deposit: User {user_id} + ${amount}")
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EENTER))
        self.cpu.exec((Opcode.MOV, Reg.RDI, 0))
        self.cpu.exec((Opcode.MOV, Reg.RBX, user_id))
        self.cpu.exec((Opcode.MOV, Reg.RCX, amount))
        self.cpu.exec((Opcode.ENCLU,))
        return self.cpu.get_reg(Reg.RBX)

    def print_balances(self):
        print(f"Current Bank Balances")
        for i in [0, 1]:
            print(f"    User {i}: ${self.call_get_balance(i)}")
        print()

In [None]:
# NORMAL
mem = Memory()
cpu = Processor(mem)
mem.connect_cpu(cpu)
os_kernel = OS(cpu)

app = SecureBankApp()

os_kernel.init_enclave(app)
os_kernel.print_balances()
os_kernel.call_transfer(from_id=0, to_id=1, amount=400)
os_kernel.call_transfer(from_id=0, to_id=1, amount=4000)
os_kernel.print_balances()
os_kernel.call_deposit(user_id=0, amount=200)
os_kernel.call_deposit(user_id=1, amount=100)
os_kernel.print_balances()

In [None]:
class CompromisedOS(OS):
    # a compromised OS which attempts to perform attacks (using direct memory accesses as well)
    def __init__(self, cpu:Processor):
        super().__init__(cpu)
        self.mem = self.cpu.bus
    
    def enclave_read(self, v = True):
        if v:
            print(f"[Attack 1] Instructing the CPU to read from enclave memory...")
        try:
            self.cpu.exec((Opcode.LOAD, Reg.RAX, self.enclave_data_addr))
        except SecurityViolation as sv:
            if v:
                print(f"\tBLOCKED ({sv})")
            return False

        if v:
            print("\tSUCCESSFUL")
            print(f"\tBalance of User 0: {self.cpu.get_reg(Reg.RAX)}")
            print()
        return True

    def enclave_write(self, v = True):
        if v:
            print(f"[Attack 2] Instructing the CPU to write to enclave memory...")
        try:
            self.cpu.exec((Opcode.STORE, 2000, self.enclave_data_addr))
        except SecurityViolation as sv:
            if v:
                print(f"\tBLOCKED ({sv})")
            return False

        if v:
            print("\tSUCCESSFUL")
            print("\tChanged balance of User 0 to $2000")
            print()
        return True

    def direct_read(self, v = True):
        balance_corr = self.call_get_balance(0)
        balance = struct.unpack("I", self.mem.probe(self.enclave_data_addr, 4))[0]
        
        if v:
            print(f"[Attack 3] Probing the memory directly...")
        if balance_corr == balance:
            if v:
                print(f"\tSUCCESSFUL")
                print("\tThe enclave memory can be read unencrypted")
                print()
            return True
        if v:
            print("\tBLOCKED")
            print()
        return False

    def direct_write(self, v = True):
        old_balance_mem = self.mem.probe(self.enclave_data_addr, 4)
        if v:
            print(f"[Attack 4] Writing directly to memory after transfer...")
            self.print_balances()
        self.call_transfer(from_id=0, to_id=1, amount=100, v=v)
        self.mem.inject(self.enclave_data_addr, old_balance_mem)
        try:
            self.call_get_balance(0)
        except SecurityViolation as sv:
            if v:
                print(f"\tBLOCKED ({sv})")
                return False
        if v:
            print("\tSUCCESSFUL")
            print("\tThe bank balance of user 0 has been successfully written back to the pre-transfer value")
            print()
            self.print_balances()
        return True
    
    def run_attacks(self, v = True):
        self.init_enclave(app)
        if v:
            self.print_balances()
        
        attack_success = False
        attack_success |= self.enclave_read(v=v)
        attack_success |= self.enclave_write(v=v)
        
        if v:
            print("Post-attack balances..")
            self.print_balances()

        attack_success |= self.direct_read(v=v)
        attack_success |= self.direct_write(v=v)
        return attack_success

In [None]:
# ATTACK 
# keep updating the functionality of the memory class until all attacks fail
mem = Memory()
cpu = Processor(mem)
mem.connect_cpu(cpu)
app = SecureBankApp()
cos = CompromisedOS(cpu)
attack_success = cos.run_attacks()
print(f"Can successfully run an attack: {attack_success}")

In [None]:
class SecureBankAppV2:
    def __init__(self, acc_a_bal = 1000, acc_b_bal = 500):
        self.initial_data = struct.pack("II", acc_a_bal, acc_b_bal)
        self.relocate(0) # has to be called by OS when location in enclave is known

    def relocate(self, new_data_base):
        self.data_base_addr = new_data_base
        # Instructions defined using Enums
        self.instructions = [
            # 0: Dispatcher
            (Opcode.CMP, Reg.RDI, 0), (Opcode.JE, 10),   # Deposit
            (Opcode.CMP, Reg.RDI, 1), (Opcode.JE, 18),  # Balance
            (Opcode.CMP, Reg.RDI, 2), (Opcode.JE, 25),  # Transfer
            (Opcode.CMP, Reg.RDI, 3), (Opcode.JE, 43),  # Withdraw
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- DEPOSIT (0) ---
            # 10:
            (Opcode.MOV, Reg.RDI, Reg.RBX), (Opcode.MUL, Reg.RDI, 4), (Opcode.ADD, Reg.RDI, self.data_base_addr),
            (Opcode.LOAD, Reg.RAX, Reg.RDI), (Opcode.ADD, Reg.RAX, Reg.RCX), (Opcode.STORE, Reg.RAX, Reg.RDI),
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- BALANCE (1) ---
            # 18:
            (Opcode.MOV, Reg.RDI, Reg.RBX), (Opcode.MUL, Reg.RDI, 4), (Opcode.ADD, Reg.RDI, self.data_base_addr),
            (Opcode.LOAD, Reg.RAX, Reg.RDI), (Opcode.MOV, Reg.RBX, Reg.RAX),
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- TRANSFER (2) ---
            # 25:
            (Opcode.MOV, Reg.RSI, Reg.RBX), (Opcode.MUL, Reg.RSI, 4), (Opcode.ADD, Reg.RSI, self.data_base_addr),
            (Opcode.MOV, Reg.RDI, Reg.RCX), (Opcode.MUL, Reg.RDI, 4), (Opcode.ADD, Reg.RDI, self.data_base_addr),
            (Opcode.LOAD, Reg.RAX, Reg.RSI), (Opcode.CMP, Reg.RAX, Reg.RDX), (Opcode.JL, XX),
            (Opcode.SUB, Reg.RAX, Reg.RDX), (Opcode.STORE, Reg.RAX, Reg.RSI),
            (Opcode.LOAD, Reg.RBX, Reg.RDI), (Opcode.ADD, Reg.RBX, Reg.RDX), (Opcode.STORE, Reg.RBX, Reg.RDI),
            (Opcode.MOV, Reg.RAX, 1), (Opcode.MOV, Reg.RBX, Reg.RAX), (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,),

            # --- WITHDRAW (3) --- make sure to check that user has enough balance
            # 43:
            
            
            # XX: Fail
            (Opcode.MOV, Reg.RBX, 0), 
            (Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EEXIT), (Opcode.ENCLU,)
        ]

In [None]:
class OSv2(OS):
    # OS v2 which supports account withdraws
    def call_withdraw(self, user_id, amount, v = True):
        if v:
            print(f"Bank Withdraw: User {user_id} - ${amount}")
        self.cpu.exec((Opcode.MOV, Reg.RAX, ENCLULeafs.ENCLU_EENTER))
        self.cpu.exec((Opcode.MOV, Reg.RDI, 3))
        # TODO put the user_id and amount into the correct registers for your implementation

        self.cpu.exec((Opcode.ENCLU,))
        # TODO make sure your implementation writes the success value to register RBX: 0 -> FAIL, 1 -> SUCCESS
        res = self.cpu.get_reg(Reg.RBX)
        if v:
            print(f"\tBank Withdraw {"FAILED" if res == 0 else "SUCCESS"}")
        return self.cpu.get_reg(Reg.RBX)

In [None]:
# Normal v2
mem = Memory()
cpu = Processor(mem)
mem.connect_cpu(cpu)
os_kernel = OSv2(cpu)

app = SecureBankAppV2()

os_kernel.init_enclave(app)
os_kernel.print_balances()
os_kernel.call_transfer(from_id=0, to_id=1, amount=400)
os_kernel.call_transfer(from_id=0, to_id=1, amount=4000)
os_kernel.print_balances()
os_kernel.call_deposit(user_id=0, amount=200)
os_kernel.call_deposit(user_id=1, amount=100)
os_kernel.print_balances()
os_kernel.call_withdraw(user_id=0, amount=500)
os_kernel.call_withdraw(user_id=1, amount=2000)
os_kernel.print_balances()