In [202]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [203]:
import pandas as pd
import numpy as np
import os
import csv

from torchsummary import summary
from tqdm import tqdm

from torch.utils.data import DataLoader, Dataset
from torch.utils.data import TensorDataset
from torch import nn, optim, tensor, Tensor

In [204]:
class CandelsDataset(Dataset):
    def __init__(self, root_dir: str, years: list, window: int):
        self.window = window
        self.tables_list = self.__get_correct_tables(root_dir, years)
        self.samples = self.__generate_samples(self.tables_list)

    def __get_correct_tables(self, root_dir, years):
        """
        get paths of .csv with len > window 
        """
        tables_list = []
        for root, dirs, files in os.walk(root_dir):
            if dirs:
                tables_dirs = dirs
                break
        for table_dir in tables_dirs:
            if int(table_dir[-4:]) in years:
                full_dir = os.path.join(root_dir, table_dir)
                for root, dirs, files in os.walk(full_dir):
                    for file in files:
                        if file.lower().endswith('.csv'):
                            full_path = os.path.join(full_dir, file)
                            with open(full_path) as f:
                                if sum(1 for line in f) > self.window:
                                    tables_list.append(full_path)
        return tables_list

    def __generate_samples(self, tables_list):
        samples = Tensor()
        for table in tables_list:
            print(table)
            df = pd.read_csv(
                table,
                sep =";",
                names=["figi", "utc", "open", "close", "high", "low", "volume"],
                index_col=False
                )
            df['utc'] = pd.to_datetime(df['utc'], utc=True)
            df['hour'] = df['utc'].dt.hour
            df['minute'] = df['utc'].dt.minute
            df['day_of_week'] = df['utc'].dt.day_of_week
            df.drop(labels=['utc'], axis=1, inplace=True)
            df = df[(df['day_of_week'] < 5)] #drop not tradeble days
            print(df)
        return samples


    def __len__(self):
        return len(self.tables_list)

    def __getitem__(self, idx):
        return self.tables_list[idx][:180].to(device), self.tables_list[idx][-1].to(device)

In [205]:
train_dataset = CandelsDataset(
    root_dir="market_data/unzip_data", 
    years=[2023, 2024], 
    window=180
    )

market_data/unzip_data\BBG004730N88_2023\e6123145-9665-43e0-8413-cd61b8aa9b13_20230101.csv
Empty DataFrame
Columns: [figi, open, close, high, low, volume, hour, minute, day_of_week]
Index: []
market_data/unzip_data\BBG004730N88_2023\e6123145-9665-43e0-8413-cd61b8aa9b13_20230102.csv
                                     figi    open   close    high     low  \
0    e6123145-9665-43e0-8413-cd61b8aa9b13  139.86  139.86  139.86  139.86   
1    e6123145-9665-43e0-8413-cd61b8aa9b13  139.85  139.85  139.85  139.85   
2    e6123145-9665-43e0-8413-cd61b8aa9b13  139.85  139.84  139.85  139.84   
3    e6123145-9665-43e0-8413-cd61b8aa9b13  139.84  139.84  139.84  139.84   
4    e6123145-9665-43e0-8413-cd61b8aa9b13  139.83  139.82  139.83  139.82   
..                                    ...     ...     ...     ...     ...   
343  e6123145-9665-43e0-8413-cd61b8aa9b13  141.84  140.98  141.84  140.98   
344  e6123145-9665-43e0-8413-cd61b8aa9b13  141.84  141.10  141.90  141.10   
345  e6123145-9665-43e0-

KeyboardInterrupt: 