In [57]:
from itertools import islice
from matplotlib import pyplot as plt
import matplotlib.dates as mdates
import torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.pandas import PandasDataset
import pandas as pd
import numpy as np
import sys, os

PROJ_DIR = "E:\Study\Lab\TimeSeriesAD\lag-llama-tuning\lag-llama"
sys.path.append(os.path.join(PROJ_DIR))
from lag_llama.gluon.estimator import LagLlamaEstimator

In [58]:
path = "./dataset/swat/SWaT_test.csv"
df_origin = pd.read_csv(
    path, usecols=[' Timestamp', 'LIT101', 'label'], index_col=0, parse_dates=True, dayfirst=True
)

In [59]:
for col in ['LIT101']:
    # 检查列的类型不是字符串类型
    if df_origin[col].dtype != 'object' and pd.api.types.is_string_dtype(df_origin[col]) == False:
        df_origin[col] = df_origin[col].astype('float32')

In [65]:
def window_splitter(input_data, lag_window, predict_window):
    # 评估时滑动步长为predict_window
    step = predict_window
    output_seq = []
    L = len(input_data)
    for i in range((L - lag_window) // step):
        use_seq = input_data[i*step : i*step + lag_window]
        use_seq.loc[:, ('item_id')] = i
        output_seq.append(use_seq)
    return output_seq

In [67]:
max_context_length_of_laglallma = 1092
predict_length = 6
context_length = predict_length*3
df_list = window_splitter(df_origin, max_context_length_of_laglallma, predict_length)

In [68]:
df_list

[                         LIT101  label  item_id
  Timestamp                                     
 2015-12-28 10:00:00  522.846680      0        0
 2015-12-28 10:00:01  522.885986      0        0
 2015-12-28 10:00:02  522.846680      0        0
 2015-12-28 10:00:03  522.964478      0        0
 2015-12-28 10:00:04  523.474792      0        0
 ...                         ...    ...      ...
 2015-12-28 10:18:07  696.461670      0        0
 2015-12-28 10:18:08  697.089722      0        0
 2015-12-28 10:18:09  698.031799      0        0
 2015-12-28 10:18:10  698.385071      0        0
 2015-12-28 10:18:11  698.620605      0        0
 
 [1092 rows x 3 columns],
                          LIT101  label  item_id
  Timestamp                                     
 2015-12-28 10:00:06  524.102783      0        1
 2015-12-28 10:00:07  524.220581      0        1
 2015-12-28 10:00:08  524.495422      0        1
 2015-12-28 10:00:09  524.063599      0        1
 2015-12-28 10:00:10  524.102783      0  

In [69]:
print(len(df_list))
print(df_list[-1])

74804
                         LIT101  label  item_id
 Timestamp                                     
2016-01-02 14:41:39  504.712006      0    74803
2016-01-02 14:41:40  505.065186      0    74803
2016-01-02 14:41:41  504.672699      0    74803
2016-01-02 14:41:42  504.790497      0    74803
2016-01-02 14:41:43  504.437195      0    74803
...                         ...    ...      ...
2016-01-02 14:59:46  516.644775      0    74803
2016-01-02 14:59:47  516.487793      0    74803
2016-01-02 14:59:48  516.370117      0    74803
2016-01-02 14:59:49  516.566284      0    74803
2016-01-02 14:59:50  516.801819      0    74803

[1092 rows x 3 columns]


In [70]:
def get_lag_llama_predictions(dataset, prediction_length, num_samples=100):
    ckpt = torch.load(
        "../content/lag-llama.ckpt", map_location=torch.device("cuda:0")
    )  # Uses GPU since in this Colab we use a GPU.
    estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

    estimator = LagLlamaEstimator(
        ckpt_path="../content/lag-llama.ckpt",
        prediction_length=prediction_length,
        context_length=32,  # Should not be changed; this is what the released Lag-Llama model was trained with
        # estimator args
        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],
        batch_size=1,
        num_parallel_samples=100,
    )

    lightning_module = estimator.create_lightning_module()
    transformation = estimator.create_transformation()
    predictor = estimator.create_predictor(transformation, lightning_module)

    forecast_it, ts_it = make_evaluation_predictions(
        dataset=dataset, predictor=predictor, num_samples=num_samples
    )
    forecasts = list(forecast_it)
    tss = list(ts_it)
    return forecasts, tss

In [85]:
long_seq = pd.concat(df_list)

In [86]:
print(type(long_seq))
print(long_seq)

<class 'pandas.core.frame.DataFrame'>
                         LIT101  label  item_id
 Timestamp                                     
2015-12-28 10:00:00  522.846680      0        0
2015-12-28 10:00:01  522.885986      0        0
2015-12-28 10:00:02  522.846680      0        0
2015-12-28 10:00:03  522.964478      0        0
2015-12-28 10:00:04  523.474792      0        0
...                         ...    ...      ...
2016-01-02 14:59:46  516.644775      0    74803
2016-01-02 14:59:47  516.487793      0    74803
2016-01-02 14:59:48  516.370117      0    74803
2016-01-02 14:59:49  516.566284      0    74803
2016-01-02 14:59:50  516.801819      0    74803

[81685968 rows x 3 columns]


In [90]:
dataset = PandasDataset.from_long_dataframe(long_seq, target="LIT101", item_id="item_id")

In [91]:
num_samples = 100
forecasts, tss = get_lag_llama_predictions(dataset, predict_length, num_samples)

In [None]:
# for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)),len(forecasts)):
#     forecast.plot(color='g', name="predict")  # 绘制预测结果
#     plt.xticks(rotation=60)  # x 轴标签旋转 60 度
#     plt.gcf().autofmt_xdate()
#     plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))