<a href="https://colab.research.google.com/github/StefanRaduMaris/machine-learning-repo/blob/main/notebook/Automated_prediction_day_trading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [67]:
import pandas as pd
import pandas_ta as ta
import numpy as np
import plotly.graph_objects as go
from scipy import stats
from neuralprophet import NeuralProphet
import yfinance as yf
from tqdm import tqdm

df = pd.read_csv("https://raw.githubusercontent.com/JasonZhangjc/ml_algo_trading/refs/heads/main/EURUSD_Candlestick_1_Hour_BID_04.05.2003-15.04.2023.csv")


In [68]:
df.columns = [col.capitalize() for col in df.columns]
df=df[df['Volume']!=0]

df.reset_index(drop=True, inplace=True)

df['RSI'] = ta.rsi(df.Close, length=12)
df['EMA'] = ta.ema(df.Close, length=150)
df['ATR'] = ta.atr(df.High, df.Low, df.Close, length=14)


In [69]:
def emea_signal(df):
  backcandles = 15
  win = backcandles + 1

  above = (np.minimum(df['Open'],  df['Close']) > df['EMA']).astype(int)
  below = (np.maximum(df['Open'],  df['Close']) < df['EMA']).astype(int)

  # ------------------------------------------------------------------ #
  # Rolling “all‑true” test with a fast sum
  # ------------------------------------------------------------------ #
  #    upt  == 1  ⇔ every candle in the window is above the EMA
  #    dnt  == 1  ⇔ every candle in the window is below the EMA
  upt = (above.rolling(win, min_periods=win).sum() == win)
  dnt = (below.rolling(win, min_periods=win).sum() == win)

  signal = np.zeros(len(df), dtype=int)
  signal[ upt  &  dnt] = 3          # (theoretically rare / flat on EMA)
  signal[ upt  & ~dnt] = 2          # all‑above window
  signal[ dnt  & ~upt] = 1          # all‑below window

  df['EMASignal'] = signal
  return df
emea_signal(df)


Unnamed: 0,Gmt time,Open,High,Low,Close,Volume,RSI,EMA,ATR,EMASignal
0,04.05.2003 21:00:00.000,1.12284,1.12338,1.12242,1.12305,2.905910e+07,,,,0
1,04.05.2003 22:00:00.000,1.12274,1.12302,1.12226,1.12241,2.609180e+07,0.000000,,,0
2,04.05.2003 23:00:00.000,1.12235,1.12235,1.12160,1.12169,2.924090e+07,0.000000,,,0
3,05.05.2003 00:00:00.000,1.12161,1.12314,1.12154,1.12258,2.991480e+07,11.120367,,,0
4,05.05.2003 01:00:00.000,1.12232,1.12262,1.12099,1.12140,2.837070e+07,9.579569,,,0
...,...,...,...,...,...,...,...,...,...,...
123836,14.04.2023 16:00:00.000,1.09789,1.09851,1.09722,1.09775,1.429995e+07,20.092619,1.096424,0.001677,2
123837,14.04.2023 17:00:00.000,1.09775,1.09901,1.09752,1.09871,9.740260e+06,27.485921,1.096455,0.001664,2
123838,14.04.2023 18:00:00.000,1.09871,1.09989,1.09871,1.09988,9.199190e+06,35.429055,1.096500,0.001629,2
123839,14.04.2023 19:00:00.000,1.09987,1.09992,1.09921,1.09964,4.669030e+06,34.581386,1.096542,0.001563,2


In [70]:
def mark_pivots(df: pd.DataFrame, window: int, high_col='High', low_col='Low') -> pd.Series:
    """
    Vectorized replacement for `isPivot()`.

    Parameters
    ----------
    df        : DataFrame that contains OHLC columns.
    window    : int
        Number of bars to look *before* and *after* a candle.
        (Total span = 2*window + 1)
    high_col  : str, default 'High'
    low_col   : str,  default 'Low'

    Returns
    -------
    pd.Series (dtype=int)
        0 = not a pivot
        1 = pivot high
        2 = pivot low
        3 = both high & low  (flat bar inside a very tight range)
    """
    span = 2 * window + 1                       # total length of the centred window

    # Rolling extrema, centred on each bar
    roll_max = df[high_col].rolling(span, center=True).max()
    roll_min = df[low_col].rolling(span,  center=True).min()

    # Boolean masks (NaNs on the edges stay False)
    pivot_high = (df[high_col] >= roll_max) & roll_max.notna()
    pivot_low  = (df[low_col]  <= roll_min) & roll_min.notna()

    # Encode: 1 = high, 2 = low, 3 = both
    pivots = np.zeros(len(df), dtype=int)
    pivots[pivot_high] += 1
    pivots[pivot_low]  += 2

    return pd.Series(pivots, index=df.index, name='isPivot')


window = 7
df['isPivot'] = mark_pivots(df, window)


In [71]:
def pointpos(x):
    if x['isPivot']==2:
        return x['Low']-1e-3
    elif x['isPivot']==1:
        return x['High']+1e-3
    else:
        return np.nan
df['pointpos'] = df.apply(lambda row: pointpos(row), axis=1)

In [72]:
dfpl = df[300:450]
fig = go.Figure(data=[go.Candlestick(x=dfpl.index,
                open=dfpl['Open'],
                high=dfpl['High'],
                low=dfpl['Low'],
                close=dfpl['Close'])])

fig.add_scatter(x=dfpl.index, y=dfpl['pointpos'], mode="markers",
                marker=dict(size=5, color="MediumPurple"),
                name="pivot")
fig.update_layout(xaxis_rangeslider_visible=False)
fig.show()

In [73]:
def detect_structure(candle: int,
                     backcandles: int,
                     window: int
                     ) -> tuple[int, int | None]:
    """
    Return a signal only on the *first* candle that breaks the key level.

    2  -> bullish breakout (PH → PL₂ < PL₁ → first close > PH.High)
    1  -> bearish breakout (PL → PH₂ > PH₁ → first close < PL.Low)
    0  -> no pattern
    """
    # ----------------------------------------------------------- #
    # 0) History guard
    # ----------------------------------------------------------- #
    if candle - backcandles < 0:
        return 0, None

    prev_bar = candle - 1
    if prev_bar < 0:                       # no previous bar to compare
        return 0, None

    # ----------------------------------------------------------- #
    # 1) Data slices
    # ----------------------------------------------------------- #
    price_df = df.iloc[candle - backcandles : candle]        # exclude current
    close_now     = df.loc[candle, 'Close']
    close_prev    = df.loc[prev_bar, 'Close']

    # confirmed pivots: idx ≤ candle - window
    piv_df = df.iloc[candle - backcandles : candle - window + 1]

    # ----------------------------------------------------------- #
    # 2) Bullish side (PH → PL₂ → first break up)
    # ----------------------------------------------------------- #
    ph_df = piv_df[piv_df['isPivot'] == 1]
    if not ph_df.empty:
        ph_idx = ph_df.index[-1]
        ph_val = ph_df.loc[ph_idx, 'High']

        # ensure previous candle did NOT already break the level
        if close_now > ph_val and close_prev <= ph_val:

            pl_before = piv_df.loc[:ph_idx - 1]
            pl_before = pl_before[pl_before['isPivot'] == 2]

            pl_after  = piv_df.loc[ph_idx + 1:]
            pl_after  = pl_after[pl_after['isPivot'] == 2]

            if not pl_before.empty and not pl_after.empty:
                pl1_val = pl_before.iloc[-1]['Low']     # last PL before PH
                pl2_val = pl_after['Low'].min()         # any PL after PH

                if pl2_val < pl1_val:                   # sweep condition
                    return 2, ph_idx

    # ----------------------------------------------------------- #
    # 3) Bearish side (PL → PH₂ → first break down)
    # ----------------------------------------------------------- #
    pl_df = piv_df[piv_df['isPivot'] == 2]
    if not pl_df.empty:
        pl_idx = pl_df.index[-1]
        pl_val = pl_df.loc[pl_idx, 'Low']

        if close_now < pl_val and close_prev >= pl_val:

            ph_before = piv_df.loc[:pl_idx - 1]
            ph_before = ph_before[ph_before['isPivot'] == 1]

            ph_after  = piv_df.loc[pl_idx + 1:]
            ph_after  = ph_after[ph_after['isPivot'] == 1]

            if not ph_before.empty and not ph_after.empty:
                ph1_val = ph_before.iloc[-1]['High']    # last PH before PL
                ph2_val = ph_after['High'].max()        # any PH after PL

                if ph2_val > ph1_val:
                    return 1, pl_idx

    # ----------------------------------------------------------- #
    return 0, None

In [74]:
df['breakout_signal'] = 0
df['pivot_ref_idx'] = None

start_index = 0
end_index = len(df)

for candle in tqdm(range(start_index, end_index)):
    signal, ref_idx = detect_structure(candle, backcandles=40, window=5)
    df.at[candle, 'breakout_signal'] = signal
    df.at[candle, 'pivot_ref_idx'] = ref_idx

100%|██████████| 123841/123841 [02:25<00:00, 853.44it/s]


In [75]:
df['breakout_signal'] = df.apply(lambda row: row['breakout_signal'] if row['EMASignal'] == row['breakout_signal'] else 0, axis=1)
df[df['breakout_signal']!=0].head(50)

Unnamed: 0,Gmt time,Open,High,Low,Close,Volume,RSI,EMA,ATR,EMASignal,isPivot,pointpos,breakout_signal,pivot_ref_idx
283,20.05.2003 16:00:00.000,1.16516,1.1697,1.16457,1.16845,293631300.0,60.633173,1.154431,0.002944,2,0,,2,273
881,24.06.2003 15:00:00.000,1.15416,1.15502,1.15111,1.15174,25684600.0,34.663327,1.164931,0.002241,1,0,,1,873
2865,17.10.2003 12:00:00.000,1.15878,1.15889,1.15529,1.15627,32252500.0,37.480996,1.166454,0.002599,1,0,,1,2836
3395,18.11.2003 15:00:00.000,1.18135,1.19099,1.18128,1.18928,31769300.0,75.707035,1.169023,0.003028,2,0,,2,3366
3708,05.12.2003 16:00:00.000,1.21387,1.21673,1.21352,1.2154,292903700.0,69.795227,1.202619,0.00262,2,0,,2,3682
3710,05.12.2003 18:00:00.000,1.21501,1.21673,1.21453,1.21582,145868400.0,69.649423,1.202955,0.002583,2,0,,2,3682
4017,24.12.2003 13:00:00.000,1.24029,1.24265,1.24,1.24258,29218100.0,63.542086,1.237101,0.001506,2,0,,2,3989
4084,29.12.2003 08:00:00.000,1.24423,1.24738,1.2442,1.24695,31422900.0,67.301716,1.241539,0.001531,2,0,,2,4056
4757,05.02.2004 09:00:00.000,1.2548,1.25919,1.25429,1.2581,33223900.0,72.63195,1.251257,0.002014,2,0,,2,4734
4860,11.02.2004 16:00:00.000,1.27007,1.28364,1.27007,1.28162,296074100.0,74.49198,1.263976,0.003497,2,0,,2,4849


In [76]:
import plotly.graph_objects as go

def plot_breakouts_with_candles(df, start_idx=None, end_idx=None):
    """
    Plots candlestick chart with breakout markers and pivot reference lines.

    Requirements in df:
    - 'Open', 'High', 'Low', 'Close'
    - 'breakout_signal' column (2 = bullish, 1 = bearish, 0 = no signal)
    - 'pivot_ref_idx' column (index of the pivot candle used for the breakout)
    """

    # Slice the DataFrame
    if start_idx is None:
        start_idx = df.index[0]
    if end_idx is None:
        end_idx = df.index[-1]
    dfpl = df.loc[start_idx:end_idx].copy()

    # Create figure
    fig = go.Figure()

    # Candlestick chart
    fig.add_trace(go.Candlestick(
        x=dfpl.index,
        open=dfpl["Open"],
        high=dfpl["High"],
        low=dfpl["Low"],
        close=dfpl["Close"],
        name="Candles"
    ))

    # Loop through breakout signals
    for idx, row in dfpl.iterrows():
        signal = row.get("breakout_signal", 0)
        pivot_idx = row.get("pivot_ref_idx", None)

        if signal == 2 and pivot_idx in df.index:
            # Bullish breakout
            y_marker = row["Low"] - (row["High"] - row["Low"]) * 0.02
            fig.add_trace(go.Scatter(
                x=[idx],
                y=[y_marker],
                mode="markers",
                marker=dict(symbol="triangle-up", color="white", size=20),
                name="Bullish Breakout"
            ))
            # Reference line from pivot high to current candle
            pivot_high = df.loc[pivot_idx, "High"]
            fig.add_shape(
                type="line",
                x0=pivot_idx,
                x1=idx,
                y0=pivot_high,
                y1=pivot_high,
                line=dict(color="purple", dash="dash")
            )

        elif signal == 1 and pivot_idx in df.index:
            # Bearish breakout
            y_marker = row["High"] + (row["High"] - row["Low"]) * 0.02
            fig.add_trace(go.Scatter(
                x=[idx],
                y=[y_marker],
                mode="markers",
                marker=dict(symbol="triangle-down", color="white", size=20),
                name="Bearish Breakout"
            ))
            # Reference line from pivot low to current candle
            pivot_low = df.loc[pivot_idx, "Low"]
            fig.add_shape(
                type="line",
                x0=pivot_idx,
                x1=idx,
                y0=pivot_low,
                y1=pivot_low,
                line=dict(color="purple", dash="dash")
            )

    # Layout styling
    fig.update_layout(
        title="Candlestick Chart with Breakout Signals",
        width=1200,
        height=800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        xaxis=dict(showgrid=False),
        yaxis=dict(showgrid=False)
    )

    fig.show()

In [77]:
plot_breakouts_with_candles(df, start_idx=3650, end_idx=3780)

#Test for paladium present

In [86]:
df1= yf.download("PA=F", period="1y", interval="1h")
df1.columns = [col[0] if isinstance(col, tuple) else col for col in df1.columns]

df1=df1.reset_index()
df1=df1.rename(columns={'Datetime':'Gmt time'})
df1['RSI'] = ta.rsi(df1.Close, length=12)
df1['EMA'] = ta.ema(df1.Close, length=150)
df1['ATR'] = ta.atr(df1.High, df1.Low, df1.Close, length=14)
df1.head()



YF.download() has changed argument auto_adjust default to True



YF.download() has changed argument auto_adjust default to True


[*********************100%***********************]  1 of 1 completed


Unnamed: 0,Gmt time,Close,High,Low,Open,Volume,RSI,EMA,ATR
0,2024-12-03 12:00:00+00:00,997.0,1001.0,997.0,1001.0,0,,,
1,2024-12-03 13:00:00+00:00,998.5,1003.5,995.0,996.5,262,100.0,,
2,2024-12-03 14:00:00+00:00,994.0,1002.5,988.5,998.5,496,78.571429,,
3,2024-12-03 15:00:00+00:00,983.0,993.5,981.5,993.5,474,50.0,,
4,2024-12-03 16:00:00+00:00,987.0,987.5,980.0,983.0,346,56.303349,,


In [87]:
emea_signal(df1)
window = 7
df1['isPivot'] = mark_pivots(df1, window)
df1['pointpos'] = df1.apply(lambda row: pointpos(row), axis=1)


In [88]:
dfpl = df1[300:450]
fig = go.Figure(data=[go.Candlestick(x=dfpl.index,
                open=dfpl['Open'],
                high=dfpl['High'],
                low=dfpl['Low'],
                close=dfpl['Close'])])

fig.add_scatter(x=dfpl.index, y=dfpl['pointpos'], mode="markers",
                marker=dict(size=5, color="MediumPurple"),
                name="pivot")
fig.update_layout(xaxis_rangeslider_visible=False)
fig.show()

In [89]:
df1['breakout_signal'] = 0
df1['pivot_ref_idx'] = None
start_index = 0
end_index = len(df1)

for candle in tqdm(range(start_index, end_index)):
    signal, ref_idx = detect_structure(candle, backcandles=40, window=5)
    df1.at[candle, 'breakout_signal'] = signal
    df1.at[candle, 'pivot_ref_idx'] = ref_idx

100%|██████████| 5739/5739 [00:06<00:00, 937.74it/s]


In [90]:
df1['breakout_signal'] = df1.apply(lambda row: row['breakout_signal'] if row['EMASignal'] == row['breakout_signal'] else 0, axis=1)
df1[df1['breakout_signal']!=0].head(50)

Unnamed: 0,Gmt time,Close,High,Low,Open,Volume,RSI,EMA,ATR,EMASignal,isPivot,pointpos,breakout_signal,pivot_ref_idx
662,2025-01-16 09:00:00+00:00,969.5,972.0,967.0,968.0,75,53.082222,950.024872,4.880553,2,0,,2,633
906,2025-02-03 00:00:00+00:00,1060.0,1074.5,1055.5,1069.0,143,63.772962,1000.658919,10.417394,2,0,,2,876
2721,2025-05-27 01:00:00+00:00,983.5,991.5,982.5,991.5,103,27.175967,999.88665,4.170352,1,0,,1,2696
2842,2025-06-03 12:00:00+00:00,1011.5,1012.5,1001.5,1004.5,527,76.493828,986.504989,5.231583,2,0,,2,2833
2874,2025-06-04 22:00:00+00:00,1005.0,1005.5,1004.0,1005.0,0,41.422169,996.614858,5.278447,2,0,,2,2843
3227,2025-06-26 08:00:00+00:00,1117.0,1117.5,1111.0,1114.0,294,70.036399,1070.744295,8.12179,2,0,,2,3200
3531,2025-07-16 19:00:00+00:00,1288.0,1288.0,1283.0,1285.0,3969,68.671209,1216.662984,11.314589,2,0,,2,3508
3947,2025-08-12 00:00:00+00:00,1160.5,1163.5,1158.0,1158.0,40,59.995615,1179.329251,7.290143,1,0,,1,3922
4069,2025-08-19 07:00:00+00:00,1131.5,1136.0,1127.0,1127.5,259,56.64622,1143.897889,5.300033,1,0,,1,4048
4467,2025-09-15 01:00:00+00:00,1218.5,1223.0,1213.5,1223.0,427,31.549433,1196.021741,11.118919,2,2,1213.499,2,4443


In [93]:
df1.shape

(5739, 14)

In [95]:
plot_breakouts_with_candles(df1, start_idx=0, end_idx=len(df1))