In [None]:
%matplotlib notebook

from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns
from test_data_set.test_data_gen import linear_data

data_size = 10000
seed = 1

x, y = linear_data(data_size=data_size, seed=seed)

# 转换为DataFrame便于Seaborn绘图
df = pd.DataFrame({"x1": x[:, 0], "x2": x[:, 1], "y": y})

# 创建1x2的子图布局
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle(f"Linear Data Visualization (n={data_size})", fontsize=16)

# 绘制x1与y的关系
sns.scatterplot(
    data=df,
    x="x1",
    y="y",
    alpha=0.3,  # 设置透明度避免重叠点
    ax=axes[0],
)
axes[0].set_title("Relationship between x1 and y")
axes[0].set_xlabel("Feature x1")
axes[0].set_ylabel("Target y")

# 添加回归线
sns.regplot(
    data=df,
    x="x1",
    y="y",
    scatter=False,  # 不显示散点（上面已显示）
    color="red",
    line_kws={"linewidth": 2},
    ax=axes[0],
)

# 绘制x2与y的关系
sns.scatterplot(data=df, x="x2", y="y", alpha=0.3, ax=axes[1])
axes[1].set_title("Relationship between x2 and y")
axes[1].set_xlabel("Feature x2")
axes[1].set_ylabel("Target y")

# 添加回归线
sns.regplot(
    data=df,
    x="x2",
    y="y",
    scatter=False,
    color="red",
    line_kws={"linewidth": 2},
    ax=axes[1],
)

# 添加3D散点图展示x1、x2与y的关系
fig_3d = plt.figure(figsize=(10, 8))
ax_3d = fig_3d.add_subplot(111, projection="3d")

# 抽样部分点避免3D图过于密集
sample = df.sample(1000, random_state=seed)

sc = ax_3d.scatter(
    sample["x1"],
    sample["x2"],
    sample["y"],
    c=sample["y"],  # 根据y值着色
    cmap="viridis",
    alpha=0.7,
)

ax_3d.set_title("3D View: x1, x2 and y")
ax_3d.set_xlabel("Feature x1")
ax_3d.set_ylabel("Feature x2")
ax_3d.set_zlabel("Target y")
fig_3d.colorbar(sc, label="y value")

# 显示所有图像
plt.tight_layout()
plt.show()