<a href="https://colab.research.google.com/github/MengOonLee/Deep_learning/blob/master/PyTorch/Transformer/Tabular/TimeSeriesForecasting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

T = 1000
num_entities = 3
num_features = 2

data_dict = {}
for eid in range(num_entities):
    t = np.linspace(0, 100, T)
    f1 = np.sin(t / 5 + eid) + 0.1 * np.random.randn(T)
    f2 = np.cos(t / 7 + eid) + 0.1 * np.random.randn(T)
    data_dict[eid] = np.stack([f1, f2], axis=1)

data_dict

{0: array([[ 0.18338777,  1.06641224],
        [ 0.0691479 ,  0.80338625],
        [ 0.04630641,  1.13123291],
        ...,
        [ 0.77620789, -0.10298175],
        [ 0.89491539, -0.22480243],
        [ 0.85298878, -0.03042633]]),
 1: array([[ 0.8455896 ,  0.47465939],
        [ 0.7771595 ,  0.42744481],
        [ 0.84012823,  0.66833162],
        ...,
        [ 0.86921874, -0.94184081],
        [ 0.87139039, -0.81069245],
        [ 0.99900522, -0.82725451]]),
 2: array([[ 1.05307743, -0.38169608],
        [ 0.9158915 , -0.36339498],
        [ 0.99801587, -0.20528155],
        ...,
        [-0.0245328 , -0.85650579],
        [ 0.04056166, -0.92190291],
        [-0.05096422, -0.7117244 ]])}

In [2]:
import torch

class TimeSeriesDataset(torch.utils.data.Dataset):
    def __init__(self, data_dict, input_window, output_window):
        self.series = []
        self.entity_ids = []
        self.input_window = input_window
        self.output_window = output_window

        for eid, data in data_dict.items():
            for i in range(len(data) - input_window - output_window):
                x = data[i : i + input_window]
                y = data[i + input_window : i + input_window + output_window]
                self.series.append((x, y))
                self.entity_ids.append(eid)

        self.num_entities = len(data_dict)

    def __len__(self):
        return len(self.series)

    def __getitem__(self, idx):
        x, y = self.series[idx]
        entity_id = self.entity_ids[idx]
        return (
            torch.tensor(data=x, dtype=torch.float32),
            torch.tensor(data=y, dtype=torch.float32),
            torch.tensor(data=entity_id, dtype=torch.long)
        )

input_window = 48
output_window = 24
ds_train = TimeSeriesDataset(data_dict=data_dict,
    input_window=input_window, output_window=output_window)
dl_train = torch.utils.data.DataLoader(dataset=ds_train,
    batch_size=32, shuffle=True)
len(dl_train.dataset)

2784

In [10]:
import torch

class PositionalEncoding(torch.nn.Module):
    def __init__(self):
        super().__init__()

class TimeSeriesForecast(torch.nn.Module):
    def __init__(self, num_entities, feature_size, entity_emb_dim=8, d_model=64,
            nhead=8, num_layers=3, dropout=0.1):
        super().__init__()
        self.entity_embedding = torch.nn.Embedding(num_entities, entity_emb_dim)
        self.input_proj = torch.nn.Linear(feature_size + entity_emb_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)

    def forward(self, src, entity_id):
        B, T, F = src.shape
        entity_emb = self.entity_embedding(entity_id)
        entity_emb_expanded = entity_emb.unsqueeze(1).expand(-1, T, -1)
        src = torch.cat([src, entity_emb_expanded], dim=-1)

        src = self.input_proj(src)
        src = self.pos_encoder(src)

        return src

model = TimeSeriesForecast(num_entities=num_entities, feature_size=num_features)

for x, y, eid in dl_train:
    preds = model(x, eid)
    print(preds.shape)

torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
torch.Size([32, 48, 64])
