In [1]:
import os

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import tensorflow as tf


import torch
import torch.utils.data as data
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.utils.data import DataLoader


import torchvision.transforms as transforms
from torch import nn, optim
from tqdm import tqdm
from torch.utils.data import random_split
import torchmetrics
from torchmetrics import Metric


In [2]:
DATA_DIR = 'gs://time_series_datasets'
LOCAL_CACHE_DIR = './data_loader/dataset/'

In [3]:
data="weather"

In [4]:
feature_type='MS'
target='OT'

In [5]:
target_slice=slice(0, None)

In [6]:
seq_len=int(336)
pred_len=int(96)

In [7]:
if not os.path.isdir(LOCAL_CACHE_DIR):
    os.mkdir(LOCAL_CACHE_DIR)
file_name = data + '.csv'
cache_filepath = os.path.join(LOCAL_CACHE_DIR, file_name)
if not os.path.isfile(cache_filepath):
    tf.io.gfile.copy(
  os.path.join(DATA_DIR, file_name), cache_filepath, overwrite=True
  )
    # Download the data from the cloud storage
    # (Implement cloud storage download here)
df_raw = pd.read_csv(cache_filepath)
df = df_raw.set_index('date')

In [8]:
df.shape

(52696, 21)

In [9]:
if feature_type == 'S':
    df = df[[target]]
elif feature_type == 'MS':
    target_idx = df.columns.get_loc(target)
    target_slice = slice(target_idx, target_idx + 1)
# split train/valid/test
n = len(df)

In [10]:
target_slice 

slice(20, 21, None)

In [11]:
if data.startswith('ETTm'):
    train_end = 12 * 30 * 24 * 4
    val_end = train_end + 4 * 30 * 24 * 4
    test_end = val_end + 4 * 30 * 24 * 4
elif data.startswith('ETTh'):
     train_end = 12 * 30 * 24
     val_end = train_end + 4 * 30 * 24
     test_end = val_end + 4 * 30 * 24
else:
    train_end = int(n * 0.7)
    val_end = n - int(n * 0.2)
    test_end = n

In [12]:
train_end, val_end , test_end 

(36887, 42157, 52696)

In [13]:
train_df = df[:train_end]
val_df = df[train_end - seq_len : val_end]
test_df = df[val_end - seq_len : test_end]

In [14]:
train_df.head()

Unnamed: 0_level_0,p (mbar),T (degC),Tpot (K),Tdew (degC),rh (%),VPmax (mbar),VPact (mbar),VPdef (mbar),sh (g/kg),H2OC (mmol/mol),...,wv (m/s),max. wv (m/s),wd (deg),rain (mm),raining (s),SWDR (W/m�),PAR (�mol/m�/s),max. PAR (�mol/m�/s),Tlog (degC),OT
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2020-01-01 00:10:00,1008.89,0.71,273.18,-1.33,86.1,6.43,5.54,0.89,3.42,5.49,...,1.02,1.6,224.3,0.0,0.0,0.0,0.0,0.0,11.45,428.1
2020-01-01 00:20:00,1008.76,0.75,273.22,-1.44,85.2,6.45,5.49,0.95,3.39,5.45,...,0.43,0.84,206.8,0.0,0.0,0.0,0.0,0.0,11.51,428.0
2020-01-01 00:30:00,1008.66,0.73,273.21,-1.48,85.1,6.44,5.48,0.96,3.39,5.43,...,0.61,1.48,197.1,0.0,0.0,0.0,0.0,0.0,11.6,427.6
2020-01-01 00:40:00,1008.64,0.37,272.86,-1.64,86.3,6.27,5.41,0.86,3.35,5.37,...,1.11,1.48,206.4,0.0,0.0,0.0,0.0,0.0,11.7,430.0
2020-01-01 00:50:00,1008.61,0.33,272.82,-1.5,87.4,6.26,5.47,0.79,3.38,5.42,...,0.49,1.4,209.6,0.0,0.0,0.0,0.0,0.0,11.81,432.2


In [15]:
def extract_contiguous_rows(df):
    # Ensure the number of rows is a multiple of 10
    num_rows = len(df)
    num_chunks = num_rows // 10
    total_rows = num_chunks * 10
    
    # Extract contiguous rows and reshape them into (None, 10, 10) format
    contiguous_rows = df.iloc[:total_rows].values.reshape(-1, 10, 10)
    return contiguous_rows

In [17]:
# Example usage:
# Creating a sample DataFrame with 25 rows and 10 columns
dat = np.random.rand(25, 10)
df = pd.DataFrame(dat)

(25, 10)

In [18]:
# Extracting contiguous rows using the function
contiguous_rows = extract_contiguous_rows(df)

# Printing the extracted rows
print("Extracted Contiguous Rows:")
print(contiguous_rows)

Extracted Contiguous Rows:
[[[0.60109751 0.57316002 0.82235306 0.90830239 0.53543699 0.42309066
   0.12857248 0.14243364 0.15345652 0.28752917]
  [0.34046589 0.44083654 0.45938235 0.82803973 0.53560273 0.08630841
   0.82641381 0.7381179  0.78836775 0.61092037]
  [0.53632828 0.85833998 0.98647752 0.89447627 0.90133282 0.26340022
   0.52546261 0.08881456 0.0936704  0.15455947]
  [0.00100327 0.48139877 0.55410048 0.42744961 0.82685272 0.10608306
   0.43728202 0.8961071  0.72276642 0.94977733]
  [0.28487528 0.05914871 0.70827748 0.25527995 0.57306995 0.25726398
   0.5705552  0.70136006 0.84599999 0.85001693]
  [0.01057907 0.40862433 0.8177182  0.03474773 0.08423594 0.85216681
   0.13997663 0.45938804 0.59493748 0.00928901]
  [0.25555435 0.06962285 0.66310731 0.90705797 0.99488619 0.8787966
   0.4273659  0.0096768  0.36168258 0.84363766]
  [0.67014206 0.31630909 0.20485813 0.55347073 0.96628385 0.95822507
   0.78587996 0.43900638 0.16688151 0.10912353]
  [0.33743262 0.32633121 0.29078137 0.

In [19]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,0.601098,0.57316,0.822353,0.908302,0.535437,0.423091,0.128572,0.142434,0.153457,0.287529
1,0.340466,0.440837,0.459382,0.82804,0.535603,0.086308,0.826414,0.738118,0.788368,0.61092
2,0.536328,0.85834,0.986478,0.894476,0.901333,0.2634,0.525463,0.088815,0.09367,0.154559
3,0.001003,0.481399,0.5541,0.42745,0.826853,0.106083,0.437282,0.896107,0.722766,0.949777
4,0.284875,0.059149,0.708277,0.25528,0.57307,0.257264,0.570555,0.70136,0.846,0.850017
5,0.010579,0.408624,0.817718,0.034748,0.084236,0.852167,0.139977,0.459388,0.594937,0.009289
6,0.255554,0.069623,0.663107,0.907058,0.994886,0.878797,0.427366,0.009677,0.361683,0.843638
7,0.670142,0.316309,0.204858,0.553471,0.966284,0.958225,0.78588,0.439006,0.166882,0.109124
8,0.337433,0.326331,0.290781,0.929663,0.781281,0.371212,0.231701,0.637805,0.116704,0.040279
9,0.333112,0.938909,0.432541,0.434792,0.728332,0.925696,0.546428,0.435906,0.828084,0.753313


In [20]:
def extract_contiguous_rows_with_stride(df, row_length=10):
    num_rows = len(df)
    num_chunks = num_rows - row_length + 1
    
    contiguous_rows = [df.iloc[i:i+row_length].values for i in range(num_chunks)]
    return np.array(contiguous_rows)

In [21]:
# Example usage:
# Creating a sample DataFrame with 25 rows and 10 columns
data = np.random.rand(25, 10)
df = pd.DataFrame(data)

In [22]:
# Extracting contiguous rows with a sequence stride of 1
contiguous_rows = extract_contiguous_rows_with_stride(df)

# Printing the extracted contiguous rows
print("Extracted Contiguous Rows with Stride 1:")
print(contiguous_rows)

Extracted Contiguous Rows with Stride 1:
[[[0.62017669 0.06605796 0.20121442 ... 0.25180125 0.10957105 0.24447385]
  [0.96186528 0.91577514 0.84535325 ... 0.34581923 0.84497572 0.46958808]
  [0.2156882  0.24649177 0.94341997 ... 0.80063642 0.1625548  0.33401217]
  ...
  [0.79262017 0.54227225 0.03531824 ... 0.88310964 0.41052776 0.97786787]
  [0.72288606 0.26022681 0.35429503 ... 0.76484748 0.21470965 0.657434  ]
  [0.40928304 0.33846398 0.06130671 ... 0.65260465 0.86120869 0.56838362]]

 [[0.96186528 0.91577514 0.84535325 ... 0.34581923 0.84497572 0.46958808]
  [0.2156882  0.24649177 0.94341997 ... 0.80063642 0.1625548  0.33401217]
  [0.64760649 0.5440651  0.4057327  ... 0.38691962 0.05159714 0.25262573]
  ...
  [0.72288606 0.26022681 0.35429503 ... 0.76484748 0.21470965 0.657434  ]
  [0.40928304 0.33846398 0.06130671 ... 0.65260465 0.86120869 0.56838362]
  [0.97950767 0.40087974 0.55409259 ... 0.65803212 0.82789983 0.84509592]]

 [[0.2156882  0.24649177 0.94341997 ... 0.80063642 0.16

In [27]:
def extract_contiguous_rows_with_stride(df, row_length=10, tail_length=4):
    num_rows = len(df)
    num_chunks = num_rows - row_length + 1
    
    contiguous_rows = []
    last_four_rows = []
    
    for i in range(num_chunks):
        chunk = df.iloc[i:i+row_length].values
        contiguous_rows.append(chunk[:row_length-tail_length])
        last_four_rows.append(chunk[-tail_length:])
    
    return np.array(contiguous_rows), np.array(last_four_rows)

In [28]:
# Example usage:
# Creating a sample DataFrame with 25 rows and 10 columns
data = np.random.rand(30, 10)
df = pd.DataFrame(data)

# Extracting contiguous rows with a sequence stride of 1 and the last 4 rows of each stride
contiguous_rows, last_four_rows = extract_contiguous_rows_with_stride(df)

# Printing the extracted contiguous rows and the last four rows of each stride
print("Extracted Contiguous Rows:")
print(contiguous_rows)
print("Last Four Rows of Each Stride:")
print(last_four_rows)

Extracted Contiguous Rows:
[[[0.11077741 0.8569821  0.75367966 ... 0.69978501 0.32801627 0.16326507]
  [0.65543091 0.95245975 0.70772906 ... 0.49688678 0.24574771 0.89238147]
  [0.58306365 0.42720285 0.42511236 ... 0.71538214 0.05785101 0.10702529]
  [0.53731147 0.60126131 0.0026757  ... 0.36649318 0.1326927  0.10944562]
  [0.55367853 0.20319867 0.34497508 ... 0.30397808 0.41939615 0.00907801]
  [0.81126477 0.30655458 0.52839141 ... 0.50183096 0.02420151 0.80476532]]

 [[0.65543091 0.95245975 0.70772906 ... 0.49688678 0.24574771 0.89238147]
  [0.58306365 0.42720285 0.42511236 ... 0.71538214 0.05785101 0.10702529]
  [0.53731147 0.60126131 0.0026757  ... 0.36649318 0.1326927  0.10944562]
  [0.55367853 0.20319867 0.34497508 ... 0.30397808 0.41939615 0.00907801]
  [0.81126477 0.30655458 0.52839141 ... 0.50183096 0.02420151 0.80476532]
  [0.75711584 0.98950622 0.10095464 ... 0.89606313 0.82509788 0.35641763]]

 [[0.58306365 0.42720285 0.42511236 ... 0.71538214 0.05785101 0.10702529]
  [0.53

In [29]:
def extract_contiguous_rows_with_stride(df, row_length=10, tail_length=4):
    num_rows = len(df)
    num_chunks = num_rows - row_length + 1
    
    contiguous_rows = []
    last_four_rows = []
    indices = []
    
    for i in range(num_chunks):
        chunk = df.iloc[i:i+row_length].values
        contiguous_rows.append(chunk[:row_length-tail_length])
        last_four_rows.append(chunk[-tail_length:])
        indices.append(i)  # Adding the index
        
    return np.array(contiguous_rows), np.array(last_four_rows), np.array(indices)

# Example usage:
# Creating a sample DataFrame with 25 rows and 10 columns
data = np.random.rand(25, 10)
df = pd.DataFrame(data)

# Extracting contiguous rows with a sequence stride of 1 and the last 4 rows of each stride
contiguous_rows, last_four_rows, indices = extract_contiguous_rows_with_stride(df)

# Printing the extracted contiguous rows, the last four rows of each stride, and the indices
print("Extracted Contiguous Rows:")
print(contiguous_rows)
print("Last Four Rows of Each Stride:")
print(last_four_rows)
print("Indices:")
print(indices)

Extracted Contiguous Rows:
[[[0.3165526  0.45030371 0.02555237 0.55658248 0.52699726 0.15864778
   0.91404655 0.30557164 0.68396498 0.84734853]
  [0.77954928 0.84197442 0.05323632 0.7954458  0.96310801 0.76669572
   0.66645954 0.30261818 0.04746557 0.75145899]
  [0.83274792 0.78805185 0.50035178 0.1367943  0.54623913 0.75208813
   0.10837886 0.74791636 0.86857351 0.58165527]
  [0.79515591 0.88203724 0.20868399 0.10633299 0.99820447 0.66678087
   0.37105407 0.83310598 0.15384178 0.96007038]
  [0.62927062 0.49609499 0.10547657 0.604357   0.03020461 0.10364393
   0.96487511 0.88768433 0.57133854 0.39159458]
  [0.91132914 0.67294207 0.07134428 0.99146044 0.94556458 0.78676524
   0.91041948 0.15139911 0.93587883 0.55702056]]

 [[0.77954928 0.84197442 0.05323632 0.7954458  0.96310801 0.76669572
   0.66645954 0.30261818 0.04746557 0.75145899]
  [0.83274792 0.78805185 0.50035178 0.1367943  0.54623913 0.75208813
   0.10837886 0.74791636 0.86857351 0.58165527]
  [0.79515591 0.88203724 0.20868399

In [32]:
contiguous_rows.shape, last_four_rows.shape, indices.shape

((16, 6, 10), (16, 4, 10), (16,))

In [21]:
class DataExtractor:
    def __init__(self, df, row_length=10, tail_length=4):
        self.data = self.extract_contiguous_rows_with_stride(df, row_length, tail_length)

    def extract_contiguous_rows_with_stride(self, df, row_length=10, tail_length=4):
        num_rows = len(df)                      # tail_length is pred_len
        num_chunks = num_rows - row_length + 1 #500-10 +1 = 491 (25-10 +1 =16) so seq_len is the same as row_length 

        contiguous_rows = []
        last_four_rows = []
        indices = []

        for i in range(num_chunks):
            chunk = df.iloc[i:i+row_length].values
            contiguous_rows.append(chunk[:row_length-tail_length])
            last_four_rows.append(chunk[-tail_length:])
            indices.append(i)  # Adding the index

        data = {
            "contiguous_rows": np.array(contiguous_rows),
            "last_four_rows": np.array(last_four_rows),
            "indices": np.array(indices)
        }

        return data

    def __len__(self):
        return len(self.data["indices"])

    def __getitem__(self, index):
        idx = self.data["indices"][index]
        return {
            "contiguous_rows": self.data["contiguous_rows"][idx],
            "last_four_rows": self.data["last_four_rows"][idx]
        }

# Example usage:
# Creating a sample DataFrame with 25 rows and 10 columns
data = np.random.rand(25, 10)
df = pd.DataFrame(data)

# Creating a DataExtractor instance
data_extractor = DataExtractor(df)

# Printing the length of the data extractor
print("Length of DataExtractor:", len(data_extractor))

# Accessing data using indices
index = 1
sample_data = data_extractor[index]
print(f"Data at index {index}:")
print("Contiguous Rows:")
print(sample_data["contiguous_rows"])
print("Last Four Rows:")
print(sample_data["last_four_rows"])
print("Target Variable")
print(sample_data["last_four_rows"][:,-1])

Length of DataExtractor: 16
Data at index 1:
Contiguous Rows:
[[0.15285431 0.65462181 0.56167772 0.52696453 0.49640244 0.37910133
  0.13925963 0.6570447  0.38317061 0.69102872]
 [0.25321915 0.64109552 0.11114328 0.18960985 0.70781113 0.53701501
  0.48673221 0.38424617 0.12962291 0.88919029]
 [0.20080035 0.15545148 0.36214991 0.16847312 0.51507733 0.46489874
  0.67801585 0.39552835 0.31524421 0.32085704]
 [0.03792975 0.95827748 0.48929388 0.34531968 0.880029   0.4360724
  0.47271623 0.50182544 0.05820181 0.88169814]
 [0.94487076 0.0322945  0.882504   0.4397508  0.34112055 0.0283295
  0.16091922 0.22940213 0.34205537 0.82598114]
 [0.69294692 0.56992619 0.08944812 0.55355615 0.5176281  0.65459783
  0.91250194 0.41022413 0.23006759 0.31659683]]
Last Four Rows:
[[0.89805834 0.44149838 0.09612947 0.05196111 0.00995504 0.71771621
  0.69550615 0.56586653 0.73921402 0.85830038]
 [0.39328632 0.91392636 0.38466649 0.31992372 0.79605568 0.51798845
  0.19680281 0.26919567 0.06761409 0.16437797]
 [0

In [22]:
# Creating a DataExtractor instance
train_data_extractor = DataExtractor(train_df)

# Printing the length of the data extractor
print("Length of DataExtractor:", len(data_extractor))

Length of DataExtractor: 16


In [None]:
train_df

In [23]:
train_data_loader = train_data_extractor[index]

In [25]:
print(f"Data at index {index}:")
print("Contiguous Rows:")
print(train_data_loader["contiguous_rows"])
print("Last Four Rows:")
print(train_data_loader["last_four_rows"])
print("Target Variable")
print(train_data_loader["last_four_rows"][:,-1])

Data at index 1:
Contiguous Rows:
[[ 1.00876e+03  7.50000e-01  2.73220e+02 -1.44000e+00  8.52000e+01
   6.45000e+00  5.49000e+00  9.50000e-01  3.39000e+00  5.45000e+00
   1.28033e+03  4.30000e-01  8.40000e-01  2.06800e+02  0.00000e+00
   0.00000e+00  0.00000e+00  0.00000e+00  0.00000e+00  1.15100e+01
   4.28000e+02]
 [ 1.00866e+03  7.30000e-01  2.73210e+02 -1.48000e+00  8.51000e+01
   6.44000e+00  5.48000e+00  9.60000e-01  3.39000e+00  5.43000e+00
   1.28029e+03  6.10000e-01  1.48000e+00  1.97100e+02  0.00000e+00
   0.00000e+00  0.00000e+00  0.00000e+00  0.00000e+00  1.16000e+01
   4.27600e+02]
 [ 1.00864e+03  3.70000e-01  2.72860e+02 -1.64000e+00  8.63000e+01
   6.27000e+00  5.41000e+00  8.60000e-01  3.35000e+00  5.37000e+00
   1.28197e+03  1.11000e+00  1.48000e+00  2.06400e+02  0.00000e+00
   0.00000e+00  0.00000e+00  0.00000e+00  0.00000e+00  1.17000e+01
   4.30000e+02]
 [ 1.00861e+03  3.30000e-01  2.72820e+02 -1.50000e+00  8.74000e+01
   6.26000e+00  5.47000e+00  7.90000e-01  3.380

In [28]:
x = np.random.rand(4, 5, 10)
print("Original Input Tensor:")
print(x)

Original Input Tensor:
[[[0.1121732  0.21145365 0.22691804 0.84945216 0.24424286 0.86074256
   0.22154613 0.4374376  0.49588739 0.78989321]
  [0.81563125 0.41585561 0.82290778 0.81856687 0.61974323 0.27029609
   0.68879049 0.23470474 0.37293059 0.58290402]
  [0.8180706  0.6918732  0.14762489 0.4465482  0.01016277 0.39431155
   0.32187705 0.07863415 0.39889036 0.38394825]
  [0.89075422 0.75567664 0.29328732 0.32541461 0.20811468 0.0238194
   0.15650301 0.88512653 0.71146054 0.31503806]
  [0.17679821 0.28008859 0.82933239 0.08072794 0.7444727  0.83763619
   0.96138969 0.80538641 0.7189273  0.6098467 ]]

 [[0.75540113 0.90994093 0.34456729 0.78894882 0.02488765 0.10830681
   0.21183278 0.63196654 0.65349304 0.45881887]
  [0.79534675 0.18982417 0.30494628 0.07188426 0.03524168 0.86382279
   0.7767972  0.67422577 0.87776016 0.14068815]
  [0.91756188 0.11220004 0.78429788 0.90882453 0.73188567 0.1359645
   0.47847985 0.12991273 0.95624889 0.29094846]
  [0.50293377 0.39382575 0.96616551 0.265

In [29]:
# Transpose the input tensor for processing in the model
x_transposed = np.transpose(x, (0, 2, 1))  # [Batch, Input Length, Channel]
print("\nTransposed Input Tensor:")
print(x_transposed)


Transposed Input Tensor:
[[[0.1121732  0.81563125 0.8180706  0.89075422 0.17679821]
  [0.21145365 0.41585561 0.6918732  0.75567664 0.28008859]
  [0.22691804 0.82290778 0.14762489 0.29328732 0.82933239]
  [0.84945216 0.81856687 0.4465482  0.32541461 0.08072794]
  [0.24424286 0.61974323 0.01016277 0.20811468 0.7444727 ]
  [0.86074256 0.27029609 0.39431155 0.0238194  0.83763619]
  [0.22154613 0.68879049 0.32187705 0.15650301 0.96138969]
  [0.4374376  0.23470474 0.07863415 0.88512653 0.80538641]
  [0.49588739 0.37293059 0.39889036 0.71146054 0.7189273 ]
  [0.78989321 0.58290402 0.38394825 0.31503806 0.6098467 ]]

 [[0.75540113 0.79534675 0.91756188 0.50293377 0.95323637]
  [0.90994093 0.18982417 0.11220004 0.39382575 0.88615299]
  [0.34456729 0.30494628 0.78429788 0.96616551 0.94611553]
  [0.78894882 0.07188426 0.90882453 0.26514335 0.21619811]
  [0.02488765 0.03524168 0.73188567 0.22263514 0.48831947]
  [0.10830681 0.86382279 0.1359645  0.74788597 0.81766777]
  [0.21183278 0.7767972  0.4

In [32]:
# Apply target slicing to select channels 1 and 2
target_slice = slice(4, 5)
print(target_slice)

slice(4, 5, None)


In [33]:
sliced_output = x_transposed[:, :, target_slice]
print("\nSliced Output Tensor (Channels 1 and 2):")
print(sliced_output)


Sliced Output Tensor (Channels 1 and 2):
[[[0.17679821]
  [0.28008859]
  [0.82933239]
  [0.08072794]
  [0.7444727 ]
  [0.83763619]
  [0.96138969]
  [0.80538641]
  [0.7189273 ]
  [0.6098467 ]]

 [[0.95323637]
  [0.88615299]
  [0.94611553]
  [0.21619811]
  [0.48831947]
  [0.81766777]
  [0.62219875]
  [0.71422961]
  [0.65889709]
  [0.17232261]]

 [[0.83343803]
  [0.71024409]
  [0.01332391]
  [0.45065694]
  [0.36953463]
  [0.84220447]
  [0.11644779]
  [0.33224332]
  [0.62075559]
  [0.49488454]]

 [[0.41153344]
  [0.1353498 ]
  [0.04096466]
  [0.76332931]
  [0.78057134]
  [0.12977589]
  [0.87585884]
  [0.90540503]
  [0.467968  ]
  [0.87882748]]]


In [39]:
#sliced_output.transpose(1,2)
np.transpose(sliced_output, (0, 1, 2)) 

array([[[0.17679821],
        [0.28008859],
        [0.82933239],
        [0.08072794],
        [0.7444727 ],
        [0.83763619],
        [0.96138969],
        [0.80538641],
        [0.7189273 ],
        [0.6098467 ]],

       [[0.95323637],
        [0.88615299],
        [0.94611553],
        [0.21619811],
        [0.48831947],
        [0.81766777],
        [0.62219875],
        [0.71422961],
        [0.65889709],
        [0.17232261]],

       [[0.83343803],
        [0.71024409],
        [0.01332391],
        [0.45065694],
        [0.36953463],
        [0.84220447],
        [0.11644779],
        [0.33224332],
        [0.62075559],
        [0.49488454]],

       [[0.41153344],
        [0.1353498 ],
        [0.04096466],
        [0.76332931],
        [0.78057134],
        [0.12977589],
        [0.87585884],
        [0.90540503],
        [0.467968  ],
        [0.87882748]]])

In [None]:
def process_sequential(data):
        """
        Pre-process dataset for multiple-step-ahead prediction: explodes dataset to a larger one with rolling origin
        Args:

        """

        assert self.processed

        if not self.processed_sequential:
            logger.info(
                f"Processing {self.subset_name} dataset before training (multiple sequences)"
            )

            outputs = self.data["outputs"]
            sequence_lengths = self.data["sequence_lengths"]
            active_entries = self.data["active_entries"]
            current_treatments = self.data["current_treatments"]
            previous_treatments = self.data["prev_treatments"][
                :, 1:, :
            ]  # Without zero_init_treatment
            current_covariates = self.data["current_covariates"]
            stabilized_weights = (
                self.data["stabilized_weights"]
                if "stabilized_weights" in self.data
                else None
            )

            num_patients, seq_length, num_features = outputs.shape

            num_seq2seq_rows = num_patients * seq_length
            # the encoder in the new model might not be defined: check for the dimension of the encoder_r and the stabilizer weight.  
            seq2seq_state_inits = np.zeros((num_seq2seq_rows, encoder_r.shape[-1])) # the encoder in the new model might not be defined
            seq2seq_active_encoder_r = np.zeros((num_seq2seq_rows, seq_length))
            seq2seq_original_index = np.zeros((num_seq2seq_rows,))
            seq2seq_previous_treatments = np.zeros(
                (num_seq2seq_rows, projection_horizon, previous_treatments.shape[-1])
            )
            seq2seq_current_treatments = np.zeros(
                (num_seq2seq_rows, projection_horizon, current_treatments.shape[-1])
            )
            seq2seq_current_covariates = np.zeros(
                (num_seq2seq_rows, projection_horizon, current_covariates.shape[-1])
            )
            seq2seq_outputs = np.zeros(
                (num_seq2seq_rows, projection_horizon, outputs.shape[-1])
            )
            seq2seq_active_entries = np.zeros(
                (num_seq2seq_rows, projection_horizon, active_entries.shape[-1])
            )
            seq2seq_sequence_lengths = np.zeros(num_seq2seq_rows)
            seq2seq_stabilized_weights = (
                np.zeros((num_seq2seq_rows, projection_horizon + 1))
                if stabilized_weights is not None
                else None
            )

            total_seq2seq_rows = 0  # we use this to shorten any trajectories later

            for i in range(num_patients):
                sequence_length = int(sequence_lengths[i])

                for t in range(
                    1, sequence_length - projection_horizon
                ):  # shift outputs back by 1
                    seq2seq_state_inits[total_seq2seq_rows, :] = encoder_r[
                        i, t - 1, :
                    ]  # previous state output
                    seq2seq_original_index[total_seq2seq_rows] = i
                    seq2seq_active_encoder_r[total_seq2seq_rows, :t] = 1.0

                    max_projection = min(projection_horizon, sequence_length - t)

                    seq2seq_active_entries[
                        total_seq2seq_rows, :max_projection, :
                    ] = active_entries[i, t : t + max_projection, :]
                    seq2seq_previous_treatments[
                        total_seq2seq_rows, :max_projection, :
                    ] = previous_treatments[i, t - 1 : t + max_projection - 1, :]
                    seq2seq_current_treatments[
                        total_seq2seq_rows, :max_projection, :
                    ] = current_treatments[i, t : t + max_projection, :]
                    seq2seq_outputs[total_seq2seq_rows, :max_projection, :] = outputs[
                        i, t : t + max_projection, :
                    ]
                    seq2seq_sequence_lengths[total_seq2seq_rows] = max_projection
                    seq2seq_current_covariates[
                        total_seq2seq_rows, :max_projection, :
                    ] = current_covariates[i, t : t + max_projection, :]

                    if (
                        seq2seq_stabilized_weights is not None
                    ):  # Also including SW of one-step-ahead prediction
                        seq2seq_stabilized_weights[
                            total_seq2seq_rows, :
                        ] = stabilized_weights[i, t - 1 : t + max_projection]

                    total_seq2seq_rows += 1

            # Filter everything shorter
            seq2seq_state_inits = seq2seq_state_inits[:total_seq2seq_rows, :]
            seq2seq_original_index = seq2seq_original_index[:total_seq2seq_rows]
            seq2seq_active_encoder_r = seq2seq_active_encoder_r[:total_seq2seq_rows, :]
            seq2seq_previous_treatments = seq2seq_previous_treatments[
                :total_seq2seq_rows, :, :
            ]
            seq2seq_current_treatments = seq2seq_current_treatments[
                :total_seq2seq_rows, :, :
            ]
            seq2seq_current_covariates = seq2seq_current_covariates[
                :total_seq2seq_rows, :, :
            ]
            seq2seq_outputs = seq2seq_outputs[:total_seq2seq_rows, :, :]
            seq2seq_active_entries = seq2seq_active_entries[:total_seq2seq_rows, :, :]
            seq2seq_sequence_lengths = seq2seq_sequence_lengths[:total_seq2seq_rows]
            if seq2seq_stabilized_weights is not None:
                seq2seq_stabilized_weights = seq2seq_stabilized_weights[
                    :total_seq2seq_rows
                ]

            # Package outputs
            seq2seq_data = {
                "init_state": seq2seq_state_inits,
                "original_index": seq2seq_original_index,
                "active_encoder_r": seq2seq_active_encoder_r,
                "prev_treatments": seq2seq_previous_treatments,
                "current_treatments": seq2seq_current_treatments,
                "current_covariates": seq2seq_current_covariates,
                "prev_outputs": seq2seq_current_covariates[:, :, :1],
                "static_features": seq2seq_current_covariates[:, 0, 1:],
                "outputs": seq2seq_outputs,
                "sequence_lengths": seq2seq_sequence_lengths,
                "active_entries": seq2seq_active_entries,
                "unscaled_outputs": seq2seq_outputs * self.scaling_params["output_stds"]
                + self.scaling_params["output_means"],
            }
            if seq2seq_stabilized_weights is not None:
                seq2seq_data["stabilized_weights"] = seq2seq_stabilized_weights

            self.data_original = deepcopy(self.data)
            self.data = seq2seq_data
            data_shapes = {k: v.shape for k, v in self.data.items()}
            logger.info(f"Shape of processed {self.subset_name} data: {data_shapes}")

            if save_encoder_r:
                self.encoder_r = encoder_r[:, :seq_length, :]

            self.processed_sequential = True
            self.exploded = True

        else:
            logger.info(
                f"{self.subset_name} Dataset already processed (multiple sequences)"
            )

        return self.data