In [None]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import pytz
import numpy as np
import os
import sys
scripts_dir = os.path.abspath(os.path.join(os.getcwd(), '..', 'scripts'))
sys.path.append(scripts_dir)
from data_generator import normalize_new_data


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
torch.set_default_device(device)

cuda


In [11]:
data = pd.read_csv("../data/final_data/cleaned_compiled_data.csv")

In [12]:
class LightningDataset(Dataset):
    def __init__(self, compiled_df, timezone_str="Asia/Singapore"):
        self.compiled_df = compiled_df.copy()
        self.timezone = pytz.timezone(timezone_str)
        self.samples = []

        self._prepare_dataset()

    def _prepare_dataset(self):
        # Ensure datetime index
        self.compiled_df["Timestamp"] = pd.to_datetime(self.compiled_df["Timestamp"])
        if not isinstance(self.compiled_df.index, pd.DatetimeIndex):
            self.compiled_df.set_index("Timestamp", inplace=True)
        self.compiled_df.index = self.compiled_df.index.tz_localize(None)

        # Prepare input features and drop target
        input_df = self.compiled_df.drop(columns=["Lightning_Risk"])
        input_columns = input_df.columns.values.tolist()

        # Get valid timestamps
        min_ts = self.compiled_df.index.min().ceil("2h") + pd.Timedelta(hours=2)
        max_ts = self.compiled_df.index.max().floor("2h")
        valid_ts = self.compiled_df.loc[
            (self.compiled_df.index >= min_ts) &
            (self.compiled_df.index <= max_ts) &
            (self.compiled_df.index.hour % 2 == 0) &
            (self.compiled_df.index.minute == 0)
        ].index

        for timestamp in valid_ts:
            try:
                # Input time windows (past)
                input_times = [timestamp - pd.Timedelta(minutes=delta) for delta in [120, 90, 60, 30, 0]]
                input_data = input_df.loc[input_times].values.flatten()

                # Output time windows (future)
                output_times = [timestamp + pd.Timedelta(minutes=delta) for delta in [0, 30, 60, 90, 120]]
                output_data = self.compiled_df.loc[output_times, "Lightning_Risk"].astype(int).values.flatten()

                self.samples.append((input_data, output_data))
            except KeyError:
                continue  # Skip if any timestamps are missing

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

In [13]:
dataset = LightningDataset(data)

In [14]:
len(dataset)

8398

In [15]:
dataset[0]

(tensor([9.6360e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 3.3200e+01, 2.8300e+01, 3.2600e+01,
         3.0800e+01, 3.1200e+01, 3.2500e+01, 3.2500e+01, 3.2400e+01, 5.5000e+00,
         1.2400e+01, 3.1000e+00, 6.7000e+00, 6.4000e+00, 5.9000e+00, 5.3000e+00,
         2.1000e+00, 4.4200e+01, 8.2700e+01, 4.9400e+01, 6.8800e+01, 7.4900e+01,
         5.3100e+01, 6.1800e+01, 5.1500e+01, 1.1000e+01, 9.3000e+01, 1.0400e+02,
         3.1200e+02, 1.6600e+02, 4.3000e+01, 2.5200e+02, 1.5900e+02, 9.6366e+04,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 3.3500e+01, 2.8300e+01, 3.2200e+01, 3.1200e+01,
         3.0400e+01, 3.2900e+01, 3.3200e+01, 3.0700e+01, 5.0000e+00, 1.2600e+01,
         4.2000e+00, 4.3000e+00, 7.7000e+00, 7.3000e+00, 4.9000e+00, 4.7000e+00,
         4.3500e+01, 8.3300e+01, 5.4600e+01, 6.3600e+01, 7.6300e+01, 5.3300e+01,
         5.9900e+01, 6.1900e