In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import random

In [None]:
#df = pd.read_csv("charts//EURUSD_M15.csv", sep="\t")

In [None]:
df = pd.read_csv("charts//EURUSD_H4.csv", sep="\t")

# Original csv's have wrong column names. This function fixes that. Not needed if you have properly labelled datasets
def csv_to_df(df):
    processed_df = df.copy()
    
    processed_df.reset_index(inplace=True)

    columns = ['index', 'Time', 'Open', 'High', 'Low', 'Close', 'Volume']

    for i in range(5, -1, -1):
        processed_df[columns[i+1]] = processed_df[columns[i]] # 1st Iteration: Move contents of Close column into Volume column
                                                              # 2nd Iteration: Move contents of Low column into Close column and so on
    #Drop the first column
    del processed_df["index"]
    return processed_df
    
def basic_tp_sl_function(signal, price, pip=0.0001, tp_pips=30, sl_pips=10):
    """
    Simple sl tp function that defines sl and tp a
    certain distance from a given price

    Args:
        signal: whether buy or sell
        price: the price to modify to set tp sl and tp from
        pip: price value of a pip
        tp_pips: number of pips between price and tp
        sl_pips: number of pips between price and sl
        
    Returns:
        tp: price at which tp was set
        sl: price at which sl was set
        tp_pips:
        sl_pips:
    """
    
    tp_dist = tp_pips * pip
    sl_dist = sl_pips * pip
    
    if signal > 0:  # Long trade
        tp = price + tp_dist
        sl = price- sl_dist
    else:  # Short trade
        tp = price - tp_dist
        sl = price + sl_dist
        
    return tp, sl, tp_pips, sl_pips

def custom_tp_sl_function(df, price, pip=0.0001):
    """
    Define a special function for calculating stop levels
    e.g
    
    lowest_low = df.iloc[:i]["Low"].min()
    highest_high = df.iloc[:i]["High"].max()

    if signal > 0:
        tp = highest_high
        sl = lowest_low
        tp_pips = int(round(tp - price, 5)/pip)
        tp_pips = int(round(price - sl, 5)/pip)

    if signal < 0:
        tp = lowest_low
        sl = highest_high
        tp_pips = int(round(price - tp, 5)/pip)
        tp_pips = int(round(sl - price, 5)/pip)

    return tp, sl, tp_pips, sl_pips
    """
    pass

def analyze_strategy(df, strategy_function, tp_sl_function, **strategy_params):
    """
    Analyzes price data using the provided strategy function and returns entry points.
    
    Args:
        df: DataFrame with OHLCV data
        strategy_function: Function that implements the trading strategy
        strategy_params: Additional parameters for the strategy
    
    Returns:
        list: List of dictionaries containing entry points
    """
    analysis_df = df.copy()
    entry_signals = strategy_function(analysis_df, **strategy_params)
        
    entry_points = []
    for i, signal in enumerate(entry_signals):
        if signal != 0 and i > 0:  
            
            close = df.iloc[i]['Close']  
            take_profit, stop_loss, take_profit_pips, stop_loss_pips = tp_sl_function(signal, close)
                
            entry_points.append({
                'date': df.iloc[i]['Time'],
                'price': close,  
                'type': 'long' if signal > 0 else 'short',
                'take_profit': take_profit,
                'stop_loss': stop_loss,
                'tp_pips': take_profit_pips,
                'sl_pips': stop_loss_pips
            })
    
    return entry_points

def sim_trade(df, capital, entry_dict):
    """Simulates a trade with a given capital.

    Args:
        df: price dataframe
        capital: money used to make the trade
        entry_dict: dict containing strategy entry info

    Returns:
        new_capital: new value of capital after trade
    """
    new_capital = capital
    
    lot = max((capital // 10) * 0.01, 0.01)
    spread_pips = random.uniform(0.2, 2.5)
    slippage_pips = random.uniform(0.0, 1.0)

    df_after_entry = df[df["Time"] > entry_dict["date"]]
    df2 = df_after_entry.copy()

    profit = None
    if entry_dict["type"] == "long":
    
        tp_hit = df2[df2["High"] >= entry_dict["take_profit"]]
        sl_hit = df2[df2["Low"] <= entry_dict["stop_loss"]]
    
        if not tp_hit.empty and not sl_hit.empty:
            profit = tp_hit.index[0] < sl_hit.index[0]
        elif not tp_hit.empty:
            profit = True
        elif not sl_hit.empty:
            profit = False

    elif entry_dict["type"] == "short":

        tp_hit = df2[df2["Low"] <= entry_dict["take_profit"]]
        sl_hit = df2[df2["High"] >= entry_dict["stop_loss"]]

        if not tp_hit.empty and not sl_hit.empty:
            profit = tp_hit.index[0] < sl_hit.index[0]
        elif not tp_hit.empty:
            profit = True
        elif not sl_hit.empty:
            profit = False

    if profit is True:
        new_capital += lot * (entry_dict["tp_pips"] - spread_pips - slippage_pips) * 10 # 10 being the value of a pip per lot in fx
    elif profit is None:
        new_capital = new_capital
    else:
        new_capital -= lot * (entry_dict["sl_pips"] + spread_pips + slippage_pips) * 10

    return new_capital


def plot_forex_chart(price_df, strategy_function, tp_sl_function, data_volume=100, capital=100, **strategy_params):
    """
    Plots a candlestick chart with strategy entry markers, take-profit/stop-loss levels,
    and an equity curve based on simulated trades.

    Args:
        price_df:
            DataFrame containing OHLCV price data. Must include at least
            ['Time', 'Open', 'High', 'Low', 'Close', 'Volume'] columns.
            if column names for your dataframe are lower case adjust code accordingly
        strategy_function:
            Function that takes a price DataFrame and returns a list/array of signals
            (e.g., 1 for long, -1 for short, 0 for no trade).
        tp_sl_function:
            Function that returns stop levels
        data_volume:
            Number of most recent candles to plot. Defaults to 100.
        capital:
            Starting account capital for equity simulation. Defaults to 100.
        **strategy_params:
            Additional keyword arguments to pass to `strategy_function`.

    Returns:
        Plotly figure object with two subplots:
            - Candlestick chart with entries and TP/SL lines.
            - Equity curve chart.

    Note: assumes simple trade-by-trade simulation without overlapping trades.
    """
    df = price_df.copy()
    try:
        df = csv_to_df(df)
        df['Time'] = pd.to_datetime(df['Time'])
    except Exception as e:
        print(f"Error converting data: {e}")
        return None
        
    df = df.tail(data_volume).reset_index(drop=True)
    #df = df.iloc[-(data_volume*2):-data_volume].reset_index(drop=True)

    entry_points = analyze_strategy(df, strategy_function, tp_sl_function, **strategy_params)

    # Plots for strategy entries and equity curve
    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.1,
        row_heights=[0.7, 0.3],
        subplot_titles=(f"{strategy_function.__name__}", "Equity Curve")
    )

    fig.update_layout(
        height=700,
        width=900
    )

    # === Plot candlestick chart ===
    fig.add_trace(go.Candlestick(
        x=df['Time'],
        open=df['Open'],
        high=df['High'],
        low=df['Low'],
        close=df['Close'],
        name='Price Action'),
        row=1, col=1
    )

    # === Equity Curve Variables ===
    equity = capital
    equity_times = [df['Time'].iloc[0]]  # start with first candlestick timestamp
    equity_values = [equity]  # starting capital

    for i, entry in enumerate(entry_points):
        marker_color = 'green' if entry['type'] == 'long' else 'red'
        marker_symbol = 'triangle-up' if entry['type'] == 'long' else 'triangle-down'

        # === Entry marker === 
        fig.add_trace(go.Scatter(
            x=[entry['date']],
            y=[entry['price']],
            mode='markers',
            marker=dict(
                symbol=marker_symbol,
                size=10,
                color=marker_color,
                line=dict(width=1, color='black')
            ),
            name=f"{entry['type'].capitalize()} Entry",
            text=f"{entry['type'].capitalize()} {i+1} entry at {entry['price']:.5f}\nTP: {entry['take_profit']:.5f}\nSL: {entry['stop_loss']:.5f}",
            hoverinfo='text',
            showlegend=False
        ), row=1, col=1)

        # === TP and SL Lines ===
        fig.add_trace(go.Scatter(
            x=[df['Time'].iloc[0], df['Time'].iloc[-1]],
            y=[entry['take_profit'], entry['take_profit']],
            mode='lines',
            line=dict(color='green', width=2, dash='dot'),
            name=f'TP {i+1} ({entry["take_profit"]:.5f})',
            legendgroup=f'trade_{i}',
            showlegend=True,
            visible='legendonly'
        ), row=1, col=1)

        fig.add_trace(go.Scatter(
            x=[df['Time'].iloc[0], df['Time'].iloc[-1]],
            y=[entry['stop_loss'], entry['stop_loss']],
            mode='lines',
            line=dict(color='red', width=2, dash='dot'),
            name=f'SL {i+1} ({entry["stop_loss"]:.5f})',
            legendgroup=f'trade_{i}',
            showlegend=True,
            visible='legendonly'
        ), row=1, col=1)
                    
        equity_times.append(entry["date"])
        equity = sim_trade(df, equity, entry)  # update capital
        equity_values.append(equity)

    # === Plot Equity Curve ===
    fig.add_trace(go.Scatter(
        x=equity_times,
        y=equity_values,
        mode='lines',  # remove markers
        line=dict(color='green', width=2, shape='spline'),
        fill='tozeroy',  # fill from the line to the x-axis
        fillcolor='rgba(0, 128, 0, 0.2)',  # optional: semi-transparent blue
        name='Equity Curve'
    ), row=2, col=1)

    # Remove weekend breaks
    fig.update_layout(
            xaxis=dict(
                rangeslider_visible=False,
                rangebreaks=[dict(bounds=["sat", "mon"])]  # Skip weekends
            ),
            legend=dict(
                orientation="v",
                yanchor="top",
                y=1,
                xanchor="left",
                x=1.01
            ),
            hovermode='closest'
        )


    print(f"Final capital from {df.head(1)["Time"].item()} to {df.tail(1)["Time"].item()}: {round(equity, 2)}")
    return fig


# Usage example:
if __name__ == "__main__":
    
    # Plot with example strategy
    fig = plot_forex_chart(
        df,
        strategy_function=trend_follow,
        tp_sl_function=basic_tp_sl_function,
        data_volume=150
    )


In [None]:
fig

In [None]:
def get_neighbor_candles(df, idx):
    return df.iloc[idx-1], df.iloc[idx], df.iloc[idx+1]

def is_bullish_fvg(c1, c2, c3):
    """
    Check if the 3 candles form a bullish fair value gap.
    
    Args:
        candles: List of candle data [index, open, high, low, close, volume]
        idx: Current candle index
    
    Returns:
        bool: True if bullish FVG is detected
    """
    
    # Bullish FVG conditions:
    return (c1.iloc[1] > c1.iloc[4] and   # First candle bearish (open > close)
            c2.iloc[4] > c2.iloc[1] and   # Middle candle bullish (close > open)
            c3.iloc[3] > c1.iloc[2])      # Last candle low > first candle high

def is_bearish_fvg(c1, c2, c3):
    """
    Check if the 3 candles form a bearish fair value gap.
    
    Args:
        candles: List of candle data [index, open, high, low, close, volume]
        idx: Current candle index
    
    Returns:
        bool: True if bearish FVG is detected
    """
    # Bearish FVG conditions:
    return (c1.iloc[1] < c1.iloc[4] and   # First candle bullish (open > close)
            c2.iloc[4] < c2.iloc[1] and   # Middle candle bearish (close > open)
            c1.iloc[3] > c3.iloc[2])      # First candle low > last candle high

def get_fvg_midpoint(c1, c3):
    """
    Calculate the midpoint of the fair value gap.
    
    Args:
        candles: List of candle data
    
    Returns:
        float: Midpoint price level
    """
    if c1.iloc[4] > c3.iloc[4]:
        return (c1[3] + c3[2]) / 2
    else:
        return (c1[3] + c3[2]) / 2

def fvg_strategy(df):
    """
    Simple FVG (Fair Value Gap) strategy implementation.
    
    Args:
        df: DataFrame with OHLCV data
    
    Returns:
        List[int]: Signals where:
            1 = bullish FVG (buy signal)
            -1 = bearish FVG (sell signal)
            0 = no signal
    """
    
    # Initialize signals array with zeros
    signals = np.zeros(len(df))
    
    # Look for FVG patterns
    for i in range(2, len(df)-1):
        
        c1,c2,c3 = get_neighbor_candles(df, i)
        
        if is_bullish_fvg(c1, c2, c3):
            mid = get_fvg_midpoint(c1, c3)
            df_after_fvg = df.iloc[i+2:]
            df2 = df_after_fvg[df_after_fvg["Low"] <= mid]
            if not df2.empty:
                entry = df2.index[0]
                signals[entry] = 1
                
        elif is_bearish_fvg(c1, c2, c3):
            mid = get_fvg_midpoint(c1, c3)
            df_after_fvg = df.iloc[i+2:]
            df2 = df_after_fvg[df_after_fvg["High"] >= mid]
            if not df2.empty:
                entry = df2.index[0]
                signals[entry] = -1
    
    return signals

In [None]:
def trend_follow(df, consecutive=2):
    """
    Simple consecutive-candle engulfing strategy.
    
    Args:
        df (pd.DataFrame): DataFrame containing OHLCV data with columns:
            'Open', 'High', 'Low', 'Close', and 'Volume'.
        consecutive (int): Number of consecutive candles to check for
                           bullish (uptrend) or bearish (downtrend) moves.
                           
    Returns:
        np.ndarray: Signals array where:
            1 = buy signal (consecutive bullish candles)
            -1 = sell signal (consecutive bearish candles)
            0 = no signal
    """
    
    # Initialize signals array
    signals = np.zeros(len(df))
    
    bullish_count = 0
    bearish_count = 0

    for i in range(1, len(df)):
        # Bullish candle
        if df["Close"].iloc[i] > df["Open"].iloc[i]:
            bullish_count += 1
            bearish_count = 0
        # Bearish candle
        elif df["Close"].iloc[i] < df["Open"].iloc[i]:
            bearish_count += 1
            bullish_count = 0
        else:
            # Doji or neutral candle resets counts
            bullish_count = 0
            bearish_count = 0

        # Check for bullish streak
        if bullish_count >= consecutive:
            signals[i] = 1
            #bullish_count = 0  # reset after signal

        # Check for bearish streak
        elif bearish_count >= consecutive:
            signals[i] = -1
            #bearish_count = 0  # reset after signal
    
    return signals


## Custom Javascript

In [None]:
# for entry in entry_points:
#         marker_color = 'green' if entry['type'] == 'long' else 'red'
#         marker_symbol = 'triangle-up' if entry['type'] == 'long' else 'triangle-down'
        
#         # Create shapes for TP and SL lines
#         tp_line = dict(
#             type='line',
#             x0=entry['date'],
#             x1=df['Time'].iloc[-1],  # Extend to end of chart
#             y0=entry['take_profit'],
#             y1=entry['take_profit'],
#             line=dict(
#                 color='green',
#                 width=1,
#                 dash='dash'
#             ),
#             visible=False  # Hidden by default
#         )
        
#         sl_line = dict(
#             type='line',
#             x0=entry['date'],
#             x1=df['Time'].iloc[-1],  # Extend to end of chart
#             y0=entry['stop_loss'],
#             y1=entry['stop_loss'],
#             line=dict(
#                 color='red',
#                 width=1,
#                 dash='dash'
#             ),
#             visible=False  # Hidden by default
#         )
        
#         # Add the shapes to the layout
#         fig.add_shape(tp_line)
#         fig.add_shape(sl_line)
        
#         # Add the entry point marker with hover events
#         fig.add_trace(go.Scatter(
#             x=[entry['date']],
#             y=[entry['price']],
#             mode='markers',
#             marker=dict(
#                 symbol=marker_symbol,
#                 size=10,
#                 color=marker_color,
#                 line=dict(width=1, color='black')
#             ),
#             name=f"{entry['type'].capitalize()} Entry",
#             text=(f"{entry['type'].capitalize()} entry at {entry['price']:.5f}\n"
#                   f"Take Profit: {entry['take_profit']:.5f}\n"
#                   f"Stop Loss: {entry['stop_loss']:.5f}"),
#             hoverinfo='text',
#             customdata=[len(fig.layout.shapes)-2, len(fig.layout.shapes)-1]  # Store indices of associated TP/SL lines
#         ))

#     # Add hover events using JavaScript
#     fig.update_layout(
#         hovermode='closest',
#         # Add JavaScript callbacks for hover events
#         updatemenus=[],  # Required for the JavaScript to work
#     )

#     # Add JavaScript for hover functionality
#     fig.add_annotation(
#         dict(
#             text="",
#             showarrow=False,
#             textangle=0,
#             xref='paper', yref='paper',
#             x=0, y=0,
#             # JavaScript for hover events
#             hovertemplate="",
#             hoverlabel=dict(
#                 font=dict(
#                     family="Courier New, monospace",
#                     size=16,
#                     color="white"
#                 )
#             )
#         )
#     )

#     # Add JavaScript to handle hover events
#     fig.update_layout(
#         newshape=dict(line_width=1),
#         paper_bgcolor='rgba(0,0,0,0)',
#         plot_bgcolor='rgba(0,0,0,0)',
#     )

#     custom_js = """
#     var gd = document.getElementById('{plot_div}');
#     var previousTraceIndex = null;

#     gd.on('plotly_hover', function(data) {
#         var traceIndex = data.points[0].curveNumber;
#         var customdata = data.points[0].customdata;
        
#         if (customdata && customdata.length === 2) {
#             var tpIndex = customdata[0];
#             var slIndex = customdata[1];
            
#             Plotly.relayout(gd, {
#                 [`shapes[${tpIndex}].visible`]: true,
#                 [`shapes[${slIndex}].visible`]: true
#             });
#         }
#     });

#     gd.on('plotly_unhover', function(data) {
#         // Hide all shapes when unhovering
#         var updates = {};
#         gd.layout.shapes.forEach((shape, i) => {
#             updates[`shapes[${i}].visible`] = false;
#         });
#         Plotly.relayout(gd, updates);
#     });
#     """

#     # Add the custom JavaScript to the layout config
#     fig.update_layout(
#         newshape_line_width=1,
#     )

#     # Save the plot
#     #output_file = f"forex_strategy_{symbol.replace('=', '_')}.html"
#     #fig.write_html(output_file)
#     #print(f"Chart saved as {output_file}")
    
#     return fig