In [1]:
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from utils import generate_vpvr
import torch


class HourlyDataset(Dataset):
    # metadata:
    # - vpvr (on prev week)

    # - prev weekly high
    # - prev weekly low
    # - prev weekly open
    # - prev weekly close
    # - monday open
    # - monday high
    # - monday low
    # - monday close

    # - bias vector
    # - 2 category (weekend, weekday{\Monday})
    def __init__(self, ltf_data, window_size, split=None):
        super(HourlyDataset, self).__init__()
        # at the end of init:
        # ltf_data:
        # ltf_candles:
        #    ['open', 'high', 'low', 'close', 'volume'] 'time' 'week_idx' 'usable' 'weekday'
        # weekly_data:
        #    ['pw_high', 'pw_low', 'pw_open', 'pw_close', 'mon_open', 'mon_high', 'mon_low', 'mon_close', 'vpvr']
        # usable_ltf_candles:
        #    ["week_idx", "weekday"]

        self.split = split
        self.ltf_data = ltf_data.copy()
        self.ltf_candles = self.ltf_data[["open", "high", "low", "close", "volume", "time"]].copy()
        self.labels = self.ltf_data[["label", "value"]].copy()

        self.ltf_candles = self._add_week_idx(self.ltf_candles)
        self.weekly_data = self._extract_weekly_info(self.ltf_candles)  # also adds "week_idx" to ltf_candles
        self.window_size = window_size

        self._handle_weekly_none()                      # for weekly
        self._mark_usable_or_not()                      # for ltf_candles
        self._guarantee_overlap_weekly_and_candles()    # for weekly and ltf_candles

        self._control_indices(self.ltf_candles)
        self.usable_ltf_candles = self.ltf_candles[self.ltf_candles.usable == 1][["week_idx", "weekday"]].copy()
        self.usable_ltf_candles = self._handle_consistent_split(self.usable_ltf_candles, self.split)  # handle split

        self._drop_unnecessary_and_order_ltf()  # for ltf_candles
        self._order_weekly_data()               # for weekly
        self.ltf_candles = self.ltf_candles.to_numpy()  # for efficiency

    def get_base_datetime(self):
        return self.ltf_data.iloc[0].time

    def time_price_reaches(self, start_time, price):
        subset = self.ltf_data[(self.ltf_data['time'] > start_time) &
                               ((self.ltf_data['low'] <= price) &
                                (self.ltf_data['high'] >= price))]
        return subset.iloc[0].time if not subset.empty else None

    def __len__(self):
        return len(self.usable_ltf_candles)

    def __getitem__(self, idx):
        """
        :return:
        x            -> numpy array: window_size * OHLCV
        metadata[2:] -> numpy array: ['pw_open', 'pw_close', 'mon_open', 'mon_high', 'mon_low', 'mon_close', 'vpvr']
        label        -> numpy array: [direction, value]
        weekday      -> int: 1 if weekday, 0 if weekend
        """
        # return candle_seq, metadata (about the week), labels, weekday info
        #
        last_candle = self.usable_ltf_candles.iloc[idx]
        weekday = last_candle.weekday
        idx_in_df = last_candle.name
        last_candle_datetime = self.ltf_data.loc[idx_in_df].time

        # x -> row_idx * OHLCV
        x = self.ltf_candles[idx_in_df - (self.window_size - 1): idx_in_df + 1].copy()
        weekly_info = self.weekly_data.loc[last_candle.week_idx].to_numpy().copy()
        label = self.labels.loc[idx_in_df].to_numpy().copy()

        # normalize x and metadata components
        x, weekly_info, label, base_info = self._normalize_data_metadata_label(x, weekly_info, label)
        base_info.append(last_candle_datetime)

        return x, weekly_info[2:], label, weekday, base_info

    def _normalize_data_metadata_label(self, x, metadata, label):
        # x -> row * OHLCV
        # metadata -> series obj with
        #     ['pw_high', 'pw_low', 'pw_open', 'pw_close', 'mon_open', 'mon_high', 'mon_low', 'mon_close', 'vpvr']
        # label -> label, value

        # take the pw_high and pw_low -> index:
        #           OHLC of x;
        #           'pw_open', 'pw_close', 'mon_open', 'mon_high', 'mon_low', 'mon_close' of metadata;
        #           value of label.
        pw_high = metadata[0]
        pw_low = metadata[1]
        pw_mid = (pw_low + pw_high) / 2
        pw_half_range = (pw_high - pw_low) / 2

        x[:, :4] = (x[:, :4] - pw_mid) / pw_half_range
        metadata[2: 8] = (metadata[2: 8] - pw_mid) / pw_half_range
        label[1] = (label[1] - pw_mid) / pw_half_range

        # standardize the volume
        mean_vol = np.mean(x[:, -1])
        std_vol = np.std(x[:, -1])
        x[:, -1] = (x[:, -1] - mean_vol) / std_vol

        return x, metadata, label, [pw_mid, pw_half_range]

    def _handle_consistent_split(self, data, split):
        # split the data by separating based on the previous week
        # 0.7 - 0.2 - 0.1 for train - dev - test
        if split is None:
            return data

        train_prop = 0.7
        dev_prop = 0.2
        test_prop = 0.1

        num_rows = len(data)
        train_rows = int(num_rows * train_prop)
        dev_rows = int(num_rows * dev_prop)

        # Calculate the indices for splitting
        train_split_index = train_rows
        dev_split_index = train_rows + dev_rows

        # Split the DataFrame
        if split == "train":
            train_data = data.iloc[:train_split_index].copy()
            return train_data
        elif split == "dev":
            dev_data = data.iloc[train_split_index:dev_split_index].copy()
            return dev_data
        elif split == "test":
            test_data = data.iloc[dev_split_index:].copy()
            return test_data

        assert False, "Split should be 'train', 'dev', or 'test'"

    """
    def _handle_consistent_split(self, data, split):
        # split the data by separating based on the previous week
        # 0.7 - 0.2 - 0.1 for train - dev - test
        if split is None:
            return data

        random_state = 42
        train_prop = 0.7
        dev_prop = 0.2
        test_prop = 0.1

        dev_remain = dev_prop / (test_prop + dev_prop)
        test_remain = test_prop / (test_prop + dev_prop)

        unique_week_indices = pd.Series(data.week_idx.unique())
        train_weeks = unique_week_indices.sample(frac=train_prop, random_state=random_state)
        if split == "train":
            return data[data['week_idx'].isin(train_weeks)].copy()

        unique_week_indices = unique_week_indices[~unique_week_indices.isin(train_weeks)]
        dev_weeks = unique_week_indices.sample(frac=dev_remain, random_state=random_state)
        if split == "dev":
            return data[data['week_idx'].isin(dev_weeks)].copy()

        unique_week_indices = unique_week_indices[~unique_week_indices.isin(dev_weeks)]
        test_weeks = unique_week_indices
        if split == "test":
            return data[data['week_idx'].isin(test_weeks)].copy()

        assert False, "Split should be 'train', 'dev', or 'test'"
    
    
        def _handle_consistent_split(self, data, split):
        # 0.7 - 0.2 - 0.1 for train - dev - test
        if split is None:
            return data

        random_state = 42
        train_prop = 0.7
        dev_prop = 0.2
        test_prop = 0.1

        dev_remain = dev_prop / (test_prop + dev_prop)
        test_remain = test_prop / (test_prop + dev_prop)

        train_split = data.sample(frac=train_prop, random_state=random_state)
        if split == "train":
            return train_split

        train_idx = train_split.index
        dev_test_data = data.drop(train_idx, axis=0)
        dev_split = dev_test_data.sample(frac=dev_remain, random_state=random_state)
        if split == "dev":
            return dev_split

        dev_idx = dev_split.index
        test_split = dev_test_data.drop(dev_idx, axis=0)
        if split == "test":
            return test_split

        assert False, "Split should be 'train', 'dev', or 'test'"
    """

    def _order_weekly_data(self):
        order = ['pw_high', 'pw_low', 'pw_open', 'pw_close', 'mon_open', 'mon_high', 'mon_low', 'mon_close', 'vpvr']
        self.weekly_data = self.weekly_data[order]

    def _drop_unnecessary_and_order_ltf(self):
        keep_id = ["open", "high", "low", "close", "volume"]
        for col in self.ltf_candles.columns:
            if col not in keep_id:
                self.ltf_candles.drop(col, inplace=True, axis=1)

        # also order the columns
        self.ltf_candles = self.ltf_candles[keep_id]

    def _control_indices(self, df):
        is_ordered = df.index.equals(pd.RangeIndex(start=0, stop=len(df)))
        assert is_ordered, "The indices must not be corrupted!"

    def _handle_weekly_none(self):
        while self.weekly_data.iloc[0].isna().any():
            self.weekly_data = self.weekly_data.iloc[1:]

        while self.weekly_data.iloc[-1].isna().any():
            self.weekly_data = self.weekly_data.iloc[:-1]

        if len(self.weekly_data[self.weekly_data.isna().any(axis=1)]) > 0:
            none_rows = self.weekly_data[self.weekly_data.isna().any(axis=1)]
            print(self.weekly_data[self.weekly_data.isna().any(axis=1)])
            assert False, "Weekly data contains None values!"

    def _guarantee_overlap_weekly_and_candles(self):
        # assumes no None value in weekly data
        idx_array = self.ltf_candles[self.ltf_candles.usable == 1].week_idx.unique()
        weekly_idx_array = self.weekly_data.index.to_numpy()
        for idx in idx_array:
            if idx not in weekly_idx_array:
                print(idx)
                assert False, "Found a non-existing week in weekly data"

    def _mark_usable_or_not(self):
        # policy of not usable:
        # if monday or among first (window_size - 1) candles, then not usable
        # also marks whether weekday or weekend
        self.ltf_candles["usable"] = self.ltf_candles.apply \
            (lambda row: 0 if (row.name < (self.window_size - 1) or row["time"].weekday() == 0) else 1, axis=1)
        self.ltf_candles["weekday"] = self.ltf_candles.apply(lambda row: 0 if row["time"].weekday() > 4 else 1, axis=1)

    def _add_week_idx(self, ltf_data):
        # already call by reference but still return it
        # add the week_idx column to the original data
        # indices are monday date
        ltf_data["week_idx"] = (
                ltf_data.time - pd.to_timedelta(ltf_data.time.dt.weekday, unit='D')).dt.strftime(
            "%Y-%m-%d")
        return ltf_data

    def _extract_weekly_info(self, hourly_candles):
        time_idx_candles = hourly_candles.set_index("time", inplace=False)

        # Monday ohlc
        monday_data = time_idx_candles[time_idx_candles.index.weekday == 0]
        monday_ohlc = monday_data.resample("W-MON").agg({
            "open": "first",
            "high": "max",
            "low": "min",
            "close": "last"
        })
        monday_ohlc.index = monday_ohlc.index.strftime('%Y-%m-%d')
        monday_ohlc.columns = ["mon_open", "mon_high", "mon_low", "mon_close"]

        # handle the case monday data was null
        weekday_idx = 1
        map_day_str_dict = {1: "W-TUE", 2: "W-WED", 3: "W-THU"}
        while len(monday_ohlc[monday_ohlc.isna().any(axis=1)]) > 0:
            if weekday_idx == 4:
                assert False, "Problem in Monday data"

            print(f"Trying to replace {len(monday_ohlc[monday_ohlc.isna().any(axis=1)])} " +
                  f"monday data (0) with ({weekday_idx})")
            day_data = time_idx_candles[time_idx_candles.index.weekday == weekday_idx]
            day_ohlc = day_data.resample(map_day_str_dict[weekday_idx]).agg({
                "open": "first",
                "high": "max",
                "low": "min",
                "close": "last"
            })
            day_ohlc.index = day_ohlc.index - pd.Timedelta(days=weekday_idx)
            day_ohlc.index = day_ohlc.index.strftime('%Y-%m-%d')
            day_ohlc.columns = ["mon_open", "mon_high", "mon_low", "mon_close"]

            null_indices = monday_ohlc[monday_ohlc.isnull().any(axis=1)].index
            monday_ohlc.loc[null_indices] = day_ohlc.loc[null_indices]

            weekday_idx += 1

        # weekly ohlc
        weekly_data = time_idx_candles.resample("W-SUN").agg({
            "open": "first",
            "high": "max",
            "low": "min",
            "close": "last"})

        # make index monday string
        # week indices were already the final day being Sunday
        # now we push them by 1 day making it following Monday
        weekly_data.index = weekly_data.index + pd.Timedelta(days=1)
        assert all(weekly_data.index.dayofweek == 0), "All dates in weekly_data need to refer to a Monday"
        weekly_data.index = weekly_data.index.strftime("%Y-%m-%d")
        weekly_data.columns = ["pw_open", "pw_high", "pw_low", "pw_close"]

        if (len(weekly_data) - len(monday_ohlc)) > 1:
            assert False

        weekly_data = pd.merge(weekly_data, monday_ohlc, left_index=True, right_index=True, how='outer')

        # VPVR
        vpvr_series = pd.Series(dtype=object)
        num_bins = 10
        for week_idx in time_idx_candles['week_idx'].unique():
            # current week
            weekly_box = time_idx_candles[time_idx_candles['week_idx'] == week_idx]

            if not weekly_box.empty:
                buy_array, sell_array = generate_vpvr(weekly_box, num_bins)
                # concatenate buy and sell arrays
                vpvr_series.at[week_idx] = np.concatenate([buy_array, sell_array])

        return pd.merge(weekly_data, vpvr_series.rename("vpvr"), left_index=True, right_index=True, how='outer')

    @staticmethod
    def get_collate_fn():
        def hourly_collate_fn(batch):
            x = [item[0] for item in batch]
            metadata = [item[1] for item in batch]
            label = [item[2] for item in batch]
            weekday = [item[3] for item in batch]
            base_data = [item[4] for item in batch]

            for idx, met in enumerate(metadata):
                metadata[idx] = np.concatenate((met[0:-1], met[-1]))

            # Convert lists to PyTorch tensors
            x_tensor = torch.tensor(x, dtype=torch.float32)  # shape: [BS, seq_len, OHLCV]
            metadata_tensor = torch.tensor(metadata, dtype=torch.float32)  # shape: [BS, metadata_dim]
            label_tensor = torch.tensor(label, dtype=torch.float32)  # shape: [BS, 2]
            weekday_tensor = torch.tensor(weekday, dtype=torch.float32)  # shape: [BS]

            return x_tensor, metadata_tensor, label_tensor, weekday_tensor, base_data

        return hourly_collate_fn


In [2]:
from torch.utils.data import Dataset


class SuperTrainerDataset(Dataset):
    def __init__(self):
        super(SuperTrainerDataset, self).__init__()
        self.length_list = [0]
        self.limit_indices = [0]
        self.datasets = list()

    def __getitem__(self, item):
        for idx, limit in enumerate(self.limit_indices):
            if item < limit:
                return self.datasets[idx - 1][item - self.limit_indices[idx - 1]]
        assert False, "Index greater than dataset size"

    def __len__(self):
        return sum(self.length_list)

    def add_dataset(self, dataset):
        self.length_list.append(len(dataset))
        self.limit_indices.append(sum(self.length_list))
        self.datasets.append(dataset)





In [3]:
import json
import os.path

import torch.optim as optim
import torch
from torch.utils.data import DataLoader

from utils import find_last_model, load_data, prepare_data
from dataloader.SwingDatasets import HourlyDataset
from dataloader.SuperTrainerDataset import SuperTrainerDataset
from trainer.trainer import HourlySwingModelTrainer
from model.swing_model import HourlySwingModel

from dataset_paths import forex_path_dict, crypto_path_dict, index_path_dict, tz_dict

pair_types_dict = {
    "forex": forex_path_dict,
    "crypto": crypto_path_dict,
    "index": index_path_dict
}

with open("config_pretrain.json", "r") as file:
    config_train = json.load(file)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# train cfg
LOAD_MODEL = config_train["load_model"]
TYPE = config_train["type"]
WINDOW_SIZE = config_train["window_size"]
BATCH_SIZE = config_train["batch_size"]
VALUE_ONLY = config_train["value_only"]
LR = config_train["lr"]
EPOCHS = config_train["epochs"]
EVAL_PERIOD = config_train["eval_period"]
CHECKPOINT_PERIOD = config_train["checkpoint_period"]
MODEL_OUT_PATH = config_train["model_out"]

# model cfg
with open("config_model.json", "r") as file:
    config_model = json.load(file)

inp_dim = config_model["inp_dim"]
metadata_dim = config_model["metadata_dim"]
metadata_bias = config_model["metadata_bias"]
metadata_gate_bias = config_model["metadata_gate_bias"]
fusion_model_dim = config_model["fusion_model_dim"]
fusion_num_heads = config_model["fusion_num_heads"]
fusion_num_layers = config_model["fusion_num_layers"]
fusion_apply_grn = config_model["fusion_apply_grn"]
fusion_dropout = config_model["fusion_dropout"]
lstm_num_layers = config_model["lstm_num_layers"]
lstm_bidirectional = config_model["lstm_bidirectional"]
lstm_dropout = config_model["lstm_dropout"]
loss_punish_cert = config_model["loss_punish_cert"]

# load model
model = HourlySwingModel(inp_dim=inp_dim, metadata_dim=metadata_dim, metadata_bias=metadata_bias,
                         metadata_gate_bias=metadata_gate_bias, fusion_model_dim=fusion_model_dim,
                         fusion_num_heads=fusion_num_heads, fusion_num_layers=fusion_num_layers,
                         fusion_apply_grn=fusion_apply_grn, fusion_dropout=fusion_dropout,
                         lstm_num_layers=lstm_num_layers, lstm_bidirectional=lstm_bidirectional,
                         lstm_dropout=lstm_dropout, loss_punish_cert=loss_punish_cert)
model.to(device)
if LOAD_MODEL is not None:
    model_file = find_last_model(LOAD_MODEL)
    print(f"Loading model from: {os.path.join(LOAD_MODEL, model_file)}")
    model.load_state_dict(torch.load(os.path.join(LOAD_MODEL, model_file), map_location=device))

# load data and prepare super_dataset
pair_path_dict = pair_types_dict[TYPE]
super_train_dataset = SuperTrainerDataset()
super_dev_dataset = SuperTrainerDataset()
idx = 0
for pair, path in pair_path_dict.items():
    idx += 1
    print(f"{idx} - Loading {pair.upper()} from {path}")
    # load data
    pair_data = load_data(path, add_zigzag_col=True)
    pair_data = prepare_data(pair_data, tz_dict[pair])
    # prepare dataset
    print(f"Building train dataset...")
    train_dataset = HourlyDataset(pair_data, WINDOW_SIZE, "train")
    print(f"Building dev dataset...")
    dev_dataset = HourlyDataset(pair_data, WINDOW_SIZE, "dev")
    # add to super dataset
    super_train_dataset.add_dataset(train_dataset)
    super_dev_dataset.add_dataset(dev_dataset)
    print(f"{pair.upper()} loaded.\n")

1 - Loading AUDUSD from data/forex/AUDUSD.csv
Building train dataset...
Trying to replace 1 monday data (0) with (1)
Building dev dataset...
Trying to replace 1 monday data (0) with (1)
AUDUSD loaded.

2 - Loading EURUSD from data/forex/EURUSD.csv
Building train dataset...
Building dev dataset...
EURUSD loaded.

3 - Loading GBPUSD from data/forex/GBPUSD.csv
Building train dataset...
Trying to replace 2 monday data (0) with (1)
Building dev dataset...
Trying to replace 2 monday data (0) with (1)
GBPUSD loaded.

4 - Loading NZDUSD from data/forex/NZDUSD.csv
Building train dataset...
Trying to replace 2 monday data (0) with (1)
Building dev dataset...
Trying to replace 2 monday data (0) with (1)
NZDUSD loaded.

5 - Loading USDCAD from data/forex/USDCAD.csv
Building train dataset...
Trying to replace 2 monday data (0) with (1)
Building dev dataset...
Trying to replace 2 monday data (0) with (1)
USDCAD loaded.

6 - Loading USDCHF from data/forex/USDCHF.csv
Building train dataset...
Trying t

In [6]:
summ = 0
for dataset in super_train_dataset.datasets:
    summ += len(dataset.ltf_data[dataset.ltf_data.zigzag!=0])

In [7]:
summ

23054

In [8]:
BatchSize = 3
metadata_dim = 4
seq_len = 5

input_tensor = torch.randn(BatchSize, metadata_dim)  # Replace with your tensor

# Expand to the desired shape
output_tensor = input_tensor.unsqueeze(1).expand(-1, seq_len, -1)

In [9]:
input_tensor

tensor([[-8.3414e-01, -8.8846e-01, -1.5769e+00,  7.2267e-01],
        [-1.2260e+00,  4.3148e-04,  9.7815e-01,  7.1238e-02],
        [-4.9249e-01, -2.4486e+00,  3.6938e-02,  2.4355e+00]])

In [13]:
output_tensor[0].shape

torch.Size([5, 4])

In [16]:
import torch.nn as nn

In [24]:
par = nn.Parameter(torch.randn((5, 10)))

In [25]:
par

Parameter containing:
tensor([[ 1.3859e+00,  2.7784e-01,  1.0979e-02, -7.9550e-02, -2.4466e-01,
         -1.6264e-01,  2.2212e+00, -1.7096e-01, -1.1046e+00, -3.2281e-01],
        [-1.3282e+00,  1.4145e+00,  4.7734e-01,  3.2597e-01, -7.9155e-01,
         -1.9237e+00,  1.4716e+00, -7.7554e-01, -8.3920e-01,  1.8074e-01],
        [-2.2318e-01, -4.0131e-01, -1.1018e+00, -3.5627e-01, -9.2998e-01,
         -1.0346e+00, -1.0416e+00, -2.1275e-01,  1.2484e+00,  6.3164e-04],
        [-2.0751e-01, -8.2800e-01, -1.0970e+00,  1.4049e+00, -5.3343e-01,
          7.3585e-02, -6.6443e-01, -6.6012e-02, -6.3939e-01, -1.2284e+00],
        [-7.6871e-01, -1.5211e-01,  6.6235e-01, -1.2609e+00,  7.2886e-01,
          2.8504e-01, -8.5848e-01,  2.8860e-01, -4.8599e-01, -6.1851e-01]],
       requires_grad=True)

In [21]:
a = torch.randn((3,5,10))

In [22]:
a

tensor([[[ 1.4173e-01,  6.5189e-01, -6.6535e-01, -2.1323e-01,  1.0572e+00,
           1.5039e+00,  6.4828e-01, -1.8274e-01, -1.5906e-01,  1.2856e+00],
         [-1.6971e+00,  1.3763e+00, -1.3518e+00,  3.3029e-01, -4.1408e-01,
           3.7386e-01,  6.9113e-03,  9.8356e-01,  8.7109e-01, -5.6164e-01],
         [-2.6846e+00, -1.1234e+00,  2.7144e-01, -8.2181e-01, -1.9523e-01,
           8.7756e-01, -3.6374e-01, -4.9087e-01,  9.1444e-03, -1.1149e-01],
         [-1.1559e+00,  1.7282e+00, -7.9126e-01, -8.4900e-02, -4.4864e-01,
          -2.1610e-03, -1.2869e+00, -3.3090e-01, -4.0983e-01, -6.6984e-01],
         [ 1.1939e+00,  1.6781e+00, -6.2555e-01,  5.3702e-01, -8.5439e-01,
          -6.0880e-01,  1.3065e-01,  3.4493e-01,  4.5945e-01, -4.7505e-01]],

        [[-1.6285e+00,  4.0523e-01, -1.6392e+00,  7.1000e-02, -8.4359e-02,
          -2.5679e-02, -1.3094e+00,  9.1601e-01,  2.9487e-01,  6.9084e-01],
         [-5.7770e-01,  1.4168e+00, -1.3110e-01,  1.3901e+00,  2.0432e+00,
          -8.1189

In [26]:
a + par

tensor([[[ 1.5277,  0.9297, -0.6544, -0.2928,  0.8125,  1.3412,  2.8695,
          -0.3537, -1.2636,  0.9628],
         [-3.0253,  2.7909, -0.8745,  0.6563, -1.2056, -1.5498,  1.4785,
           0.2080,  0.0319, -0.3809],
         [-2.9078, -1.5247, -0.8304, -1.1781, -1.1252, -0.1570, -1.4053,
          -0.7036,  1.2576, -0.1109],
         [-1.3634,  0.9002, -1.8883,  1.3200, -0.9821,  0.0714, -1.9514,
          -0.3969, -1.0492, -1.8983],
         [ 0.4252,  1.5260,  0.0368, -0.7238, -0.1255, -0.3238, -0.7278,
           0.6335, -0.0265, -1.0936]],

        [[-0.2426,  0.6831, -1.6282, -0.0085, -0.3290, -0.1883,  0.9118,
           0.7451, -0.8097,  0.3680],
         [-1.9059,  2.8313,  0.3462,  1.7161,  1.2517, -2.7356,  0.7576,
          -0.0826,  0.7358, -0.5943],
         [-1.5206, -1.4135,  0.7118,  0.4218, -1.3113, -0.3806,  1.7865,
          -0.5722,  1.6663,  0.5585],
         [-1.4534, -1.7459,  0.2954,  1.0209, -2.3645, -0.2007, -1.9557,
          -0.5275,  0.4339,  0.1563],

In [27]:
class MetricDataloader(DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __getitem__(self, index):
        # Calculate start and end indices for the batch
        start_idx = index * self.batch_size
        end_idx = start_idx + self.batch_size

        batch = [self.dataset[i] for i in range(start_idx, min(end_idx, len(self.dataset)))]

        if self.collate_fn is not None:
            batch = self.collate_fn(batch)

        return batch

In [31]:
train_dataloader = MetricDataloader(train_dataset, batch_size=1, shuffle=False,
                              collate_fn=HourlyDataset.get_collate_fn())

In [41]:
train_dataloader[1004][-1][-1]

[1.70511,
 0.043810000000000016,
 Timestamp('2020-04-03 04:00:00+0200', tz='Etc/GMT-2')]