In [149]:
from Crypto.Hash import keccak
import random
import decimal
import csv
import time

decimal.getcontext().prec = 18


class PoolOracle:
    def __init__(self):
        self.price_of = {}

    def feed(self, market: str, price: decimal.Decimal):
        self.price_of[market] = price

    def get_price(self, market: str):
        return self.price_of[market]

    def get_adaptive_price(
        self,
        market: str,
        size_delta: decimal.Decimal,
        max_skew: decimal.Decimal,
        skew: decimal.Decimal,
    ):
        if self.price_of.get(market) is None:
            raise Exception("market not found")
        price = self.price_of[market]
        pd_before = skew / max_skew
        pd_after = (skew + size_delta) / max_skew
        price_before = price * (1 + pd_before)
        price_after = price * (1 + pd_after)
        return (price_before + price_after) / 2
        # return self.price_of[market]


class Pool:
    def __init__(self, oralce_: PoolOracle, max_skew_: decimal.Decimal):
        self.positions = {}
        self.counter_trade_states = {
            "True": {},
            "False": {},
        }
        self.pool_oralce = oralce_
        self.market_skew_of = {}
        self.max_skew = max_skew_

    def get_position_id(self, user: str, market: str):
        hasher = keccak.new(digest_bits=256)
        hasher.update(f"{user}{market}".encode("utf-8"))
        return hasher.hexdigest()

    def get_market_skew_of(self, market: str):
        return self.market_skew_of.get(market, 0)

    def get_position_delta(self, user: str, market: str):
        position = self.positions[self.get_position_id(user, market)]
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            -position["size"],
            self.max_skew,
            self.market_skew_of[market],
        )
        return self.get_delta(
            position["size"],
            position["size"] > 0,
            mark_price,
            position["avg_entry_price"],
        )

    def get_close_price(self, user: str, market: str):
        return self.pool_oralce.get_adaptive_price(
            market,
            -self.positions[self.get_position_id(user, market)]["size"],
            self.max_skew,
            self.market_skew_of[market],
        )

    def get_delta(
        self,
        size: decimal.Decimal,
        is_long: bool,
        mark_price: decimal.Decimal,
        avg_entry_price: decimal.Decimal,
    ):
        # if avg entry price of the position is 0 then return 0
        if avg_entry_price == 0:
            return 0
        # calculate pnl
        pnl = size * (mark_price - avg_entry_price) / avg_entry_price
        # Inverse pnl if short
        if is_long == False:
            return -pnl
        return pnl

    def increase_position(self, user: str, market: str, size_delta: decimal.Decimal):
        # short circuit when size_delta is 0
        if size_delta == 0:
            return

        # long or short
        is_long = size_delta > 0

        # get position id
        position_id = self.get_position_id(user, market)
        # if position not exist, create it
        if self.positions.get(position_id) is None:
            self.positions[position_id] = {
                "user": user,
                "market": market,
                "avg_entry_price": decimal.Decimal(0),
                "size": decimal.Decimal(0),
                "realized_pnl": decimal.Decimal(0),
            }
        # if market skew not exist, create it
        if self.market_skew_of.get(market) is None:
            self.market_skew_of[market] = decimal.Decimal(0)
        # if counter trade states not exist, create it
        if self.counter_trade_states[str(is_long)].get(market) is None:
            self.counter_trade_states[str(is_long)][market] = {
                "size": decimal.Decimal(0),
                "sum_se": decimal.Decimal(0),
                "sum_s2e": decimal.Decimal(0),
            }

        # Load position
        position = self.positions.get(position_id)

        if position["size"] != 0:
            raise Exception("Not implemented when position size is not 0")

        # get prices from oracle
        # get mark price
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            size_delta,
            self.max_skew,
            self.market_skew_of[market],
        )

        # update position
        # if it is a new position, set avg_entry_price to mark_price
        position["user"] = user
        position["avg_entry_price"] = mark_price
        position["size"] = size_delta
        position["realized_pnl"] = 0

        # update counter trade states
        self.counter_trade_states[str(is_long)][market]["sum_se"] += (
            size_delta / mark_price
        )
        self.counter_trade_states[str(is_long)][market]["sum_s2e"] += (
            size_delta * size_delta
        ) / mark_price
        self.counter_trade_states[str(is_long)][market]["size"] += abs(size_delta)

        # update market skew
        if is_long:
            self.market_skew_of[market] += size_delta
        else:
            self.market_skew_of[market] -= size_delta

    def get_next_avg_entry_price(
        self,
        is_long: str,
        market: str,
        next_price: decimal.Decimal,
        size_delta: decimal.Decimal,
    ):
        states = self.counter_trade_states[is_long][market]
        avg_entry_price = states["avg_entry_price"]
        size = states["size"]
        if size == 0:
            return next_price
        next_size = states["size"] + size_delta
        price_delta = avg_entry_price - next_price
        delta = states["size"] * price_delta / avg_entry_price
        return next_price * next_size / (next_size - delta)

    def get_position(self, user: str, market: str):
        return self.positions.get(
            self.get_position_id(user, market),
            {"user": None, "market": None, "avg_entry_price": 0, "size": 0},
        )

    def get_market_global_pnl(self, is_long: bool, market: str):
        price = decimal.Decimal(self.pool_oralce.get_price(market))
        max_skew = decimal.Decimal(self.max_skew)
        skew = self.market_skew_of[market]
        sum_se = (
            -self.counter_trade_states[str(is_long)][market]["sum_se"]
            if is_long
            else self.counter_trade_states[str(is_long)][market]["sum_se"]
        )
        sum_s2e = self.counter_trade_states[str(is_long)][market]["sum_s2e"]
        sum_size = (
            -self.counter_trade_states[str(is_long)][market]["size"]
            if is_long
            else self.counter_trade_states[str(is_long)][market]["size"]
        )
        a = price * sum_se
        b = price * skew / max_skew * sum_se
        c = price / (2 * max_skew) * sum_s2e
        d = sum_size
        return a + b + c - d


def percentage_diff(base_value, comp_value):
    return abs((comp_value - base_value) / base_value) * 100


oracle = PoolOracle()
pool = Pool(oracle, 300_000_000)

# feed prices
oracle.feed("BTC", 28_000)
oracle.feed("ETH", 1_200)

# print(oracle.get_adaptive_price("ETH", 100, 1_000_000, -200))
# print(oracle.get_adaptive_price("ETH", -100, 1_000_000, -100))

data = [["trade_id", "avg_entry_price", "size", "expected_close_price"]]
total_trades = 100
current_price = 28_000
for i in range(total_trades):
    current_price += decimal.Decimal(random.randint(0, 30))
    oracle.feed("BTC", current_price)
    pool.increase_position(
        str(i), "BTC", decimal.Decimal(random.randint(100, 10_000_000))
    )

expensive_pool_pnl = 0
for i in range(total_trades):
    position = pool.get_position(str(i), "BTC")
    expensive_pool_pnl += pool.get_position_delta(str(i), "BTC")
    data.append(
        [
            i,
            position["avg_entry_price"],
            position["size"],
            pool.get_close_price(str(i), "BTC"),
        ]
    )

with open(f"./out/{int(time.time())}_data.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(data)

cheap_pool_pnl = pool.get_market_global_pnl(True, "BTC")

print(pool.counter_trade_states)
print("expensive pool pnl:", expensive_pool_pnl)
print("cheap pool pnl:", cheap_pool_pnl)
print("diff:", expensive_pool_pnl - abs(cheap_pool_pnl))
print("diff %:", percentage_diff(expensive_pool_pnl, abs(cheap_pool_pnl)))

# try increase position
# pool.increase_position("0xalice", "BTC", 10_000)
# print(pool.counter_trade_states)
# pool.increase_position("0xbob", "BTC", 5_000)
# print(pool.counter_trade_states)
# pool.increase_position("0xcat", "BTC", -2_000)
# print(pool.counter_trade_states)
# pool.increase_position("0xdog", "BTC", -3_000)
# print(pool.counter_trade_states)
# pool.increase_position("0x1", "BTC", 2_000)

# check position PnL
# print("0xalice's BTC PnL:", pool.get_position_delta("0xalice", "BTC"))
# print("0xbob's BTC PnL:", pool.get_position_delta("0xbob", "BTC"))
# print("0x1's BTC PnL:", pool.get_position_delta("0x1", "BTC"))
# print("0xcat's BTC PnL:", pool.get_position_delta("0xcat", "BTC"))
# print("0xdog's BTC PnL:", pool.get_position_delta("0xdog", "BTC"))
# print("pool's long BTC PnL:", pool.get_market_global_pnl(True, "BTC"))
# print("pool's short BTC PnL:", pool.get_market_global_pnl(False, "BTC"))
# 32418627242
# 3134091025282

{'True': {'BTC': {'size': Decimal('477536617'), 'sum_se': Decimal('9992.39157213278007'), 'sum_s2e': Decimal('67006331658.0278734')}}, 'False': {}}
expensive pool pnl: 282133395.379410378
cheap pool pnl: -282133395.379410380
diff: -2E-9
diff %: 7.08884532194573605E-16
