<a href="https://colab.research.google.com/github/GuillaumeFuchs/CS221/blob/main/Hello_world.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
from collections import defaultdict
from typing import List, Tuple, Dict, Optional
from pprint import pprint

In [2]:
class UniV2Pool:
    def __init__(self, token0: str, token1: str, reserve0: float, reserve1: float):
        self.token0 = token0
        self.token1 = token1
        self.reserve0 = reserve0
        self.reserve1 = reserve1

    def get_output_amount(self, input_token: str, input_amount: float) -> float:
        """ A Uniswap V2 pool functions as an on-chain market-making engine, enabling the permissionless exchange of two tokens.
        The system is governed by the constant product formula, which must always hold:
        reserve_A * reserve_B = k,
        where reserve_A (respectively reserve_B) represents the available liquidity of asset A (respectively asset B) in the pool.
        This fundamental equation ensures that the pool never depletes either asset, as the price dynamically adjusts—approaching infinity as one reserve approaches zero.
        To determine the output amount (output_amount) received when swapping a given input amount (input_amount) of input_token, the following equation must be solved:
        (reserve_in + input_amount) * (reserve_out − output_amount) = k = reserve_in * reserve_out
        This equation maintains the constant product invariant while reflecting the impact of each trade on price and liquidity. """
        if input_token == self.token0:
            reserve_in, reserve_out = self.reserve0, self.reserve1
        elif input_token == self.token1:
            reserve_in, reserve_out = self.reserve1, self.reserve0
        else:
            return Exception("unsupported token")

        output_amount = (input_amount * reserve_out) / (reserve_in + input_amount)
        return output_amount


    def get_spot_price(self, input_token: str) -> float:
        """ This gives the instataneous price. This is given mostly for information purpose. """
        # Inverse ?
        if input_token == self.token0:
            return self.reserve0 / self.reserve1
        elif input_token == self.token1:
            return self.reserve1 / self.reserve0
        else:
            return Exception("unsupported token")


In [71]:
# tick sizes dérivés des décimales ERC20 standards
DEFAULT_TICKSIZE = {
    "ETH": 1e-18,      # 18 decimals
    "DAI": 1e-18,      # 18 decimals
    "USDC": 1e-6,      # 6 decimals
    "USDT": 1e-6,      # 6 decimals
    "WBTC": 1e-8,      # 8 decimals
}


class Router:
    TOL = 1e-10
    MAX_ITER = 100

    def __init__(self, pools: List["UniV2Pool"], ticksize: Optional[Dict[str,float]]=None):
        self.pools = pools
        self.tokens = set()
        for p in pools:
            self.tokens.add(p.token0)
            self.tokens.add(p.token1)
        self.ticksize = ticksize or DEFAULT_TICKSIZE

        ##
        self.pairs: Dict[Tuple[str, str], List[Tuple[float, float, int]]] = defaultdict(list)
        for idx, p in enumerate(pools):
            print(idx)
            self.pairs[(p.token0, p.token1)].append((p.reserve0, p.reserve1, idx))
            self.pairs[(p.token1, p.token0)].append((p.reserve1, p.reserve0, idx))


    @staticmethod
    def _out(rin: float, rout: float, x: float) -> float:
        return (x * rout) / (rin + x) if x > 0 else 0.0

    @staticmethod
    def _marginal(rin: float, rout: float, x: float) -> float:
        return (rout * rin) / (rin + x) ** 2

    @staticmethod
    def _x_from_lambda(rin: float, rout: float, lam: float) -> float:
        if lam <= 0:
            return float("inf")
        t = math.sqrt((rout * rin) / lam) - rin
        return max(0.0, t)


    def _round_to_tick(self, token: str, qty: float) -> float:
      tick = self.ticksize.get(token, 0.0)
      if tick <= 0:
          return qty
      return round(qty / tick) * tick


    def _aggregate_pair(self, t_in: str, t_out: str, x_total: float):
        pairs = self.pairs.get((t_in, t_out), [])
        if x_total <= 0 or not pairs:
            return 0.0, 0.0, []

        lam_hi = max(Router._marginal(r0, r1, 0.0) for r0, r1, _ in pairs)
        lam_lo = 0.0

        best = None
        iters = 0
        while lam_hi - lam_lo > self.TOL and iters < self.MAX_ITER:
            lam_cand = 0.5 * (lam_lo + lam_hi)

            xs = [Router._x_from_lambda(r0, r1, lam_cand) for r0, r1, _ in pairs]
            # arrondi des allocations d'input par ticksize du token in
            xs = [self._round_to_tick(t_in, xi) for xi in xs]
            spent = sum(xs)

            y = sum(self._round_to_tick(t_out, Router._out(r0, r1, xi)) for (r0, r1, _), xi in zip(pairs, xs))

            best = (lam_cand, y, xs)
            if spent > x_total:
                lam_lo = lam_cand
            else:
                lam_hi = lam_cand
            iters += 1

        if best is None:
            raise RuntimeError(f"_aggregate_pair failed to converge for pair ({t_in},{t_out}) " f"with x_total={x_total}")

        lam_star, y_star, xs_star = best
        splits = [{"pool_index": pairs[i][2], "alloc_in": xs_star[i]} for i in range(len(pairs))]

        return y_star, lam_star, splits

    def _candidate_paths(self, X: str, Y: str):
        paths = []
        if (X, Y) in self.pairs:
            paths.append(("direct", None, None))
        # 1 hop
        for A1 in self.tokens:
            if A1 in (X, Y):
                continue
            if (X, A1) in self.pairs and (A1, Y) in self.pairs:
                paths.append(("via1", A1, None))
        # 2 hops
        for A1 in self.tokens:
            if A1 in (X, Y):
                continue
            for A2 in self.tokens:
                if A2 in (X, Y, A1):
                    continue
                if (X, A1) in self.pairs and (A1, A2) in self.pairs and (A2, Y) in self.pairs:
                    paths.append(("via2", A1, A2))
        return paths

    def _eval_path(self, kind, X, Y, x, A1=None, A2=None):
        if kind == "direct":
            y1, lam1, _ = self._aggregate_pair(X, Y, x)
            return y1, lam1

        if kind == "via1":
            y1, lam1, _ = self._aggregate_pair(X, A1, x)
            if y1 == 0:
                return 0.0, 0.0
            y2, lam2, _ = self._aggregate_pair(A1, Y, y1)
            return y2, lam1 * lam2

        # via2: X->A1->A2->Y
        y1, lam1, _ = self._aggregate_pair(X, A1, x)
        if y1 == 0:
            return 0.0, 0.0
        y2, lam2, _ = self._aggregate_pair(A1, A2, y1)
        if y2 == 0:
            return 0.0, 0.0
        y3, lam3, _ = self._aggregate_pair(A2, Y, y2)
        return y3, lam1 * lam2 * lam3

    def _x_for_path_mu(self, target_mu: float, kind, m1, X, Y, x_hi: float, m2=None) -> float:
        if target_mu <= 0:
            return x_hi
        lo, hi = 0.0, x_hi
        for _ in range(60):
            m = 0.5 * (lo + hi)
            _, mu = self._eval_path(kind, X, Y, m, m1, m2)
            if mu > target_mu:
                lo = m
            else:
                hi = m
        return hi

    def _path_initial_mu(self, kind, X, Y, A1=None, A2=None) -> float:
        if kind == "direct":
            legs = [(X, Y)]
        elif kind == "via1":
            legs = [(X, A1), (A1, Y)]
        else:  # via2
            legs = [(X, A1), (A1, A2), (A2, Y)]

        mu = 1.0
        for u, v in legs:
          mu *= max((rout / rin for rin, rout, _ in self.pairs.get((u, v), [])), default=0.0)

        return mu

    def _allocate_for_mu(self, mu: float, paths: List[Tuple[str, str, str]], tokenA: str, tokenB: str, amount_in: float):
        xs = [self._round_to_tick(tokenA, self._x_for_path_mu(mu, kind, A1, tokenA, tokenB, amount_in, A2)) for (kind, A1, A2) in paths]
        spent = sum(xs)

        routes, total_out = [], 0.0

        for (kind, A1, A2), xi in zip(paths, xs):
            if xi <= 1e-12:
                continue
            if kind == "direct":
                y1, _, splits1 = self._aggregate_pair(tokenA, tokenB, xi)
                y1 = self._round_to_tick(tokenB, y1)
                total_out += y1
                routes.append(
                    {
                        "path": [tokenA, tokenB],
                        "in": xi,
                        "out": y1,
                        "pair": [{"pair": f"{tokenA}/{tokenB}", **s} for s in splits1],
                    }
                )
            elif kind == "via1":
                y1, _, splits1 = self._aggregate_pair(tokenA, A1, xi)
                y1 = self._round_to_tick(A1, y1)
                y2, _, splits2 = self._aggregate_pair(A1, tokenB, y1)
                y2 = self._round_to_tick(tokenB, y2)

                total_out += y2
                routes.append(
                    {
                        "path": [tokenA, A1, tokenB],
                        "in": xi,
                        "mid_out": y1,
                        "out": y2,
                        "pair1": [{"pair": f"{tokenA}/{A1}", **s} for s in splits1],
                        "pair2": [{"pair": f"{A1}/{tokenB}", **s} for s in splits2],
                    }
                )
            else:  # via2
                y1, _, splits1 = self._aggregate_pair(tokenA, A1, xi)
                y1 = self._round_to_tick(A1, y1)
                y2, _, splits2 = self._aggregate_pair(A1, A2, y1)
                y2 = self._round_to_tick(A2, y2)
                y3, _, splits3 = self._aggregate_pair(A2, tokenB, y2)
                y3 = self._round_to_tick(tokenB, y3)

                total_out += y3
                routes.append(
                    {
                        "path": [tokenA, A1, A2, tokenB],
                        "in": xi,
                        "mid1_out": y1,
                        "mid2_out": y2,
                        "out": y3,
                        "pair1": [{"pair": f"{tokenA}/{A1}", **s} for s in splits1],
                        "pair2": [{"pair": f"{A1}/{A2}", **s} for s in splits2],
                        "pair3": [{"pair": f"{A2}/{tokenB}", **s} for s in splits3],
                    }
                )
        return total_out, routes, xs, spent

    def solve(self, tokenA: str, tokenB: str, amount_in: float):
        if amount_in <= 0:
            return {"amount_out": 0.0, "routes": []}

        # List all possible paths
        paths = self._candidate_paths(tokenA, tokenB)
        if not paths:
            return {"amount_out": 0.0, "routes": []}

        # Init marginal range
        mu_lo, mu_hi = 0.0, max(self._path_initial_mu(kind, tokenA, tokenB, A1, A2) for kind, A1, A2 in paths)
        print(mu_hi*amount_in)

        best = None
        iters = 0
        print(paths)
        while mu_hi - mu_lo > self.TOL and iters < self.MAX_ITER:
            mu_cand = 0.5 * (mu_lo + mu_hi)
            amount_out, routes, xs, spent = self._allocate_for_mu(mu_cand, paths, tokenA, tokenB, amount_in)
            best = (mu_cand, amount_out, routes, xs)
            # A appliquer pour éviter de rechercher une égalité stricte.
            if abs(spent - amount_in) <= amount_in * self.TOL:
              break
            elif spent > amount_in:
                mu_lo = mu_cand
            else:
                mu_hi = mu_cand
            iters += 1

        if best is None:
            raise RuntimeError(f"solve failed to converge for ({tokenA},{tokenB}) with amount_in={amount_in}")

        _, amount_out, routes, _ = best
        return {"amount_out": amount_out, "routes": routes}


In [74]:
def main():
    pools = [
        UniV2Pool("ETH", "USDC", 1, 10000),
        UniV2Pool("ETH", "USDC", 2000, 2000000),
        UniV2Pool("ETH", "USDC", 1000, 1000000),
        UniV2Pool("ETH", "DAI", 1000, 900000),
        UniV2Pool("ETH", "DAI", 3000, 2800000),
        UniV2Pool("ETH", "DAI", 3000, 3100000),
        UniV2Pool("DAI", "USDC", 1000000, 1000000),
        UniV2Pool("DAI", "USDC", 2000000, 2000000),
        UniV2Pool("DAI", "USDT", 1000000, 900000),
        UniV2Pool("DAI", "USDT", 900000, 1000000),
        #UniV2Pool("USDC", "USDT", 2000000, 2000000),
        UniV2Pool("ETH", "USDT", 2000, 2000000),
        UniV2Pool("ETH", "USDT", 10000, 10000000),
    ]

    # ETH / USDC
    # (ETH / DAI / USDC)
    # (ETH / USDT / DAI / USDC)
    eth_sell_amount = 10
    usdc_sell_amount = 10000
    router = Router(pools)
    res_0 = router.solve("ETH", "USDC", eth_sell_amount)


    print(f"solution for '{eth_sell_amount} ETH to USDC':")
    pprint(res_0)
    print(' ')

    if False:
      res_1 = router.solve("USDC", "ETH", usdc_sell_amount)
      print(f"solution for '{usdc_sell_amount} USDC to ETH':")
      pprint(res_1)

if __name__ == "__main__":
    main()

0
1
2
3
4
5
6
7
8
9
10
11
100000.0
[('direct', None, None), ('via1', 'DAI', None), ('via2', 'USDT', 'DAI')]
solution for '10 ETH to USDC':
{'amount_out': 15446.162855999999,
 'routes': [{'in': 2.037380076970303,
             'out': 6707.688947,
             'pair': [{'alloc_in': 2.0373800769704027,
                       'pair': 'ETH/USDC',
                       'pool_index': 0},
                      {'alloc_in': 0.0, 'pair': 'ETH/USDC', 'pool_index': 1},
                      {'alloc_in': 0.0, 'pair': 'ETH/USDC', 'pool_index': 2}],
             'path': ['ETH', 'USDC']},
            {'in': 7.962619922529029,
             'mid1_out': 7957.339817,
             'mid2_out': 8764.001881901258,
             'out': 8738.473909,
             'pair1': [{'alloc_in': 1.3271033204782725,
                        'pair': 'ETH/USDT',
                        'pool_index': 10},
                       {'alloc_in': 6.635516602391363,
                        'pair': 'ETH/USDT',
                        '