In [45]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from enum import IntEnum
from collections.abc import Iterable

from numpy.typing import NDArray

In [46]:
ArrayF = NDArray[np.float64]
Array1D = NDArray[np.float64]

In [39]:
class Instrument(ABC):
    @abstractmethod
    def payoff(self, S: ArrayF) -> ArrayF:
        pass


@dataclass(frozen=True)
class Call(Instrument):
    strike_price: float

    def payoff(self, S: ArrayF) -> ArrayF:
        payoff = np.copy(S)
        payoff = payoff - self.strike_price
        return np.where(payoff > 0, payoff, 0)


@dataclass(frozen=True)
class Put(Instrument):
    strike_price: float

    def payoff(self, S: ArrayF) -> ArrayF:
        payoff = np.copy(S)
        payoff = self.strike_price - payoff
        return np.where(payoff > 0, payoff, 0)


@dataclass(frozen=True)
class Stock(Instrument):
    initial_price: float

    def payoff(self, S: ArrayF) -> ArrayF:
        return S - self.initial_price

    def pri


class Side(IntEnum):
    LONG = 1
    SHORT = -1

In [34]:
@dataclass(frozen=True)
class Leg():
    instrument: Instrument
    side: Side = field(default=1)
    quantity: int = field(default=1)

    @property
    def instrument(self) -> Instrument:
        return self.instrument

    @property
    def side(self) -> Side:
        return self.side

    @property
    def quantity(self) -> int:
        return self.quantity

    def pnl(self, S: ArrayF) -> ArrayF:
        return self.quantity * self.side * self.instrument.payoff(S)

In [None]:



class Position():
    def __init__(self, legs: tuple[Leg]):
        self.legs = legs
        self.prices: Array1D = self._median_price(legs)

    def pnl(self, anchor_point: float | None):
        lb, rb = anchor_point*0.5, anchor_point * 2
    
    @staticmethod
    def _validate_prices(prices: Iterable[float]) -> Array1D:
        arr = np.asarray(prices, dtype=np.float64)

        if arr.ndim == 0:
            raise ValueError("prices must be a 1D iterable")

        if arr.ndim != 1:
            raise ValueError(f"Prices must be 1D, got shape {arr.shape}")