In [19]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import pytz
import numpy as np
import os
import sys
import glob

scripts_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts'))
sys.path.append(scripts_dir)
from data_generator import normalize_new_data

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
torch.set_default_device(device)

cuda


In [21]:
# Define the directory and base filename pattern
file_pattern = "../data/final_data/cleaned_compiled_data_normalized_part*.csv"

# Use glob to get a sorted list of all matching CSV parts
csv_files = sorted(glob.glob(file_pattern))

# Load and concatenate all parts
data = pd.concat((pd.read_csv(f) for f in csv_files), ignore_index=True)

# Done! Now `data` holds the full combined normalized DataFrame
print(f"Loaded {len(csv_files)} files. Final shape: {data.shape}")

Loaded 2 files. Final shape: (221369, 42)


In [22]:
class LightningDataset(Dataset):
    def __init__(self, compiled_df, timezone_str="Asia/Singapore"):
        self.compiled_df = compiled_df.copy()
        self.timezone = pytz.timezone(timezone_str)
        self.samples = []

        self._prepare_dataset()

    def _prepare_dataset(self):
        # Ensure datetime index
        self.compiled_df["Timestamp"] = pd.to_datetime(self.compiled_df["Timestamp"])
        if not isinstance(self.compiled_df.index, pd.DatetimeIndex):
            self.compiled_df.set_index("Timestamp", inplace=True)
        self.compiled_df.index = self.compiled_df.index.tz_localize(None)

        # Prepare input features and drop target
        input_df = self.compiled_df.drop(columns=["Lightning_Risk"])
        input_columns = input_df.columns.values.tolist()

        # Get valid timestamps
        min_ts = self.compiled_df.index.min().ceil("2h") + pd.Timedelta(hours=2)
        max_ts = self.compiled_df.index.max().floor("2h")
        valid_ts = self.compiled_df.loc[
            (self.compiled_df.index >= min_ts) &
            (self.compiled_df.index <= max_ts) &
            (self.compiled_df.index.hour % 2 == 0) &
            (self.compiled_df.index.minute == 0)
        ].index

        for timestamp in valid_ts:
            try:
                # Input time windows (past)
                input_times = [timestamp - pd.Timedelta(minutes=delta) for delta in [120, 90, 60, 30, 0]]
                input_data = input_df.loc[input_times].values.flatten()

                # Output time windows (future)
                output_times = [timestamp + pd.Timedelta(minutes=delta) for delta in [0, 30, 60, 90, 120]]
                output_data = self.compiled_df.loc[output_times, "Lightning_Risk"].astype(int).values.flatten()

                self.samples.append((input_data, output_data))
            except KeyError:
                continue  # Skip if any timestamps are missing

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

In [23]:
dataset = LightningDataset(data)

In [24]:
len(dataset)

8398

In [25]:
dataset[0]

(tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8630,
         0.5796, 0.8315, 0.7320, 0.7548, 0.8262, 0.8262, 0.8208, 0.5448, 0.6825,
         0.4646, 0.5755, 0.5682, 0.5555, 0.5392, 0.4169, 0.1023, 0.6465, 0.1563,
         0.4115, 0.5093, 0.1984, 0.3096, 0.1798, 0.1312, 0.4850, 0.5158, 0.9292,
         0.6648, 0.3130, 0.8303, 0.6496, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.8784, 0.5796, 0.8101, 0.7548, 0.7088, 0.8473,
         0.8630, 0.7263, 0.5306, 0.6855, 0.5055, 0.5088, 0.5981, 0.5893, 0.5276,
         0.5216, 0.0954, 0.6576, 0.2164, 0.3348, 0.5329, 0.2008, 0.2839, 0.3110,
         0.0000, 0.4907, 0.5130, 0.5703, 0.6227, 0.4072, 0.8544, 0.5703, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.8155, 0.5796,
         0.8421, 0.7992, 0.7263, 0.8315, 0.8155, 0.7435, 0.5632, 0.6870, 0.5121,
         0.4604, 0.6065, 0.5682, 0.5448, 0.5420, 0.1674, 0.6391, 0.1629, 0.2839,
         0.4812, 0.2361, 0.3