In [17]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

class StockForecastDataset(Dataset):
    def __init__(self, data_dir, window_size=90, transform=None):
        self.data_dir = data_dir
        self.window_size = window_size
        self.transform = transform
        self.sector_files = []

        # Collect all CSV file paths in subdirectories, except "Economic_Data"
        for root, dirs, files in os.walk(data_dir):
            if 'Economic_Data' in dirs:
                dirs.remove('Economic_Data')  # Exclude the 'Economic_Data' directory
            for file in files:
                if file.endswith(".csv"):
                    self.sector_files.append(os.path.join(root, file))

        # Store the total length (number of sliding windows across all files)
        self.total_windows = self._calculate_total_windows()

    def _calculate_total_windows(self):
        total_windows = 0
        for csv_path in self.sector_files:
            df = pd.read_csv(csv_path)
            total_windows += max(0, len(df) - self.window_size)
        return total_windows
    
    def _get_data_from_file(self, csv_path, start_idx):
        # Read the specific CSV file and return the data corresponding to the window
        df = pd.read_csv(csv_path)
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.sort_values('timestamp').reset_index(drop=True)
        
        # Extract relevant features (open, close, volume) and labels
        features = df[['1. open','2. high','3. low','4. close','5. adjusted close','6. volume']].values
        labels = df[['Forecast 1 Week', 'Forecast 2 Week', 'Forecast 3 Week', 'Forecast 4 Week']].values
        
        # Return the sliding window and the corresponding labels
        window_features = features[start_idx:start_idx + self.window_size]
        window_labels = labels[start_idx + self.window_size - 1]  # Labels for the most recent day of the window
        
        return window_features, window_labels
    
    def __len__(self):
        return self.total_windows

    def __getitem__(self, idx):
        # Locate the corresponding CSV file and sliding window based on idx
        window_count = 0
        for csv_path in self.sector_files:
            df = pd.read_csv(csv_path)
            num_windows = max(0, len(df) - self.window_size)
            if window_count + num_windows > idx:
                # This file contains the required window
                file_idx = idx - window_count
                features, labels = self._get_data_from_file(csv_path, file_idx)
                break
            window_count += num_windows
        
        # Convert to torch tensors
        features = torch.tensor(features, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)  # Labels are categorical
        
        if self.transform:
            features = self.transform(features)
        
        return features, labels



In [18]:
# Example usage of DataLoader
data_dir = '/Users/danielcaraballo/Desktop/TimeSeriesProject/data/raw_data'
window_size = 90
batch_size = 4

# Instantiate dataset
dataset = StockForecastDataset(data_dir=data_dir, window_size=window_size)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,num_workers=4)




In [19]:
# Example iteration through DataLoader
for batch_features, batch_labels in dataloader:
    print("Batch features shape:", batch_features.shape)  # Should be [batch_size, 90, 3]
    print("Batch labels shape:", batch_labels.shape)      # Should be [batch_size, 4]
    break

Batch features shape: torch.Size([4, 90, 6])
Batch labels shape: torch.Size([4, 4])


In [20]:
batch_labels

tensor([[1, 0, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 2, 2, 2]])

In [21]:
batch_features[0]

tensor([[2.0690e+01, 2.0940e+01, 2.0115e+01, 2.0120e+01, 1.6921e+01, 2.9903e+06],
        [1.9940e+01, 2.0000e+01, 1.9440e+01, 1.9770e+01, 1.6627e+01, 2.1769e+06],
        [1.9990e+01, 2.0395e+01, 1.9880e+01, 2.0230e+01, 1.7014e+01, 1.9941e+06],
        [2.0380e+01, 2.0490e+01, 2.0040e+01, 2.0410e+01, 1.7165e+01, 1.8599e+06],
        [2.0470e+01, 2.0555e+01, 2.0160e+01, 2.0510e+01, 1.7249e+01, 2.2881e+06],
        [2.0340e+01, 2.0900e+01, 2.0000e+01, 2.0390e+01, 1.7149e+01, 2.2722e+06],
        [2.0990e+01, 2.1475e+01, 2.0950e+01, 2.1230e+01, 1.7855e+01, 3.4474e+06],
        [2.1230e+01, 2.1587e+01, 2.1060e+01, 2.1400e+01, 1.7998e+01, 2.5400e+06],
        [2.1250e+01, 2.1380e+01, 2.1000e+01, 2.1060e+01, 1.7712e+01, 2.0513e+06],
        [2.1380e+01, 2.1500e+01, 2.0895e+01, 2.1160e+01, 1.7796e+01, 2.0281e+06],
        [2.1000e+01, 2.1080e+01, 2.0600e+01, 2.0610e+01, 1.7334e+01, 2.2606e+06],
        [2.0360e+01, 2.0370e+01, 2.0011e+01, 2.0180e+01, 1.6972e+01, 1.7812e+06],
        [2.0440e