In [1]:
import heapq
import itertools

In [2]:
class Transaction:
    def __init__(self, sender, nonce, gas_price):
        self.sender = sender
        self.nonce = nonce
        self.gas_price = float(gas_price)

        self.hash = abs(hash(sender + hash(str(nonce)) + hash(self.gas_price)))

    def __hash__(self):
        return self.hash

    def __str__(self):
        return f"Tx<0x{self.hash:016x}>(sender=0x{self.sender:016x}, nonce={self.nonce}, gas_price={self.gas_price:.2f})"

In [3]:
class PricedHeap:
    def __init__(self, tx_pool):
        self.tx_pool = tx_pool  # type: TxPool

        self.heap = []
        self.stales = 0
        self.counter = itertools.count()

    def __len__(self):
        return len(self.heap)

    def underpriced(self, tx: Transaction):
        return self.heap and self.heap[0][0] >= tx.gas_price

    def discard(self, slots):
        drop = []
        while self.heap and slots:
            tx = heapq.heappop(self.heap)[2]
            if tx not in self.tx_pool.all:
                self.stales -= 1
                continue
            drop.append(tx)
            slots -= 1
        return drop

    def removed(self, count):
        self.stales += count

    def put(self, tx: Transaction):
        heapq.heappush(self.heap, (tx.gas_price, next(self.counter), tx))

In [4]:
class TxSortedMap:
    def __init__(self):
        self.items = {}  # type: dict[int, Transaction]

    def get(self, nonce) -> Transaction:
        return self.items.get(nonce)

    def put(self, tx: Transaction):
        self.items[tx.nonce] = tx

    def remove(self, nonce):
        try:
            self.items.pop(nonce)
            return True
        except KeyError:
            return False

    def filter(self, filter):
        nonces_to_remove = [nonce for nonce, tx in self.items.items() if filter(tx)]
        return [self.items.pop(nonce) for nonce in nonces_to_remove]

In [5]:
class TxList:
    def __init__(self, strict):
        self.strict = strict

        self.txs = TxSortedMap()

    def __len__(self):
        return len(self.txs)

    def add(self, tx: Transaction, price_bump):
        old = self.txs.get(tx.nonce)
        if old and tx.gas_price <= old.gas_price * (100 + price_bump) / 100:
            return False, None
        self.txs.put(tx)
        return True, old

    def remove(self, tx: Transaction):
        if not self.txs.remove(tx.nonce):
            return False, []
        if self.strict:
            return True, self.txs.filter(lambda other_tx: other_tx.nonce > tx.nonce)
        return True, []

    def overlaps(self, tx: Transaction):
        return self.txs.get(tx.nonce) is not None

In [6]:
class TxNoncer:
    def __init__(self):
        pass

    def set_if_lower(self, addr, nonce):
        pass

In [7]:
class TxPoolConfig:
    def __init__(self, price_bump=10, global_slots=4096, global_queue=1024):
        self.price_bump = price_bump
        self.global_slots = global_slots
        self.global_queue = global_queue

In [8]:
class TxPool:
    def __init__(self, config=TxPoolConfig()):
        self.config = config

        self.pending_nonces = TxNoncer()

        self.pending = {}  # type: dict[int, TxList]
        self.queue = {}  # type: dict[int, TxList]
        self.priced = PricedHeap(self)
        self.all = set()

    def add_txs(self, txs: list[Transaction]):
        for tx in txs:
            if tx in self.all:
                continue
            total_slots = self.config.global_slots + self.config.global_queue
            if len(self.all) >= total_slots:
                if self.priced.underpriced(tx):
                    continue
                drop = self.priced.discard(len(self.all) - total_slots + 1)
                for drop_tx in drop:
                    self.remove_tx(drop_tx, outofbound=False)

            tx_list = self.pending.get(tx.sender)
            if tx_list and tx_list.overlaps(tx):
                inserted, old = tx_list.add(tx, self.config.price_bump)
                if not inserted:
                    continue
                if old:
                    self.all.remove(tx)
                    self.priced.removed(1)
                self.all.add(tx)
                self.priced.put(tx)
                self.queue_tx_event(tx)
                continue

            self.enqueue_tx(tx)


    def remove_tx(self, tx: Transaction, outofbound):
        self.all.remove(tx)
        if outofbound:
            self.priced.removed(1)

        try:
            pending = self.pending[tx.sender]
            removed, invalids = pending.remove(tx)
            if removed:
                if not len(pending):
                    del self.pending[tx.sender]
                for tx in invalids:
                    self.enqueue_tx(tx)
                self.pending_nonces.set_if_lower(tx.sender, tx.nonce)
                return
        except KeyError:
            pass

        try:
            future = self.queue[tx.sender]
            if future.remove(tx)[0]:
                if not len(future):
                    del self.queue[tx.sender]
        except KeyError:
            pass

    def enqueue_tx(self, tx):
        try:
            tx_list = self.queue[tx.sender]
        except KeyError:
            tx_list = TxList(False)
            self.queue[tx.sender] = tx_list
        inserted, old = tx_list.add(tx, self.config.price_bump)
        if not inserted:
            return
        if old:
            self.all.remove(old)
            self.priced.removed(1)

        if tx not in self.all:
            self.all.add(tx)
            self.priced.put(tx)

    def queue_tx_event(self, tx):
        pass

In [9]:
mempool = TxPool()

# initial transactions
mempool.add_txs([
    Transaction(sender=hash("0"), nonce=0, gas_price=10),
    Transaction(sender=hash("1"), nonce=0, gas_price=4.3),
    Transaction(sender=hash("2"), nonce=0, gas_price=11.5),
])

# replacements
mempool.add_txs([
    Transaction(sender=hash("0"), nonce=0, gas_price=10.5),
    Transaction(sender=hash("1"), nonce=0, gas_price=7.7),
    Transaction(sender=hash("2"), nonce=0, gas_price=9),
])

# higher nonces
mempool.add_txs([
    Transaction(sender=hash("0"), nonce=1, gas_price=7),
    Transaction(sender=hash("1"), nonce=4, gas_price=9),
    Transaction(sender=hash("2"), nonce=2, gas_price=8),
])

for tx in mempool.all:
    print(tx)


Tx<0x141a72b974902ce5>(sender=0x7c2596bc8fe6ecce, nonce=0, gas_price=11.50)
Tx<0x184b2d791fcdd9ab>(sender=0x7c2596bc8fe6ecce, nonce=2, gas_price=8.00)
Tx<0x0e2fb09d82c5e60e>(sender=0x47f4dbfce4a94006, nonce=1, gas_price=7.00)
Tx<0x04961703e92c4e0f>(sender=0x-39c52b5f61e359ff, nonce=0, gas_price=7.70)
Tx<0x004fcc91f8ccdbb8>(sender=0x-39c52b5f61e359ff, nonce=4, gas_price=9.00)
Tx<0x0fe9b7f9c952801a>(sender=0x47f4dbfce4a94006, nonce=0, gas_price=10.00)
