In [2]:
from Crypto.Hash import keccak
import random
import decimal
import csv
import time

decimal.getcontext().prec = 18


class PoolOracle:
    def __init__(self):
        self.price_of = {}

    def feed(self, market: str, price: decimal.Decimal):
        self.price_of[market] = price

    def get_price(self, market: str):
        return self.price_of[market]

    def get_adaptive_price(
        self,
        market: str,
        size_delta: decimal.Decimal,
        max_skew: decimal.Decimal,
        skew: decimal.Decimal,
    ):
        if self.price_of.get(market) is None:
            raise Exception("market not found")
        price = self.price_of[market]
        pd_before = skew / max_skew
        pd_after = (skew + size_delta) / max_skew
        price_before = price * (1 + pd_before)
        price_after = price * (1 + pd_after)
        return (price_before + price_after) / 2
        # return self.price_of[market]


class Pool:
    def __init__(self, oralce_: PoolOracle, max_skew_: decimal.Decimal):
        self.positions = {}
        self.counter_trade_states = {
            "True": {},
            "False": {},
        }
        self.pool_oralce = oralce_
        self.market_skew_of = {}
        self.max_skew = max_skew_

    def get_position_id(self, user: str, market: str):
        hasher = keccak.new(digest_bits=256)
        hasher.update(f"{user}{market}".encode("utf-8"))
        return hasher.hexdigest()

    def get_market_skew_of(self, market: str):
        return self.market_skew_of.get(market, 0)

    def get_position_delta(self, user: str, market: str):
        position = self.positions.get(self.get_position_id(user, market))
        if position is None or position["size"] == 0:
            return 0
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            -position["size"],
            self.max_skew,
            self.market_skew_of[market],
        )
        if mark_price < 0:
            print(self.counter_trade_states)
        return self.get_delta(
            position["size"],
            position["size"] > 0,
            mark_price,
            position["avg_entry_price"],
        )

    def get_close_price(self, user: str, market: str):
        return self.pool_oralce.get_adaptive_price(
            market,
            -self.positions.get(self.get_position_id(user, market))["size"],
            self.max_skew,
            self.market_skew_of[market],
        )

    def get_delta(
        self,
        size: decimal.Decimal,
        is_long: bool,
        mark_price: decimal.Decimal,
        avg_entry_price: decimal.Decimal,
    ):
        # if avg entry price of the position is 0 then return 0
        if avg_entry_price == 0:
            return 0
        # calculate pnl
        pnl = size * (mark_price - avg_entry_price) / avg_entry_price
        # Inverse pnl if short
        if is_long == False:
            return -pnl
        return pnl

    def increase_position(self, user: str, market: str, size_delta: decimal.Decimal):
        # short circuit when size_delta is 0
        if size_delta == 0:
            return

        # long or short
        is_long = size_delta > 0

        # get position id
        position_id = self.get_position_id(user, market)
        # if position not exist, create it
        if self.positions.get(position_id) is None:
            self.positions[position_id] = {
                "user": user,
                "market": market,
                "avg_entry_price": decimal.Decimal(0),
                "size": decimal.Decimal(0),
                "realized_pnl": decimal.Decimal(0),
            }
        # if market skew not exist, create it
        if self.market_skew_of.get(market) is None:
            self.market_skew_of[market] = decimal.Decimal(0)
        # if counter trade states not exist, create it
        if self.counter_trade_states[str(is_long)].get(market) is None:
            self.counter_trade_states[str(is_long)][market] = {
                "size": decimal.Decimal(0),
                "sum_se": decimal.Decimal(0),
                "sum_s2e": decimal.Decimal(0),
            }

        # Load position
        position = self.positions.get(position_id)

        # get prices from oracle
        # get mark price
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            size_delta,
            self.max_skew,
            self.market_skew_of[market],
        )

        is_new_position = position["size"] == 0

        # update position
        # if it is a new position, set avg_entry_price to mark_price
        next_close_price = self.get_next_close_price(
            self.market_skew_of[market],
            self.pool_oralce.get_price(market),
            position["size"],
            (size_delta if is_long else -size_delta),
        )

        old_sum_se = 0
        old_sum_s2e = 0
        if is_new_position:
            position["user"] = user
            position["avg_entry_price"] = mark_price
            position["size"] = size_delta
            position["realized_pnl"] = 0
        # existing position
        else:
            unrealizedPnl = self.get_position_delta(user, market)
            old_sum_se = abs(position["size"]) / position["avg_entry_price"]
            old_sum_s2e = (position["size"] * position["size"]) / position[
                "avg_entry_price"
            ]
            position["size"] += size_delta
            position["avg_entry_price"] = self.get_average_entry_price(
                position["size"], next_close_price, unrealizedPnl
            )

        # update counter trade states
        if is_new_position:
            self.counter_trade_states[str(is_long)][market]["sum_se"] += (
                abs(size_delta) / mark_price
            )
            self.counter_trade_states[str(is_long)][market]["sum_s2e"] += (
                size_delta * size_delta
            ) / mark_price
            self.counter_trade_states[str(is_long)][market]["size"] += abs(size_delta)
        else:
            self.counter_trade_states[str(is_long)][market]["sum_se"] -= old_sum_se
            self.counter_trade_states[str(is_long)][market]["sum_se"] += (
                abs(position["size"]) / position["avg_entry_price"]
            )
            self.counter_trade_states[str(is_long)][market]["sum_s2e"] -= old_sum_s2e
            self.counter_trade_states[str(is_long)][market]["sum_s2e"] += (
                position["size"] * position["size"]
            ) / position["avg_entry_price"]
            self.counter_trade_states[str(is_long)][market]["size"] += abs(size_delta)

            # update market skew
        if is_long:
            self.market_skew_of[market] += size_delta
        else:
            self.market_skew_of[market] += size_delta

    def decrease_position(self, user: str, market: str, size_delta: decimal.Decimal):
        # short circuit when size_delta is 0
        if size_delta == 0:
            return

        if size_delta < 0:
            raise Exception("Size delta > 0")

        # long or short

        # get position id
        position_id = self.get_position_id(user, market)
        if self.positions.get(position_id) is None:
            raise Exception("Can't decrease non-existent position")

        # Load position
        position = self.positions.get(position_id)
        is_long = position["size"] > 0

        # get prices from oracle
        # get mark price
        mark_price = self.pool_oralce.get_adaptive_price(
            market,
            -size_delta if is_long else size_delta,
            self.max_skew,
            self.market_skew_of[market],
        )

        positionPnl = self.get_position_delta(user, market)
        toRealizedPnl = positionPnl * size_delta / abs(position["size"])
        unrealizedPnl = positionPnl - toRealizedPnl

        old_weighted_price_sum = position["avg_entry_price"] * abs(size_delta)
        old_sum_se = abs(position["size"]) / position["avg_entry_price"]
        old_sum_s2e = position["size"] * position["size"] / position["avg_entry_price"]

        # update position
        next_close_price = self.get_next_close_price(
            self.market_skew_of[market],
            self.pool_oralce.get_price(market),
            position["size"],
            (-size_delta if is_long else size_delta),
        )
        position["size"] -= size_delta if is_long else -size_delta
        if position["size"] != 0:
            position["avg_entry_price"] = self.get_average_entry_price(
                position["size"], next_close_price, unrealizedPnl
            )
        else:
            position["avg_entry_price"] = 0
        position["realized_pnl"] += toRealizedPnl

        # update counter trade states
        self.counter_trade_states[str(is_long)][market]["sum_se"] -= old_sum_se
        self.counter_trade_states[str(is_long)][market]["sum_se"] += (
            (abs(position["size"]) / position["avg_entry_price"])
            if position["avg_entry_price"] > 0
            else 0
        )
        self.counter_trade_states[str(is_long)][market]["sum_s2e"] -= old_sum_s2e
        self.counter_trade_states[str(is_long)][market]["sum_s2e"] += (
            (position["size"] * position["size"]) / position["avg_entry_price"]
            if position["avg_entry_price"] > 0
            else 0
        )
        self.counter_trade_states[str(is_long)][market]["size"] -= abs(size_delta)

        # update market skew
        if is_long:
            self.market_skew_of[market] -= size_delta
        else:
            self.market_skew_of[market] += size_delta

    def get_average_entry_price(
        self,
        position_size: decimal.Decimal,
        markPrice: decimal.Decimal,
        unrealizedPnl: decimal.Decimal,
    ):
        return (markPrice * position_size) / (position_size + unrealizedPnl)

    def get_next_close_price(
        self,
        skew: decimal.Decimal,
        oracle_price: decimal.Decimal,
        position_size: decimal.Decimal,
        size_delta: decimal.Decimal,
    ):
        _newPositionSize = position_size + size_delta
        _newMarketSkew = skew + size_delta

        _premiumBefore = _newMarketSkew / self.max_skew
        _premiumAfter = (_newMarketSkew - _newPositionSize) / self.max_skew

        _premium = (_premiumBefore + _premiumAfter) / 2

        return oracle_price * (1 + _premium)

    def get_next_avg_entry_price(
        self,
        is_long: str,
        market: str,
        next_price: decimal.Decimal,
        size_delta: decimal.Decimal,
    ):
        states = self.counter_trade_states[is_long][market]
        avg_entry_price = states["avg_entry_price"]
        size = states["size"]
        if size == 0:
            return next_price
        next_size = states["size"] + size_delta
        price_delta = avg_entry_price - next_price
        delta = states["size"] * price_delta / avg_entry_price
        return next_price * next_size / (next_size - delta)

    def get_position(self, user: str, market: str):
        return self.positions.get(
            self.get_position_id(user, market),
            {"user": None, "market": None, "avg_entry_price": 0, "size": 0},
        )

    def get_market_global_pnl(self, is_long: bool, market: str):
        price = decimal.Decimal(self.pool_oralce.get_price(market))
        max_skew = decimal.Decimal(self.max_skew)
        skew = self.market_skew_of[market]
        sum_se = (
            -self.counter_trade_states[str(is_long)][market]["sum_se"]
            if is_long
            else self.counter_trade_states[str(is_long)][market]["sum_se"]
        )
        sum_s2e = self.counter_trade_states[str(is_long)][market]["sum_s2e"]
        sum_size = (
            -self.counter_trade_states[str(is_long)][market]["size"]
            if is_long
            else self.counter_trade_states[str(is_long)][market]["size"]
        )
        print(
            "price:",
            price,
            "max_skew:",
            max_skew,
            "skew:",
            skew,
            "sum_se:",
            sum_se,
            "sum_s2e:",
            sum_s2e,
            "sum_size:",
            sum_size,
        )
        a = price * sum_se
        b = price * skew / max_skew * sum_se
        c = price / (2 * max_skew) * sum_s2e
        d = sum_size
        global_pnl = a + b + c - d
        return -global_pnl


def percentage_diff(base_value, comp_value):
    return abs((comp_value - base_value) / base_value) * 100


oracle = PoolOracle()
pool = Pool(oracle, 300_000_000)

# feed prices
oracle.feed("BTC", 28_000)
oracle.feed("ETH", 1_200)

data = [["trade_id", "avg_entry_price", "size", "expected_close_price"]]
total_trades = 1000
iterations = 100000
current_price = 28_000

# # simple simulation
# for i in range(total_trades):
#     current_price += decimal.Decimal(random.randint(-30, 30))
#     oracle.feed("BTC", current_price)
#     pool.increase_position(
#         str(i), "BTC", decimal.Decimal(random.randint(100, 10_000_000))
#     )

# for i in range(total_trades):
#     current_price += decimal.Decimal(random.randint(-30, 30))
#     oracle.feed("BTC", current_price)
#     position = pool.get_position(str(i), "BTC")
#     pool.decrease_position(
#         str(i), "BTC", decimal.Decimal(random.randint(100, position["size"]))
#     )

# complex simulation
open_count = 0
increase_count = 0
decrease_count = 0
for i in range(iterations):
    current_price += decimal.Decimal(random.randint(-30, 30))
    oracle.feed("BTC", current_price)
    trader = random.choice(range(total_trades))
    position = pool.get_position(trader, "BTC")

    if position["size"] == 0:
        pool.increase_position(
            trader, "BTC", decimal.Decimal(random.randint(-1_000_000, 1_000_000))
        )
        open_count += 1
    else:
        is_decrease = random.choice([True, False])
        if is_decrease is True:
            pool.decrease_position(
                trader, "BTC", decimal.Decimal(random.randint(1, abs(position["size"])))
            )
            decrease_count += 1
        else:
            if position["size"] > 0:
                pool.increase_position(
                    trader, "BTC", decimal.Decimal(random.randint(100, 100_000))
                )
            else:
                pool.increase_position(
                    trader, "BTC", decimal.Decimal(random.randint(-100_000, -100))
                )
            increase_count += 1

expensive_pool_pnl_long = 0
expensive_pool_pnl_short = 0
for i in range(total_trades):
    position = pool.get_position(str(i), "BTC")
    if position is not None and position["size"] != 0:
        is_long = position["size"] > 0
        if is_long:
            expensive_pool_pnl_long += pool.get_position_delta(str(i), "BTC")
        else:
            expensive_pool_pnl_short += pool.get_position_delta(str(i), "BTC")
        data.append(
            [
                i,
                position["avg_entry_price"],
                position["size"],
                pool.get_close_price(str(i), "BTC"),
            ]
        )

with open(f"./out/{int(time.time())}_data.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerows(data)

cheap_pool_pnl_long = pool.get_market_global_pnl(True, "BTC")
# cheap_pool_pnl_long = 0
cheap_pool_pnl_short = pool.get_market_global_pnl(False, "BTC")
# cheap_pool_pnl_short = 0
cheap_pool_pnl = cheap_pool_pnl_long + cheap_pool_pnl_short

print(current_price)
print(pool.counter_trade_states)
print("expensive pool pnl long:", expensive_pool_pnl_long)
print("expensive pool pnl short:", expensive_pool_pnl_short)
print("cheap pool pnl long:", cheap_pool_pnl_long)
print("cheap pool pnl short:", cheap_pool_pnl_short)
print("diff long:", abs(expensive_pool_pnl_long) - abs(cheap_pool_pnl_long))
print(
    "diff long %:",
    percentage_diff(abs(expensive_pool_pnl_long), abs(cheap_pool_pnl_long)),
)
print("diff short:", abs(expensive_pool_pnl_short) - abs(cheap_pool_pnl_short))
print(
    "diff short %:",
    percentage_diff(abs(expensive_pool_pnl_short), abs(cheap_pool_pnl_short)),
)
print("open_count", open_count)
print("increase_count", increase_count)
print("decrease_count", decrease_count)

# validate if decrease is correct
# oracle.feed("BTC", 20000)
# pool.increase_position("alice", "BTC", 1000)
# print("position alice", pool.get_position("alice", "BTC"))

# oracle.feed("BTC", 22000)
# print("alice pnl", pool.get_position_delta("alice", "BTC"))
# pool.increase_position("alice", "BTC", 100)
# print("position alice", pool.get_position("alice", "BTC"))
# print("alice pnl", pool.get_position_delta("alice", "BTC"))
# print("long global pnl", pool.get_market_global_pnl(True, "BTC"))

# pool.increase_position("bob", "BTC", 1000)
# print("position bob", pool.get_position("bob", "BTC"))
# print("bob pnl", pool.get_position_delta("bob", "BTC"))
# print("long global pnl", pool.get_market_global_pnl(True, "BTC"))

# oracle.feed("BTC", 20000)
# pool.increase_position("bob", "BTC", 100)
# print("position bob", pool.get_position("bob", "BTC"))
# print("alice pnl", pool.get_position_delta("alice", "BTC"))
# print("bob pnl", pool.get_position_delta("bob", "BTC"))
# print("long global pnl", pool.get_market_global_pnl(True, "BTC"))

# oracle.feed("JPY", decimal.Decimal(0.007346297098947275625720855402))
# pool.increase_position("alice", "JPY", -1000)
# print("position alice", pool.get_position("alice", "JPY"))
# print("alice pnl", pool.get_position_delta("alice", "JPY"))
# print("short global pnl", pool.get_market_global_pnl(False, "JPY"))

# oracle.feed("JPY", decimal.Decimal(0.007419773696902244481543312928))
# pool.increase_position("alice", "JPY", -100)
# print("position alice", pool.get_position("alice", "JPY"))
# print("alice pnl", pool.get_position_delta("alice", "JPY"))
# print("short global pnl", pool.get_market_global_pnl(False, "JPY"))

# ============================

price: 23269 max_skew: 300000000 skew: -2333056 sum_se: -1996.16399782705058 sum_s2e: 359467406.171542200 sum_size: -47412475
price: 23269 max_skew: 300000000 skew: -2333056 sum_se: 2107.10206757408815 sum_s2e: 482765764.049911847 sum_size: 49745531
23269
{'True': {'BTC': {'size': Decimal('47412475'), 'sum_se': Decimal('1996.16399782705058'), 'sum_s2e': Decimal('359467406.171542200')}}, 'False': {'BTC': {'size': Decimal('49745531'), 'sum_se': Decimal('2107.10206757408815'), 'sum_s2e': Decimal('482765764.049911847')}}}
expensive pool pnl long: -1338900.71869306721
expensive pool pnl short: -1077950.87643597890
cheap pool pnl long: -1338900.7186930684
cheap pool pnl short: 1077950.8764359756
diff long: -1.19E-9
diff long %: 8.88788827570118315E-14
diff short: 3.30E-9
diff short %: 3.06136399360865670E-13
open_count 1078
increase_count 49437
decrease_count 49485
