In [2]:
from Crypto.Hash import keccak
import decimal

decimal.getcontext().prec = 18


class PoolOracle:
    def __init__(self):
        self.price_of = {}

    def feed(self, market: str, price: decimal.Decimal):
        self.price_of[market] = price

    def get_price(self, market: str):
        return self.price_of[market]

    def get_adaptive_price(
        self,
        market: str,
        size_delta: decimal.Decimal,
        max_skew: decimal.Decimal,
        skew: decimal.Decimal,
    ):
        if self.price_of.get(market) is None:
            raise Exception("market not found")
        price = self.price_of[market]
        pd_before = skew / max_skew
        pd_after = (skew + size_delta) / max_skew
        price_before = price * (1 + pd_before)
        price_after = price * (1 + pd_after)
        return (price_before + price_after) / 2
        # return self.price_of[market]


class Pool:
    def __init__(self, oralce_: PoolOracle, max_skew_: decimal.Decimal):
        self.positions = {}
        self.counter_trade_states = {
            "True": {},
            "False": {},
        }
        self.pool_oralce = oralce_
        self.market_skew_of = {}
        self.max_skew = max_skew_

    def get_position_id(self, user: str, market: str):
        hasher = keccak.new(digest_bits=256)
        hasher.update(f"{user}{market}".encode("utf-8"))
        return hasher.hexdigest()

    def get_market_skew_of(self, market: str):
        return self.market_skew_of.get(market, 0)

    def get_position_delta(self, user: str, market: str):
        position = self.positions[self.get_position_id(user, market)]
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            -position["size"],
            self.max_skew,
            self.market_skew_of[market],
        )
        return self.get_delta(
            position["size"],
            position["size"] > 0,
            mark_price,
            position["avg_entry_price"],
        )

    def get_close_price(self, user: str, market: str):
        return self.pool_oralce.get_adaptive_price(
            market,
            -self.positions[self.get_position_id(user, market)]["size"],
            self.max_skew,
            self.market_skew_of[market],
        )

    def get_delta(
        self,
        size: decimal.Decimal,
        is_long: bool,
        mark_price: decimal.Decimal,
        avg_entry_price: decimal.Decimal,
    ):
        # if avg entry price of the position is 0 then return 0
        if avg_entry_price == 0:
            return 0
        # calculate pnl
        pnl = size * (mark_price - avg_entry_price) / avg_entry_price
        # Inverse pnl if short
        if is_long == False:
            return -pnl
        return pnl

    def increase_position(self, user: str, market: str, size_delta: decimal.Decimal):
        # short circuit when size_delta is 0
        if size_delta == 0:
            return

        # long or short
        is_long = size_delta > 0

        # get position id
        position_id = self.get_position_id(user, market)
        # if position not exist, create it
        if self.positions.get(position_id) is None:
            self.positions[position_id] = {
                "user": user,
                "market": market,
                "avg_entry_price": decimal.Decimal(0),
                "size": decimal.Decimal(0),
                "realized_pnl": decimal.Decimal(0),
            }
        # if market skew not exist, create it
        if self.market_skew_of.get(market) is None:
            self.market_skew_of[market] = decimal.Decimal(0)
        # if counter trade states not exist, create it
        if self.counter_trade_states[str(is_long)].get(market) is None:
            self.counter_trade_states[str(is_long)][market] = {
                "size": decimal.Decimal(0),
                "size2": decimal.Decimal(0),
                "weighted_price_sum": decimal.Decimal(0),
            }

        # Load position
        position = self.positions.get(position_id)

        if position["size"] != 0:
            raise Exception("Not implemented when position size is not 0")

        # get prices from oracle
        # get mark price
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            size_delta,
            self.max_skew,
            self.market_skew_of[market],
        )

        # update position
        # if it is a new position, set avg_entry_price to mark_price
        position["user"] = user
        position["avg_entry_price"] = mark_price
        position["size"] = size_delta
        position["realized_pnl"] = 0
        print(self.get_position(user, market))

        # update counter trade states
        self.counter_trade_states[str(is_long)][market][
            "weighted_price_sum"
        ] += mark_price * abs(size_delta)
        self.counter_trade_states[str(is_long)][market]["size"] += abs(size_delta)
        self.counter_trade_states[str(is_long)][market]["size2"] += abs(size_delta) ** 2

        # update market skew
        if is_long:
            self.market_skew_of[market] += size_delta
        else:
            self.market_skew_of[market] -= size_delta

    def get_next_avg_skew(self, market: str, size_delta: decimal.Decimal):
        prev_skew = self.market_skew_of[market]
        avg_skew = self.counter_trade_states[str(True)][market]["avg_skew"]
        return (avg_skew * prev_skew) + size_delta / (prev_skew + size_delta)

    def get_position(self, user: str, market: str):
        return self.positions.get(
            self.get_position_id(user, market),
            {"user": None, "market": None, "avg_entry_price": 0, "size": 0},
        )

    def get_market_global_average_entry_price(self, is_long: bool, market: str):
        weighted_price_sum = self.counter_trade_states[str(is_long)][market][
            "weighted_price_sum"
        ]
        size = self.counter_trade_states[str(is_long)][market]["size"]
        return weighted_price_sum / size

    def get_market_global_pnl(self, is_long: bool, market: str):
        price = decimal.Decimal(self.pool_oralce.get_price(market))
        max_skew = decimal.Decimal(self.max_skew)
        skew = self.market_skew_of[market]
        sum_size = (
            -self.counter_trade_states[str(is_long)][market]["size"]
            if is_long
            else self.counter_trade_states[str(is_long)][market]["size"]
        )
        sum_size2 = self.counter_trade_states[str(is_long)][market]["size2"]
        weighted_avg_entry_price = self.get_market_global_average_entry_price(
            is_long, market
        )
        weighted_avg_exit_price = (
            (2 * price * sum_size)
            + ((2 * price * skew / max_skew) * sum_size)
            + (price / max_skew * sum_size2)
        ) / (2 * sum_size)
        print(
            "price",
            price,
            "sum_size",
            sum_size,
            "skew",
            skew,
            "max_skew",
            max_skew,
            "sum_size2",
            sum_size2,
        )
        print("a", (2 * price * sum_size))
        print("b", ((2 * price * skew / max_skew) * sum_size))
        print("b1", (2 * price * skew / max_skew))
        print("c", (price / max_skew * sum_size2))
        print("d", (2 * sum_size))
        print(weighted_avg_exit_price)
        if is_long == False:
            # if short, market global pnl turns positive when exit price is higher than entry price
            # due to counter trade
            return (
                (weighted_avg_entry_price - weighted_avg_exit_price)
                * sum_size
                / weighted_avg_entry_price
            )
        # if long, market global pnl turns positive when exit price is lower than entry price
        return (
            (weighted_avg_exit_price - weighted_avg_entry_price)
            * sum_size
            / weighted_avg_entry_price
        )


oracle = PoolOracle()
pool = Pool(oracle, 300_000_000)

# feed prices
oracle.feed("BTC", 28_000)
oracle.feed("ETH", 1_200)

# print(oracle.get_adaptive_price("ETH", 100, 1_000_000, -200))
# print(oracle.get_adaptive_price("ETH", -100, 1_000_000, -100))

# try increase position
pool.increase_position("0xalice", "BTC", 10_000)
# print(pool.counter_trade_states)
pool.increase_position("0xbob", "BTC", 5_000)
# print(pool.counter_trade_states)
pool.increase_position("0xcat", "BTC", -2_000)
# print(pool.counter_trade_states)
pool.increase_position("0xdog", "BTC", -3_000)
# print(pool.counter_trade_states)
pool.increase_position("0x1", "BTC", 2_000)

# check position PnL
print("0xalice's BTC PnL:", pool.get_position_delta("0xalice", "BTC"))
# print("0xbob's BTC PnL:", pool.get_position_delta("0xbob", "BTC"))
# print("0x1's BTC PnL:", pool.get_position_delta("0x1", "BTC"))
# print("0xcat's BTC PnL:", pool.get_position_delta("0xcat", "BTC"))
# print("0xdog's BTC PnL:", pool.get_position_delta("0xdog", "BTC"))
print("pool's long BTC PnL:", pool.get_market_global_pnl(True, "BTC"))
print("pool's short BTC PnL:", pool.get_market_global_pnl(False, "BTC"))

{'user': '0xalice', 'market': 'BTC', 'avg_entry_price': Decimal('28000.4666666666666'), 'size': 10000, 'realized_pnl': 0}
{'user': '0xbob', 'market': 'BTC', 'avg_entry_price': Decimal('28001.1666666666666'), 'size': 5000, 'realized_pnl': 0}
{'user': '0xcat', 'market': 'BTC', 'avg_entry_price': Decimal('28001.3066666666666'), 'size': -2000, 'realized_pnl': 0}
{'user': '0xdog', 'market': 'BTC', 'avg_entry_price': Decimal('28001.4466666666668'), 'size': -3000, 'realized_pnl': 0}
{'user': '0x1', 'market': 'BTC', 'avg_entry_price': Decimal('28001.9600000000000'), 'size': 2000, 'realized_pnl': 0}
0xalice's BTC PnL: 0.399993333444442594
price 28000 sum_size -17000 skew 22000 max_skew 300000000 sum_size2 129000000
a -952000000
b -69813.3333333333334
b1 4.10666666666666667
c 12040.0000000000000
d -34000
28001.6992156862745
pool's long BTC PnL: -0.516651015180052950
price 28000 sum_size 5000 skew 22000 max_skew 300000000 sum_size2 13000000
a 280000000
b 20533.3333333333334
b1 4.10666666666666667