# Task02. 搭建开发环境并运行、理解时序插补工作流

## 1. 开发环境配置

### PyPOTS开发环境支持多种安装方式, 你可以自由选择从源码安装, 通过pip安装PyPI上的发布版本或者使用conda从conda-forge的发行版进行环境配置, 如果你熟悉docker的使用方式, 也可以通过docker来获取我们已经为你配置好的PyPOTS开发环境容器

### 从下方选择一种你熟悉的安装方式来为PyPOTS配置Python开发环境

In [None]:
# 从源码安装
!pip install https://github.com/WenjieDu/PyPOTS/archive/main.zip

In [None]:
# 从PyPI安装 
!pip install pypots

In [None]:
# 从conda-forge安装 (‼️请确定你熟悉conda的操作并且确认你的电脑上安装了conda)
!conda install conda-forge::pypots

In [None]:
# 运行配置好PyPOTS开发环境的docker容器 (‼️请确定你熟悉docker的使用并且确认你的电脑上安装了docker)
!docker run -it --name pypots wenjiedu/pypots

## 2. 时间序列插补工作流

## 生成一个随机的时间序列数据集

In [None]:
from benchpots.datasets import preprocess_physionet2012

physionet2012_dataset = preprocess_physionet2012(
    subset="set-a", 
    pattern="point", 
    rate=0.1,
)
print(physionet2012_dataset.keys())

In [None]:
import numpy as np

physionet2012_dataset["test_X_indicating_mask"] = np.isnan(physionet2012_dataset["test_X"]) ^ np.isnan(physionet2012_dataset["test_X_ori"])
physionet2012_dataset["test_X_ori"] = np.nan_to_num(physionet2012_dataset["test_X_ori"])

train_set = {
    "X": physionet2012_dataset["train_X"],
}
val_set = {
    "X": physionet2012_dataset["val_X"],
    "X_ori": physionet2012_dataset["val_X_ori"],
}
test_set = {
    "X": physionet2012_dataset["test_X"],
    "X_ori": physionet2012_dataset["test_X_ori"],
}

In [None]:
physionet2012_dataset['n_features']

In [None]:
from pypots.imputation import SAITS

saits = SAITS(
    n_steps=physionet2012_dataset['n_steps'],
    n_features=physionet2012_dataset['n_features'],
    n_layers=3,
    d_model=64,
    n_heads=4,
    d_k=16,
    d_v=16,
    d_ffn=128,
    dropout=0.1,
    epochs=10,
)

In [None]:
saits.fit(train_set, val_set)

In [None]:
test_set_imputation_results = saits.predict(test_set)

In [None]:
from pypots.nn.functional import calc_mse

test_MSE = calc_mse(
            test_set_imputation_results["imputation"],
            physionet2012_dataset["test_X_ori"],
            physionet2012_dataset["test_X_indicating_mask"],
)
print(f"SAITS test_MSE: {test_MSE}")

In [None]:
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pypots.utils.logging import logger


# TODO: 优化该画图函数
def plot_data(
    X: np.ndarray,
    X_ori: np.ndarray,
    X_imputed: np.ndarray,
    sample_idx: Optional[int] = None,
    n_rows: int = 10,
    n_cols: int = 4,
    fig_size: Optional[list] = None,
):

    vals_shape = X.shape
    assert len(vals_shape) == 3, "vals_obs should be a 3D array of shape (n_samples, n_steps, n_features)"
    n_samples, n_steps, n_features = vals_shape

    if sample_idx is None:
        sample_idx = np.random.randint(low=0, high=n_samples)
        logger.warning(f"⚠️ No sample index is specified, a random sample {sample_idx} is selected for visualization.")

    if fig_size is None:
        fig_size = [24, 36]

    n_k = n_rows * n_cols
    K = np.min([n_features, n_k])
    L = n_steps
    plt.rcParams["font.size"] = 16
    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(fig_size[0], fig_size[1]))

    for k in range(K):
        df = pd.DataFrame({"x": np.arange(0, L), "val": X_imputed[sample_idx, :, k]})
        df1 = pd.DataFrame({"x": np.arange(0, L), "val": X[sample_idx, :, k]})
        df2 = pd.DataFrame({"x": np.arange(0, L), "val": X_ori[sample_idx, :, k]})
        row = k // n_cols
        col = k % n_cols
        axes[row][col].plot(df1.x, df1.val, color="r", marker="x", linestyle="None")
        axes[row][col].plot(df2.x, df2.val, color="b", marker="o", linestyle="None")
        axes[row][col].plot(df.x, df.val, color="g", linestyle="solid")
        if col == 0:
            plt.setp(axes[row, 0], ylabel="value")
        if row == -1:
            plt.setp(axes[-1, col], xlabel="time")

    logger.info("Plotting finished. Please invoke matplotlib.pyplot.show() to display the plot.")


plot_data(
    test_set["X"], 
    test_set["X_ori"], 
    test_set_imputation_results["imputation"],
    5,
    n_rows=7,
    n_cols=6,
    fig_size=[100, 50]
)