In [None]:
class MaxSharpeModel:

    def __init__(
            self, 
            stock_prices, 
            lookback=60, 
            rebalance=60, 
            risk_free_rate=0.03, 
            costs=0.002,
            shorts=False,
    ):
        self.stock_prices = stock_prices
        self.stocks = self.stock_prices.columns.to_list()
        self.stock_returns = stock_prices.pct_change()[1:]
        self.lookback = lookback
        self.rebalance = rebalance
        self.weights = self._init_weights_df(self.stock_returns)
        self.risk_free_rate = risk_free_rate
        self.costs = costs
        self.shorts = shorts

    def fit(self, disp=True):
        
        if disp:
            index = tqdm(self.stock_prices.index[self.lookback:])
        elif not disp:
            index = self.stock_prices.index[self.lookback:]
        
        for n, date in enumerate(index):
            past_prices = self.stock_prices.loc[:date][-self.lookback:]
            if n % self.rebalance == 0:
                opt = MaxSharpe(
                    past_prices, 
                    self.risk_free_rate,
                    shorts=self.shorts,
                )
                opt.optimize()
                self.weights.loc[date] = opt.weights
        
        self.stock_returns = self.stock_returns[self.lookback-1:]
        self.weights = self.weights.ffill().shift()[self.lookback:]
        self.returns, self.growth, self.transaction_costs, self.num_trades = self._calculate_returns(
            self.weights, self.stock_returns, self.costs
        )

    @staticmethod
    def _calculate_returns(weights, stock_returns, costs, name='max_sharpe_portfolio'):
        returns = (weights * stock_returns).sum(axis=1)
        returns.name = name
        growth = (1 + returns).cumprod() - 1
        signals = weights.diff()
        transaction_costs = costs * abs(signals).sum(axis=1)
        transaction_costs.name = 'transaction_costs'
        returns -= transaction_costs
        return returns, growth, transaction_costs, signals[1:].ne(0).sum().sum()

    @staticmethod
    def _init_weights_df(returns):
        return pd.DataFrame(columns=returns.columns, index=returns.index)