In [1]:
!pip install pandas numpy matplotlib

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

class TrendFollowingStrategy:
    def __init__(self, data):
        self.data = data

    def calculate_ema(self, window):
        return self.data['Close'].ewm(span=window, adjust=False).mean()

    def calculate_sma(self, window):
        return self.data['Close'].rolling(window=window).mean()

    def generate_signals(self, short_window, long_window):
        self.data['ShortEMA'] = self.calculate_ema(short_window)
        self.data['LongEMA'] = self.calculate_ema(long_window)
        self.data['Signal'] = 0.0
        self.data['Signal'][short_window:] = np.where(
            self.data['ShortEMA'][short_window:] > self.data['LongEMA'][short_window:], 1.0, 0.0
        )
        self.data['Position'] = self.data['Signal'].diff()

    def plot_signals(self):
        plt.figure(figsize=(14, 7))
        plt.plot(self.data['Close'], label='Close Price')
        plt.plot(self.data['ShortEMA'], label='Short EMA', alpha=0.7)
        plt.plot(self.data['LongEMA'], label='Long EMA', alpha=0.7)
        plt.scatter(
            self.data.index, self.data['Position'] == 1, label='Buy Signal', marker='^', alpha=1, color='green'
        )
        plt.scatter(
            self.data.index, self.data['Position'] == -1, label='Sell Signal', marker='v', alpha=1, color='red'
        )
        plt.title('Trend Following Strategy Signals')
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.legend()
        plt.show()

if __name__ == '__main__':
    # Sample data (replace with your own data)
    data = pd.DataFrame({
        'Close': np.random.rand(100) * 100
    }, index=pd.date_range(start='2022-01-01', periods=100))

    strategy = TrendFollowingStrategy(data)
    strategy.generate_signals(short_window=20, long_window=50)
    strategy.plot_signals()