<a href="https://colab.research.google.com/github/JHyunjun/Attention/blob/main/Tabnet_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch numpy scikit-learn pytorch-tabnet

In [None]:
import numpy as np
import torch
from pytorch_tabnet.tab_model import TabNetRegressor
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 주가 데이터 생성 (예시)
np.random.seed(42)
n_samples = 1000
time_steps = 50
n_features = 3
X = np.random.rand(n_samples, time_steps, n_features)

# 데이터 정규화
scaler = StandardScaler()
X_norm = scaler.fit_transform(X.reshape(-1, n_features)).reshape(n_samples, time_steps, n_features)

# 훈련 및 검증 데이터 분할
X_train, X_val, _, _ = train_test_split(X_norm, X_norm, test_size=0.2, random_state=42)

# 입력 및 출력 크기 정의
input_dim = time_steps * n_features
output_dim = time_steps * n_features

# TabNet AutoEncoder 모델 정의
tabnet_autoencoder = TabNetRegressor(n_d=16, n_a=16, n_steps=5, gamma=1.3, n_independent=2, n_shared=2, epsilon=1e-15,
                                     optimizer_fn=torch.optim.Adam,
                                     optimizer_params=dict(lr=1e-3), scheduler_params=None, scheduler_fn=None,
                                     mask_type='entmax', input_dim=input_dim, output_dim=output_dim)

# 모델 훈련
tabnet_autoencoder.fit(X_train=X_train.reshape(-1, input_dim), y_train=X_train.reshape(-1, input_dim),
                       eval_set=[(X_val.reshape(-1, input_dim), X_val.reshape(-1, input_dim))],
                       max_epochs=100, patience=10, batch_size=32, virtual_batch_size=32)

# 모델 평가 및 예측
X_val_encoded = tabnet_autoencoder.predict(X_val.reshape(-1, input_dim))
X_val_encoded = X_val_encoded.reshape(-1, time_steps, n_features)


In [None]:
print(X_val_encoded.shape)

In [None]:
import matplotlib.pyplot as plt

# Reconstruction Error 계산
reconstruction_error = np.mean((X_val - X_val_encoded) ** 2, axis=(1, 2))

# Reconstruction Error 시각화
plt.figure(figsize=(6, 6))
plt.plot(reconstruction_error)
plt.xlabel('Sample index')
plt.ylabel('Reconstruction Error')
plt.title('Reconstruction Error for Validation Data')
plt.show()