In [1]:
from datasets import ForexPricePredictionDataset
import mplfinance as mpf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [2]:
dataset = ForexPricePredictionDataset("./data/USDJPY_H1.csv", header=0, data_order="ohlc", input_duration=20, output_duration=1, normalize=False)

In [3]:
data = dataset._data

In [4]:
data["Time"] = pd.to_datetime(data["Time"])
data = data.set_index("Time")

In [5]:
data

Unnamed: 0_level_0,Open,High,Low,Close,Volume
Time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2007-03-13 16:00:00,116.145,116.376,116.145,116.165,14746
2007-03-13 17:00:00,116.180,116.208,115.986,116.122,13243
2007-03-13 18:00:00,116.121,116.247,116.075,116.227,15540
2007-03-13 19:00:00,116.208,116.231,116.080,116.107,13368
2007-03-13 20:00:00,116.096,116.170,116.040,116.140,14143
...,...,...,...,...,...
2023-04-07 07:00:00,132.122,132.175,132.098,132.145,3527
2023-04-07 08:00:00,132.147,132.208,132.099,132.187,2978
2023-04-07 09:00:00,132.176,132.251,132.110,132.233,4661
2023-04-07 10:00:00,132.227,132.271,132.151,132.227,1805


In [6]:
def savecandle(data, root, label, name):
    fig,ax = mpf.plot(data.head(30), type='candle', returnfig=True, scale_padding=0, style='charles')
    ax[0].set_axis_off()
    if label == 1:
        path = f"./{root}/up/"
    elif label == 2:
        path = f"./{root}/down/"
    else:
        path = f"./{root}/stationary/"
    fig.savefig(path + name + ".png", pad_inches=0)
    ax[0].cla()
    ax[1].cla()

In [84]:
def saveall(data, root):
    for i in range(31, len(data.index), 15):
        d = data[i-31:i]
        y = d.iloc[30]
        x = d.iloc[:30]

        if y["Close"] - x["Close"][-1] > 0.03:
            label = 1
        elif y["Close"] - x["Close"][-1] < -0.03:
            label = 2
        else:
            label = 0
        savecandle(x, root, label, str(x.iloc[0].name))
        plt.figure().clear()
        plt.close('all')
        plt.cla()
        plt.clf()


In [85]:
train, test = train_test_split(data, test_size=0.2, random_state=42, shuffle=False)

In [86]:
train[(train["Open"].diff() > -0.03) & (train["Open"].diff() < 0.03)]

Unnamed: 0_level_0,Open,High,Low,Close,Volume
Time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2007-03-13 21:00:00,116.125,116.125,115.752,115.787,14068
2007-03-14 12:00:00,117.041,117.082,116.972,117.060,13714
2007-03-14 13:00:00,117.046,117.390,117.026,117.271,14954
2007-03-14 15:00:00,117.263,117.263,117.076,117.166,13162
2007-03-14 18:00:00,117.222,117.274,117.140,117.189,13714
...,...,...,...,...,...
2020-01-24 02:00:00,109.563,109.627,109.561,109.611,2943
2020-01-24 04:00:00,109.618,109.622,109.520,109.534,3164
2020-01-24 07:00:00,109.442,109.467,109.319,109.326,9351
2020-01-24 09:00:00,109.319,109.324,109.222,109.257,7988


In [87]:
train[train["Open"].diff() > 0.03]

Unnamed: 0_level_0,Open,High,Low,Close,Volume
Time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2007-03-13 17:00:00,116.180,116.208,115.986,116.122,13243
2007-03-13 19:00:00,116.208,116.231,116.080,116.107,13368
2007-03-13 23:00:00,116.035,116.083,115.950,115.995,14759
2007-03-14 01:00:00,116.123,116.329,116.090,116.165,14018
2007-03-14 02:00:00,116.162,116.408,116.140,116.231,14502
...,...,...,...,...,...
2020-01-23 15:00:00,109.566,109.576,109.508,109.534,4031
2020-01-23 20:00:00,109.542,109.549,109.512,109.547,1395
2020-01-24 00:00:00,109.634,109.642,109.563,109.569,4921
2020-01-24 03:00:00,109.611,109.625,109.587,109.618,2122


In [88]:
train[train["Open"].diff() < -0.03]

Unnamed: 0_level_0,Open,High,Low,Close,Volume
Time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2007-03-13 18:00:00,116.121,116.247,116.075,116.227,15540
2007-03-13 20:00:00,116.096,116.170,116.040,116.140,14143
2007-03-13 22:00:00,115.802,116.051,115.789,116.031,13804
2007-03-14 00:00:00,116.001,116.270,115.968,116.116,13496
2007-03-14 05:00:00,116.534,116.600,116.305,116.501,14497
...,...,...,...,...,...
2020-01-24 01:00:00,109.570,109.589,109.521,109.563,3977
2020-01-24 05:00:00,109.533,109.560,109.430,109.459,6056
2020-01-24 06:00:00,109.459,109.467,109.316,109.441,10123
2020-01-24 08:00:00,109.325,109.354,109.246,109.319,8480


In [90]:
saveall(train, "images/train")
saveall(test, "images/test")

<Figure size 640x480 with 0 Axes>