In [None]:
import pandas as pd
from sklearn.pipeline import Pipeline

from src import (
    Dataset, Model,
    resample_data_by_10min,
    generate_full_data,
    filter_nan_days,
    encode_datetime,
    merge_external,
    create_samples,
    post_process,
    evaluate
)

In [None]:
class CustomDataset(Dataset):
    def pre_process(self, train_data, test_data, target):
        dataset = {}

        dataset["train"] = (
            generate_full_data(train_data, start_time="07:00", end_time="16:59")
            .pipe(resample_data_by_10min)
            .pipe(filter_nan_days)
            .pipe(merge_external, external_file="data/10min.csv")
            .pipe(encode_datetime)
            .pipe(create_samples, subtract_prev=True, flatten=True)
        )

        if test_data is not None and target is not None:
            dataset["test"] = (
                generate_full_data(test_data, start_time="07:00", end_time="16:59")
                .pipe(resample_data_by_10min)
                .pipe(merge_external, external_file="data/10min.csv")
                .pipe(encode_datetime)
                .pipe(create_samples, target=target, subtract_prev=True, flatten=True)
            )

        return dataset

dataset = CustomDataset(
    train_file="./data/train.csv",
    test_file="./data/test.csv",
    target_file="./data/target.csv",
)
print(dataset)

In [None]:
pipeline = Pipeline([
    ("MLP", Model("MLP", epochs=1000))
])
pipeline.fit(**dataset["train"])
predictions = pipeline.predict(dataset["test"]["X"])
predictions

In [None]:
target = pd.read_csv("data/target.csv")

predictions += dataset["test"]["prev_y"]
target["答案"] = post_process(predictions)

target.to_csv("./data/predictions.csv", index=False)
evaluate(target_file="data/target.csv", prediction_file="./data/predictions.csv")


# ⏮️ Previous day

In [None]:
target = pd.read_csv("data/target.csv")
target["答案"] = dataset["test"]["prev_y"].ravel()
target.to_csv("./data/predictions.csv", index=False)
evaluate(target_file="data/target.csv", prediction_file="./data/predictions.csv")