In [None]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
import os
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

class TFTModel:
    def __init__(self, data_path, features, target_column, max_encoder_length, max_prediction_length, batch_size, checkpoint_dir="./checkpoints", model_save_dir="./weights"):
        # 데이터 로드 및 전처리
        self.df = pd.read_csv(data_path)
        self.features = features
        self.target_column = target_column
        self.max_encoder_length = max_encoder_length
        self.max_prediction_length = max_prediction_length
        self.batch_size = batch_size
        self.checkpoint_dir = checkpoint_dir
        self.model_save_dir = model_save_dir

        self.df['product_id'] = self.df['product_id'].astype(str)
        self.df['week_date'] = pd.to_datetime(self.df['week_date'])
        self.df = self.df.sort_values("week_date").reset_index(drop=True)
        self.df["time_idx"] = (self.df["week_date"] - self.df["week_date"].min()).dt.days // 7  # 주 단위로 인덱스 생성

        # 8:2로 train과 validation 분리
        self.train_df, self.val_df = train_test_split(self.df, test_size=0.2, shuffle=False)
        
        # 훈련 데이터 준비
        self.training_cutoff = self.train_df['time_idx'].max() - max_prediction_length
        self.training = TimeSeriesDataSet(
            self.train_df[lambda x: x.time_idx <= self.training_cutoff],
            time_idx="time_idx",
            target=self.target_column,
            group_ids=["product_id"],
            min_encoder_length=max_encoder_length // 2,
            max_encoder_length=max_encoder_length,
            max_prediction_length=max_prediction_length,
            static_categoricals=["product_id"],
            time_varying_known_reals=["time_idx"] + features,
            time_varying_unknown_reals=[target_column],
            target_normalizer=GroupNormalizer(groups=["product_id"]),
        )

        # 모델 정의
        self.tft = TemporalFusionTransformer.from_dataset(
            self.training,
            learning_rate=0.03,
            hidden_size=16,
            attention_head_size=1,
            dropout=0.1,
            hidden_continuous_size=8,
            output_size=1,
            loss=RMSE(),
        )

        # Trainer 설정
        self.trainer = pl.Trainer(
            max_epochs=100,
            devices=1,
            accelerator='gpu',
            gradient_clip_val=0.1,
            logger=pl.loggers.TensorBoardLogger('tb_logs'),
            callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=5)]
        )

    def fit(self):
        """
        학습을 진행하는 함수
        """
        # 데이터 로더 준비
        train_dataloader = self.training.to_dataloader(train=True, batch_size=self.batch_size, num_workers=0)

        # 모델 학습
        pl.seed_everything(42)
        self.trainer.fit(self.tft, train_dataloaders=train_dataloader)

        # 체크포인트 저장
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(self.checkpoint_dir, "tft_model_checkpoint.ckpt")
        self.trainer.save_checkpoint(checkpoint_path)

        # 모델 가중치 저장
        os.makedirs(self.model_save_dir, exist_ok=True)
        model_save_path = os.path.join(self.model_save_dir, "tft_model_weights.pth")
        torch.save(self.tft.state_dict(), model_save_path)

        print("모델 학습 완료 및 저장.")

    def evaluate(self):
        """
        validation 데이터를 사용하여 모델 성능 평가
        """
        val_dataloader = self.training.to_dataloader(train=False, batch_size=self.batch_size, num_workers=0)
        predictions = self.trainer.predict(self.tft, val_dataloader)

        # 실제 값과 예측값 비교
        true_values = self.val_df[self.target_column].values
        mse = mean_squared_error(true_values, predictions)
        rmse = np.sqrt(mse)
        r2 = r2_score(true_values, predictions)
        
        print(f"Validation RMSE: {rmse:.4f}, R2: {r2:.4f}")

        return predictions

    def plot_predictions(self, product_ids):
        """
        각 상품별 마지막 주 예측 결과를 실제 값과 함께 시각화
        """
        for product_id in product_ids:
            product_data = self.val_df[self.val_df['product_id'] == product_id]
            product_data = product_data.sort_values("week_date")

            # 실제값 (파란색)과 예측값 (빨간색)
            true_values = product_data[self.target_column].values
            last_week_data = product_data.tail(1)
            prediction = self.trainer.predict(self.tft, self.training.to_dataloader(train=False, batch_size=self.batch_size, num_workers=0))

            # 그래프 그리기
            print(product_data['product_name'])
            plt.figure(figsize=(10, 6))
            plt.plot(product_data['week_date'], true_values, color='blue', label='Actual')
            plt.scatter(last_week_data['week_date'], prediction, color='red', label='Prediction', zorder=5)
            plt.title(f"Product {product_id}: Actual vs Predicted for Last Week")
            plt.xlabel("Week")
            plt.ylabel("Sales")
            plt.legend()
            plt.grid(True)
            plt.show()

    def predict_next_week(self, product_id):
        """
        마지막 주에 대해 다음 주 예측값을 계산하고 비교
        """
        product_data = self.df[self.df['product_id'] == product_id]
        last_week_sales = product_data['week_purchase_cnt'].iloc[-1]

        # 예측값 (다음 주 판매량)
        next_week_prediction = self.trainer.predict(self.tft, self.training.to_dataloader(train=False, batch_size=self.batch_size, num_workers=0))[-1]

        # 비교: 평균, 최대, 최소, 마지막 주 판매량
        mean_weekly_sales = product_data['week_purchase_cnt'].mean()
        max_weekly_sales = product_data['week_purchase_cnt'].max()
        min_weekly_sales = product_data['week_purchase_cnt'].min()

        print(f"Product {product_id}:")
        print(f"  Mean Weekly Sales: {mean_weekly_sales:.2f}")
        print(f"  Max Weekly Sales: {max_weekly_sales:.2f}")
        print(f"  Min Weekly Sales: {min_weekly_sales:.2f}")
        print(f"  Last Week Sales: {last_week_sales:.2f}")
        print(f"  Predicted Next Week Sales: {next_week_prediction:.2f}")
        
        return next_week_prediction

# 사용 예시
#if __name__ == "__main__":
# 필요한 변수 설정
data_path = "./fina_preprocessing_data.csv"
features = ['price', 'review_cnt', 'wish_cnt', 'sixMothRatio(puchase_cnt/review_cnt)', 'week_review_count', 'average_review_score', 'category1_encoded',
           'category2_encoded', 'category3_encoded', 'rolling_mean_purchase', 'rolling_std_purchase', 'week_num', 'month', 'month_sin', 'month_cos', 
            'week_sin', 'week_cos']

target_column = 'week_purchase_cnt'
max_encoder_length = 24
max_prediction_length = 1
batch_size = 64

# 모델 인스턴스 생성 및 학습 진행
tft_model = TFTModel(data_path, features, target_column, max_encoder_length, max_prediction_length, batch_size)
tft_model.fit()

# validation 데이터에 대한 예측
predictions = tft_model.evaluate()

# 예시 3개의 상품에 대해 마지막 주 예측 결과를 시각화
random_product_ids = df['product_id'].sample(n=5, random_state=42).tolist()
tft_model.plot_predictions(random_product_ids)  # 예시로 1, 2, 3번 상품에 대해 시각화

# 특정 상품에 대해 다음 주 예측값을 계산 및 비교
product_id = 1
tft_model.predict_next_week(product_id)


In [None]:
week_date                               datetime64[ns]
product_id                                       int64
product_name                                    object
price                                          float64
review_cnt                                     float64
purchase_cnt                                   float64
wish_cnt                                       float64
sixMothRatio(puchase_cnt/review_cnt)           float64
week_review_count                              float64
average_review_score                           float64
week_purchase_cnt                              float64
category1_encoded                              float64
category2_encoded                              float64
category3_encoded                              float64
rolling_mean_purchase                          float64
rolling_std_purchase                           float64
week_num                                        UInt32
month                                            int32
month_sin                                      float64
month_cos                                      float64
week_sin                                       Float64
week_cos                                       Float64

In [10]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning import Trainer
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
import os
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

In [3]:
data_path = "./fina_preprocessing_data.csv"
ds = pd.read_csv(data_path)
ds

Unnamed: 0,week_date,product_id,product_name,price,review_cnt,purchase_cnt,wish_cnt,sixMothRatio(puchase_cnt/review_cnt),week_review_count,average_review_score,...,category2_encoded,category3_encoded,rolling_mean_purchase,rolling_std_purchase,week_num,month,month_sin,month_cos,week_sin,week_cos
0,2023-09-25,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,6.0,4.666667,...,68.0,99.0,14.00,0.000000,39,9,-1.000000,-1.836970e-16,-1.000000,-1.836970e-16
1,2023-10-02,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,71.0,4.873239,...,68.0,99.0,91.50,109.601551,40,10,-0.866025,5.000000e-01,-0.992709,1.205367e-01
2,2023-10-09,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,97.0,4.742268,...,68.0,99.0,138.00,111.772090,41,10,-0.866025,5.000000e-01,-0.970942,2.393157e-01
3,2023-10-16,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,108.0,4.740741,...,68.0,99.0,167.75,108.944558,42,10,-0.866025,5.000000e-01,-0.935016,3.546049e-01
4,2023-10-23,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,89.0,4.853933,...,68.0,99.0,176.60,96.401763,43,10,-0.866025,5.000000e-01,-0.885456,4.647232e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23211,2024-09-30,88459191933,모리모토 뉴스타 M1 내야 외야 투수 올라운드 야구 글러브 우투 좌투,79000.0,8.0,52.0,25.0,5.777778,4.0,4.750000,...,13.0,184.0,23.00,0.000000,40,9,-1.000000,-1.836970e-16,-0.992709,1.205367e-01
23212,2024-10-07,88459191933,모리모토 뉴스타 M1 내야 외야 투수 올라운드 야구 글러브 우투 좌투,79000.0,8.0,52.0,25.0,5.777778,5.0,4.800000,...,13.0,184.0,25.50,3.535534,41,10,-0.866025,5.000000e-01,-0.970942,2.393157e-01
23213,2024-09-30,88473568229,노다지 영양 양곰탕 750g 7팩 곱창전골 내장탕 한우사골 즉석국 간편식,49900.0,70.0,1088.0,101.0,8.845528,1.0,3.000000,...,5.0,202.0,8.00,0.000000,40,9,-1.000000,-1.836970e-16,-0.992709,1.205367e-01
23214,2024-10-07,88473568229,노다지 영양 양곰탕 750g 7팩 곱창전골 내장탕 한우사골 즉석국 간편식,49900.0,70.0,1088.0,101.0,8.845528,93.0,4.526882,...,5.0,202.0,415.00,575.584920,41,10,-0.866025,5.000000e-01,-0.970942,2.393157e-01


In [4]:
print(ds.dtypes)

week_date                                object
product_id                                int64
product_name                             object
price                                   float64
review_cnt                              float64
purchase_cnt                            float64
wish_cnt                                float64
sixMothRatio(puchase_cnt/review_cnt)    float64
week_review_count                       float64
average_review_score                    float64
week_purchase_cnt                       float64
category1_encoded                       float64
category2_encoded                       float64
category3_encoded                       float64
rolling_mean_purchase                   float64
rolling_std_purchase                    float64
week_num                                  int64
month                                     int64
month_sin                               float64
month_cos                               float64
week_sin                                

In [18]:
ds[(ds['product_id']==88208430684)]

Unnamed: 0,week_date,product_id,product_name,price,review_cnt,purchase_cnt,wish_cnt,sixMothRatio(puchase_cnt/review_cnt),week_review_count,average_review_score,...,category2_encoded,category3_encoded,rolling_mean_purchase,rolling_std_purchase,week_num,month,month_sin,month_cos,week_sin,week_cos
23047,2024-08-05,88208430684,[9/3블랙예약발송]칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌,109000.0,142.0,567.0,151.0,3.910345,9.0,5.0,...,39.0,219.0,35.0,0.0,32,8,-0.866025,-0.5,-0.663123,-0.7485107
23048,2024-08-12,88208430684,[9/3블랙예약발송]칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌,109000.0,142.0,567.0,151.0,3.910345,17.0,4.941176,...,39.0,219.0,50.5,21.92031,33,8,-0.866025,-0.5,-0.748511,-0.6631227
23049,2024-08-19,88208430684,[9/3블랙예약발송]칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌,109000.0,142.0,567.0,151.0,3.910345,9.0,5.0,...,39.0,219.0,45.333333,17.897858,34,8,-0.866025,-0.5,-0.822984,-0.5680647
23050,2024-08-26,88208430684,[9/3블랙예약발송]칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌,109000.0,142.0,567.0,151.0,3.910345,24.0,5.0,...,39.0,219.0,57.25,27.956812,35,8,-0.866025,-0.5,-0.885456,-0.4647232
23051,2024-09-02,88208430684,[9/3블랙예약발송]칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌,109000.0,142.0,567.0,151.0,3.910345,1.0,5.0,...,39.0,219.0,46.4,34.275356,36,9,-1.0,-1.83697e-16,-0.935016,-0.3546049
23052,2024-09-09,88208430684,(팔찌 사전예약)칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌/목걸이,109000.0,142.0,567.0,151.0,3.910345,9.0,4.777778,...,39.0,219.0,44.5,31.008063,37,9,-1.0,-1.83697e-16,-0.970942,-0.2393157
23053,2024-09-16,88208430684,[9/3블랙예약발송]칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌,109000.0,142.0,567.0,151.0,3.910345,11.0,4.454545,...,39.0,219.0,45.833333,30.688217,38,9,-1.0,-1.83697e-16,-0.992709,-0.1205367
23054,2024-09-23,88208430684,칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌/목걸이,109000.0,142.0,567.0,151.0,3.910345,26.0,4.653846,...,39.0,219.0,51.666667,37.792415,39,9,-1.0,-1.83697e-16,-1.0,-1.83697e-16
23055,2024-09-30,88208430684,칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌/목걸이,109000.0,142.0,567.0,151.0,3.910345,22.0,4.454545,...,39.0,219.0,60.166667,39.009828,40,9,-1.0,-1.83697e-16,-0.992709,0.1205367
23056,2024-10-07,88208430684,칼로 네오디뮴 스트랩 3600 활력 부스터 에너지 팔찌/목걸이,109000.0,142.0,567.0,151.0,3.910345,17.0,4.294118,...,39.0,219.0,55.666667,35.898004,41,10,-0.866025,0.5,-0.970942,0.2393157


# 함수로만  작성

In [51]:
max_prediction_length = 1  # 다음 주 예측
max_encoder_length = 24    # 과거 12주 사용 (기존 24주에서 축소)
min_encoder_length = 8     # 최소 8주의 데이터 필요


def train_fillter(df):
    # 데이터 준비 단계
    # print("데이터 필터링 전:")
    # print(f"전체 제품 수: {len(df_filled['product_id'].unique())}")
    
    # 시퀀스 길이 계산
    sequence_lengths = df_filled.groupby('product_id').size()
    # print("\n시퀀스 길이 통계:")
    # print(f"최소 시퀀스 길이: {sequence_lengths.min()}")
    # print(f"최대 시퀀스 길이: {sequence_lengths.max()}")
    # print(f"평균 시퀀스 길이: {sequence_lengths.mean():.2f}")

    # 최소 필요 길이 설정
    min_required_length = min_encoder_length + max_prediction_length
    valid_products = sequence_lengths[sequence_lengths >= min_required_length].index


In [5]:
'''
week_date                                object
product_id                                int64
category1Id                             float64
category2Id                             float64
category3Id                             float64
price                                   float64
review_cnt                              float64
purchase_cnt                            float64
wish_cnt                                float64
sixMothRatio(puchase_cnt/review_cnt)    float64
week_review_count                       float64
average_review_score                    float64
week_purchase_cnt                       float64
dtype: object
'''

df_filled = ds.copy()
df_filled

Unnamed: 0,week_date,product_id,product_name,price,review_cnt,purchase_cnt,wish_cnt,sixMothRatio(puchase_cnt/review_cnt),week_review_count,average_review_score,...,category2_encoded,category3_encoded,rolling_mean_purchase,rolling_std_purchase,week_num,month,month_sin,month_cos,week_sin,week_cos
0,2023-09-25,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,6.0,4.666667,...,68.0,99.0,14.00,0.000000,39,9,-1.000000,-1.836970e-16,-1.000000,-1.836970e-16
1,2023-10-02,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,71.0,4.873239,...,68.0,99.0,91.50,109.601551,40,10,-0.866025,5.000000e-01,-0.992709,1.205367e-01
2,2023-10-09,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,97.0,4.742268,...,68.0,99.0,138.00,111.772090,41,10,-0.866025,5.000000e-01,-0.970942,2.393157e-01
3,2023-10-16,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,108.0,4.740741,...,68.0,99.0,167.75,108.944558,42,10,-0.866025,5.000000e-01,-0.935016,3.546049e-01
4,2023-10-23,6356018199,DIY 목재재단 나무 원목 합판 집성목 MDF 방부목 자작나무 히노끼,1000.0,61440.0,3745.0,1685.0,2.386871,89.0,4.853933,...,68.0,99.0,176.60,96.401763,43,10,-0.866025,5.000000e-01,-0.885456,4.647232e-01
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23211,2024-09-30,88459191933,모리모토 뉴스타 M1 내야 외야 투수 올라운드 야구 글러브 우투 좌투,79000.0,8.0,52.0,25.0,5.777778,4.0,4.750000,...,13.0,184.0,23.00,0.000000,40,9,-1.000000,-1.836970e-16,-0.992709,1.205367e-01
23212,2024-10-07,88459191933,모리모토 뉴스타 M1 내야 외야 투수 올라운드 야구 글러브 우투 좌투,79000.0,8.0,52.0,25.0,5.777778,5.0,4.800000,...,13.0,184.0,25.50,3.535534,41,10,-0.866025,5.000000e-01,-0.970942,2.393157e-01
23213,2024-09-30,88473568229,노다지 영양 양곰탕 750g 7팩 곱창전골 내장탕 한우사골 즉석국 간편식,49900.0,70.0,1088.0,101.0,8.845528,1.0,3.000000,...,5.0,202.0,8.00,0.000000,40,9,-1.000000,-1.836970e-16,-0.992709,1.205367e-01
23214,2024-10-07,88473568229,노다지 영양 양곰탕 750g 7팩 곱창전골 내장탕 한우사골 즉석국 간편식,49900.0,70.0,1088.0,101.0,8.845528,93.0,4.526882,...,5.0,202.0,415.00,575.584920,41,10,-0.866025,5.000000e-01,-0.970942,2.393157e-01


In [59]:
print(df_filled.dtypes)

week_date                                object
product_id                                int64
product_name                             object
price                                   float64
review_cnt                              float64
purchase_cnt                            float64
wish_cnt                                float64
sixMothRatio(puchase_cnt/review_cnt)    float64
week_review_count                       float64
average_review_score                    float64
week_purchase_cnt                       float64
category1_encoded                       float64
category2_encoded                       float64
category3_encoded                       float64
rolling_mean_purchase                   float64
rolling_std_purchase                    float64
week_num                                  int64
month                                     int64
month_sin                               float64
month_cos                               float64
week_sin                                

In [67]:
pip install --upgrade pip

Collecting pip
  Using cached pip-24.3.1-py3-none-any.whl.metadata (3.7 kB)
Using cached pip-24.3.1-py3-none-any.whl (1.8 MB)
Installing collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.2
    Uninstalling pip-24.2:
      Successfully uninstalled pip-24.2
Successfully installed pip-24.3.1
Note: you may need to restart the kernel to use updated packages.


In [68]:
!pip install -U pytorch-lightning pytorch-forecasting torch



In [12]:
pip show torch pytorch-lightning pytorch-forecasting


Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: lightning, pytorch-forecasting, pytorch-lightning, sentence-transformers, torchaudio, torchmetrics, torchvision
---
Name: pytorch-lightning
Version: 2.4.0
Summary: PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.
Home-page: https://github.com/Lightning-AI/lightning
Author: Lightning AI et al.
Author-email: pytorch@lightning.ai
License: Apache-2.0
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: fsspec, lightning-utilities, packaging, PyYAML, torch, torchmetrics, tqdm, typing-extensi

In [13]:
pip install --upgrade pytorch-forecasting

Note: you may need to restart the kernel to use updated packages.


In [14]:
pip show torch pytorch-lightning pytorch-forecasting


Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: lightning, pytorch-forecasting, pytorch-lightning, sentence-transformers, torchaudio, torchmetrics, torchvision
---
Name: pytorch-lightning
Version: 2.4.0
Summary: PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.
Home-page: https://github.com/Lightning-AI/lightning
Author: Lightning AI et al.
Author-email: pytorch@lightning.ai
License: Apache-2.0
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: fsspec, lightning-utilities, packaging, PyYAML, torch, torchmetrics, tqdm, typing-extensi

In [15]:
pip install pytorch-lightning==1.6.0


Collecting pytorch-lightning==1.6.0
  Downloading pytorch_lightning-1.6.0-py3-none-any.whl.metadata (33 kB)
Requested pytorch-lightning==1.6.0 from https://files.pythonhosted.org/packages/09/18/cee67f4849dea9a29b7af7cdf582246bcba9eaa73d9443e138a4172ec786/pytorch_lightning-1.6.0-py3-none-any.whl has invalid metadata: .* suffix can only be used with `==` or `!=` operators
    torch (>=1.8.*)
           ~~~~~~^
Please use pip<24.1 if you need to use this version.[0m[33m
[0m[31mERROR: Could not find a version that satisfies the requirement pytorch-lightning==1.6.0 (from versions: 0.0.2, 0.2, 0.2.2, 0.2.3, 0.2.4, 0.2.4.1, 0.2.5, 0.2.5.1, 0.2.5.2, 0.2.6, 0.3, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.4.1, 0.3.5, 0.3.6, 0.3.6.1, 0.3.6.3, 0.3.6.4, 0.3.6.5, 0.3.6.6, 0.3.6.7, 0.3.6.8, 0.3.6.9, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.5.0, 0.5.1, 0.5.1.2, 0.5.1.3, 0.5.2, 0.5.2.1, 0.5.3, 0.5.3.1, 0.5.3.2, 0.5.3.3, 0.6.0, 0.7.1, 0.7.3, 0.7.5, 0.7.6, 0.8.1, 0.8.3, 0.8.4, 0.8

In [17]:
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, QuantileLoss
from lightning.pytorch.tuner import Tuner
import pandas as pd
import numpy as np
from pytorch_lightning import Trainer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
import os
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split


# 기존 파라미터 조정
max_prediction_length = 1  # 다음 주 예측
max_encoder_length = 24    # 과거 12주 사용 (기존 24주에서 축소)
min_encoder_length = 8     # 최소 8주의 데이터 필요

features = [ 
             'price',
             'review_cnt',
             'wish_cnt',
             'sixMothRatio(puchase_cnt/review_cnt)',
             'week_review_count',
             'average_review_score',
             'category1_encoded',
             'category2_encoded',
             'category3_encoded',
             'rolling_mean_purchase',
             'rolling_std_purchase',
             'week_num',
             'month',
             'month_sin',
             'month_cos',
             'week_sin',
             'week_cos']



# 데이터 준비 단계
print("데이터 필터링 전:")
print(f"전체 제품 수: {len(df_filled['product_id'].unique())}")

# 시퀀스 길이 계산
sequence_lengths = df_filled.groupby('product_id').size()
print("\n시퀀스 길이 통계:")
print(f"최소 시퀀스 길이: {sequence_lengths.min()}")
print(f"최대 시퀀스 길이: {sequence_lengths.max()}")
print(f"평균 시퀀스 길이: {sequence_lengths.mean():.2f}")

# 최소 필요 길이 설정
min_required_length = min_encoder_length + max_prediction_length
valid_products = sequence_lengths[sequence_lengths >= min_required_length].index

print(f"\n조정된 최소 필요 데이터 길이: {min_required_length}")
print(f"충분한 데이터를 가진 제품 수: {len(valid_products)}")
print(f"포함된 제품 비율: {(len(valid_products) / len(sequence_lengths) * 100):.2f}%")


# 유효한 제품만 필터링
df_filtered = df_filled[df_filled['product_id'].isin(valid_products)].copy()


# time_idx 재생성
# week_date를 datetime 형식으로 변환
df_filtered ['week_date'] = pd.to_datetime(df_filtered ['week_date'])
# product_id를 문자열로 변환
df_filtered ['product_id'] = df_filtered ['product_id'].astype(str)
# 각 product_id마다 0부터 시작하는 time_idx 생성 -> 상품 별로 시계열 예측을 진행하도록 함.
df_filtered = df_filtered.sort_values(["product_id", "week_date"]).reset_index(drop=True)
df_filtered["time_idx"] = df_filtered.groupby("product_id").cumcount()

# training cutoff 설정 (각 시퀀스의 80%를 훈련에 사용)
df_filtered["training_cutoff"] = df_filtered.groupby("product_id")["time_idx"].transform(
    lambda x: int(len(x) * 0.8)
)

# print(df_filtered.dtypes)

# Training dataset 생성
training = TimeSeriesDataSet(
    df_filtered[lambda x: x.time_idx <= x.training_cutoff],
    time_idx="time_idx",
    target="week_purchase_cnt",
    group_ids=["product_id"],
    min_encoder_length=min_encoder_length,
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["product_id"],
    time_varying_known_reals=["time_idx"] + features,
    time_varying_unknown_reals=["week_purchase_cnt"],
    target_normalizer=GroupNormalizer(groups=["product_id"]),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

# Validation dataset 생성
validation = TimeSeriesDataSet.from_dataset(
    training,
    df_filtered,
    min_prediction_idx=training.index.time.max() + 1,
    stop_randomization=True
)

# 데이터 로더 생성
batch_size = 128  # 배치 크기 증가
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 2, num_workers=0)


# Early stopping 설정 조정
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=1e-4,
    patience=10,
    verbose=True,
    mode="min"
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')
# 모델 파라미터 조정
# 모델 초기화
pl.seed_everything(42)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=2,
    dropout=0.2,  # 드롭아웃 증가
    hidden_continuous_size=16,
    loss=RMSE(),
)

# Trainer 초기화
trainer = pl.Trainer(
    max_epochs=100,
    devices=1,
    accelerator='gpu',
    gradient_clip_val=0.1,
    callbacks=[early_stop_callback, lr_monitor],
    enable_progress_bar=True
)


# 모델 학습
print("\n모델 학습 시작...")
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    # val_dataloaders=val_dataloader
)

# 검증 세트에 대한 성능 평가
print("\n검증 세트 성능 평가:")
validation_predictions = tft.predict(val_dataloader)
validation_actual = torch.cat([y[0] for x, y in iter(val_dataloader)])

# GPU 텐서를 CPU로 이동하여 계산
val_predictions_np = validation_predictions.cpu().numpy()
val_actual_np = validation_actual.cpu().numpy()

rmse = np.sqrt(mean_squared_error(val_actual_np, val_predictions_np))
r2 = r2_score(val_actual_np, val_predictions_np)

print(f"Validation RMSE: {rmse:.2f}")
print(f"Validation R²: {r2:.4f}")




Seed set to 42


데이터 필터링 전:
전체 제품 수: 453

시퀀스 길이 통계:
최소 시퀀스 길이: 2
최대 시퀀스 길이: 310
평균 시퀀스 길이: 51.25

조정된 최소 필요 데이터 길이: 9
충분한 데이터를 가진 제품 수: 431
포함된 제품 비율: 95.14%


/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
  super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_datalo


모델 학습 시작...


/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Training: |                                               | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

# class화로 self로 받기

In [29]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
import os
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split

class TFTModel:
    def __init__(self, data_path, features, target_column, max_encoder_length, max_prediction_length, batch_size, checkpoint_dir="./checkpoints", model_save_dir="./weights"):
        # 데이터 로드 및 전처리
        self.df = pd.read_csv(data_path)
        self.features = features
        self.target_column = target_column
        self.max_encoder_length = max_encoder_length
        self.max_prediction_length = max_prediction_length
        self.batch_size = batch_size
        self.checkpoint_dir = checkpoint_dir
        self.model_save_dir = model_save_dir

        self.df['product_id'] = self.df['product_id'].astype(str)
        self.df['week_date'] = pd.to_datetime(self.df['week_date'])
        self.df = self.df.sort_values("week_date").reset_index(drop=True)
        self.df["time_idx"] = (self.df["week_date"] - self.df["week_date"].min()).dt.days // 7  # 주 단위로 인덱스 생성

        # 8:2로 train과 validation 분리
        self.train_df, self.val_df = train_test_split(self.df, test_size=0.2, shuffle=False)
        
        # 훈련 데이터 준비
        self.training_cutoff = self.train_df['time_idx'].max() - max_prediction_length
        self.training = TimeSeriesDataSet(
            self.train_df[lambda x: x.time_idx <= self.training_cutoff],
            time_idx="time_idx",
            target=self.target_column,
            group_ids=["product_id"],
            min_encoder_length=max_encoder_length // 2,
            max_encoder_length=max_encoder_length,
            max_prediction_length=max_prediction_length,
            static_categoricals=["product_id"],
            time_varying_known_reals=["time_idx"] + features,
            time_varying_unknown_reals=[target_column],
            target_normalizer=GroupNormalizer(groups=["product_id"]),
        )

        # validation 데이터셋 생성
        self.validation = TimeSeriesDataSet.from_dataset(
            self.training,
            self.train_df,
            predict=True,
            stop_randomization=True
        )

        # 데이터 로더 준비
        self.train_dataloader = self.training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
        self.val_dataloader = self.validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

        # 모델 정의
        self.tft = TemporalFusionTransformer.from_dataset(
            self.training,
            learning_rate=0.03,
            hidden_size=16,
            attention_head_size=1,
            dropout=0.1,
            hidden_continuous_size=8,
            output_size=1,  # max_prediction_length와 동일하게 설정
            loss=RMSE(),
        )

        # Trainer 설정
        self.trainer = pl.Trainer(
            max_epochs=10,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            devices=1,
            gradient_clip_val=0.1,
            logger=pl.loggers.TensorBoardLogger('tb_logs'),
            callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=5)]
        )

    def fit(self):
        """
        학습을 진행하는 함수
        """
        # 모델 학습
        pl.seed_everything(42)
        self.trainer.fit(
            model=self.tft,
            train_dataloaders=self.train_dataloader,
            val_dataloaders=self.val_dataloader
        )

        # 체크포인트 저장
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(self.checkpoint_dir, "tft_model_checkpoint.ckpt")
        self.trainer.save_checkpoint(checkpoint_path)

        # 모델 가중치 저장
        os.makedirs(self.model_save_dir, exist_ok=True)
        model_save_path = os.path.join(self.model_save_dir, "tft_model_weights.pth")
        torch.save(self.tft.state_dict(), model_save_path)

        print("모델 학습 완료 및 저장.")

    def evaluate(self):
        """
        validation 데이터를 사용하여 모델 성능 평가
        """
        predictions = self.trainer.predict(self.tft, self.val_dataloader)
        predictions = torch.cat([p.prediction for p in predictions]).cpu().numpy()

        # 실제 값과 예측값 비교
        actuals = torch.cat([self.tft.to_network_output(batch)[0] for batch in iter(self.val_dataloader)]).cpu().numpy()
        
        mse = mean_squared_error(actuals, predictions)
        rmse = np.sqrt(mse)
        r2 = r2_score(actuals, predictions)
        
        print(f"Validation RMSE: {rmse:.4f}, R2: {r2:.4f}")

        return predictions

    def plot_predictions(self, product_ids):
        """
        각 상품별 예측 결과를 실제 값과 함께 시각화
        """
        predictions = self.trainer.predict(self.tft, self.val_dataloader)
        predictions = torch.cat([p.prediction for p in predictions]).cpu().numpy()

        for product_id in product_ids:
            product_data = self.val_df[self.val_df['product_id'] == str(product_id)]
            if len(product_data) == 0:
                print(f"No data found for product {product_id}")
                continue

            product_data = product_data.sort_values("week_date")
            true_values = product_data[self.target_column].values

            # 예측값 매핑
            pred_indices = product_data.index
            if len(pred_indices) > len(predictions):
                pred_indices = pred_indices[-len(predictions):]
            product_predictions = predictions[:len(pred_indices)]

            plt.figure(figsize=(10, 6))
            plt.plot(product_data['week_date'], true_values, color='blue', label='Actual')
            plt.plot(product_data['week_date'].iloc[-len(product_predictions):], 
                    product_predictions, color='red', label='Predicted')
            plt.title(f"Product {product_id}: Actual vs Predicted Values")
            plt.xlabel("Week")
            plt.ylabel("Sales")
            plt.legend()
            plt.grid(True)
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.show()

    def predict_next_week(self, product_id):
        """
        특정 상품의 다음 주 판매량 예측
        """
        product_data = self.df[self.df['product_id'] == str(product_id)]
        if len(product_data) == 0:
            print(f"No data found for product {product_id}")
            return None

        encoder_data = self.training.get_inference_data(
            product_data,
            predict_mode=True
        )
        
        predictions = self.tft.predict(encoder_data)
        next_week_prediction = predictions[0].cpu().numpy()

        # 통계 계산
        historical_stats = {
            'mean': product_data[self.target_column].mean(),
            'max': product_data[self.target_column].max(),
            'min': product_data[self.target_column].min(),
            'last': product_data[self.target_column].iloc[-1]
        }

        print(f"\nProduct {product_id} Statistics:")
        print(f"  Historical Mean Sales: {historical_stats['mean']:.2f}")
        print(f"  Historical Max Sales: {historical_stats['max']:.2f}")
        print(f"  Historical Min Sales: {historical_stats['min']:.2f}")
        print(f"  Last Week Sales: {historical_stats['last']:.2f}")
        print(f"  Predicted Next Week Sales: {next_week_prediction[0]:.2f}")

        return next_week_prediction[0]

# 사용 예시
#if __name__ == "__main__":
    # 필요한 변수 설정
data_path = "./fina_preprocessing_data.csv"
features = [ 
             'price',
             'review_cnt',
             'wish_cnt',
             'sixMothRatio(puchase_cnt/review_cnt)',
             'week_review_count',
             'average_review_score',
             'category1_encoded',
             'category2_encoded',
             'category3_encoded',
             'rolling_mean_purchase',
             'rolling_std_purchase',
             'week_num',
             'month',
             'month_sin',
             'month_cos',
             'week_sin',
             'week_cos']

target_column = 'week_purchase_cnt'
max_encoder_length = 24
max_prediction_length = 1
batch_size = 64

# 모델 인스턴스 생성 및 학습
tft_model = TFTModel(data_path, features, target_column, 
                    max_encoder_length, max_prediction_length, batch_size)
tft_model.fit()

# 평가 및 예측
predictions = tft_model.evaluate()

# 예시 상품들에 대한 예측 시각화
sample_products = tft_model.df['product_id'].unique()[:3]
print(sample_products)
tft_model.plot_predictions(sample_products)

# 특정 상품의 다음 주 예측
tft_model.predict_next_week(sample_products[0])



KeyError: "Unknown category '88208430684' encountered. Set `add_nan=True` to allow unknown categories"

In [36]:
pip install --upgrade torch torchvision torchaudio pytorch-lightning pytorch-forecasting


Collecting torch
  Downloading torch-2.5.1-cp310-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.20.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.5.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (6.4 kB)
Collecting pytorch-forecasting
  Downloading pytorch_forecasting-1.2.0-py3-none-any.whl.metadata (13 kB)
Downloading torch-2.5.1-cp310-none-macosx_11_0_arm64.whl (63.9 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.9/63.9 MB[0m [31m52.0 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:01[0m
[?25hDownloading torchvision-0.20.1-cp310-cp310-macosx_11_0_arm64.whl (1.8 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchaudio-2.5.1-cp310-cp310-macosx_11_0_arm64.whl (1.8 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [43]:
pip install pytorch

Collecting pytorch
  Downloading pytorch-1.0.2.tar.gz (689 bytes)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: pytorch
  Building wheel for pytorch (setup.py) ... [?25lerror
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m [31m[6 lines of output][0m
  [31m   [0m Traceback (most recent call last):
  [31m   [0m   File "<string>", line 2, in <module>
  [31m   [0m   File "<pip-setuptools-caller>", line 34, in <module>
  [31m   [0m   File "/private/var/folders/c4/2xfj413907ng2mds_4tqrm8r0000gn/T/pip-install-2fttb2s9/pytorch_357cb588e0724029ab180095aadf7dd2/setup.py", line 15, in <module>
  [31m   [0m     raise Exception(message)
  [31m   [0m Exception: You tried to install "pytorch". The package named for PyTorch is "torch"
  [31m   [0m [31m[end of output][0m
  
  [1;35mnote[0m: This er

In [45]:
pip install torch torchvision


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [46]:
pip show torch pytorch-lightning pytorch-forecasting

Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: lightning, pytorch-forecasting, pytorch-lightning, sentence-transformers, torchaudio, torchmetrics, torchvision
---
Name: pytorch-lightning
Version: 2.4.0
Summary: PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.
Home-page: https://github.com/Lightning-AI/lightning
Author: Lightning AI et al.
Author-email: pytorch@lightning.ai
License: Apache-2.0
Location: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages
Requires: fsspec, lightning-utilities, packaging, PyYAML, torch, torchmetrics, tqdm, typing-extensi

In [50]:
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE
import os
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.model_selection import train_test_split

class TFTModel:
    def __init__(self, data_path, features, target_column, max_encoder_length, max_prediction_length, batch_size, checkpoint_dir="./checkpoints", model_save_dir="./weights"):
        # 데이터 로드
        self.df = pd.read_csv(data_path)

        # 데이터 전처리
        self.features = features
        self.target_column = target_column
        self.max_encoder_length = max_encoder_length
        self.max_prediction_length = max_prediction_length
        self.batch_size = batch_size
        self.checkpoint_dir = checkpoint_dir
        self.model_save_dir = model_save_dir

        # product_id 처리
        self.df['product_id'] = self.df['product_id'].astype(str)  # product_id를 문자열로 변환
        
        # 날짜 처리
        self.df['week_date'] = pd.to_datetime(self.df['week_date'])
        self.df = self.df.sort_values("week_date").reset_index(drop=True)
        self.df["time_idx"] = (self.df["week_date"] - self.df["week_date"].min()).dt.days // 7

        # 각 product_id별로 데이터를 훈련/검증 세트로 분리
        train_list = []
        val_list = []

        for product_id, group in self.df.groupby('product_id'):
            # group은 각 product_id에 해당하는 데이터
            group_train, group_val = train_test_split(
                group, 
                test_size=0.2, 
                shuffle=False  # 시계열 데이터에서는 시간 순서를 유지
            )
            train_list.append(group_train)
            val_list.append(group_val)

        # 훈련/검증 데이터를 병합
        self.train_df = pd.concat(train_list).reset_index(drop=True)
        self.val_df = pd.concat(val_list).reset_index(drop=True)

        # # 디버그: 카테고리 정보 출력
        # print("Unique product_ids in training:")
        # print(self.train_df['product_id'].unique())
        # print("\nUnique product_ids in validation:")
        # print(self.val_df['product_id'].unique())

        # # 전체 데이터셋의 고유 product_id
        # all_product_ids = self.df['product_id'].unique()
        # print(f"\nTotal unique product_ids: {len(all_product_ids)}")
        
        # TimeSeriesDataSet 생성
        self.training = TimeSeriesDataSet(
            self.train_df[lambda x: x.time_idx <= self.train_df['time_idx'].max() - max_prediction_length],
            time_idx="time_idx",
            target=self.target_column,
            group_ids=["product_id"],
            min_encoder_length=max_encoder_length // 2,
            max_encoder_length=max_encoder_length,
            max_prediction_length=max_prediction_length,
            static_categoricals=["product_id"],
            time_varying_known_reals=["time_idx"] + features,
            time_varying_unknown_reals=[target_column],
            target_normalizer=GroupNormalizer(groups=["product_id"]),
            # add_nan=True  # 새로운 카테고리 처리
        )

        # validation 데이터셋 생성
        self.validation = TimeSeriesDataSet.from_dataset(
            self.training,
            self.val_df,
            predict=True,
            stop_randomization=True
        )

        # 데이터 로더 준비
        self.train_dataloader = self.training.to_dataloader(
            train=True, 
            batch_size=batch_size, 
            num_workers=0
        )
        self.val_dataloader = self.validation.to_dataloader(
            train=False, 
            batch_size=batch_size, 
            num_workers=0
        )

        # 모델 정의
        self.tft = TemporalFusionTransformer.from_dataset(
            self.training,
            learning_rate=0.03,
            hidden_size=16,
            attention_head_size=1,
            dropout=0.1,
            hidden_continuous_size=8,
            output_size=1,
            loss=RMSE(),
        )

        # Trainer 설정
        self.trainer = pl.Trainer(
            max_epochs=10,
            devices=1,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            gradient_clip_val=0.1,
            # logger=pl.loggers.TensorBoardLogger('tb_logs'),
            # callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=5)]
        )
        print(self.tft)


    def fit(self):
        """
        학습을 진행하는 함수
        """
        # 모델 학습
        pl.seed_everything(42)
        self.trainer.fit(
            model=self.tft,
            train_dataloaders=self.train_dataloader,
            val_dataloaders=self.val_dataloader
        )

        # 체크포인트 저장
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(self.checkpoint_dir, "tft_model_checkpoint.ckpt")
        self.trainer.save_checkpoint(checkpoint_path)

        # 모델 가중치 저장
        os.makedirs(self.model_save_dir, exist_ok=True)
        model_save_path = os.path.join(self.model_save_dir, "tft_model_weights.pth")
        torch.save(self.tft.state_dict(), model_save_path)

        print("모델 학습 완료 및 저장.")

    def evaluate(self):
        """
        validation 데이터를 사용하여 모델 성능 평가
        """
        predictions = self.trainer.predict(self.tft, self.val_dataloader)
        predictions = torch.cat([p.prediction for p in predictions]).cpu().numpy()

        # 실제 값과 예측값 비교
        actuals = torch.cat([self.tft.to_network_output(batch)[0] for batch in iter(self.val_dataloader)]).cpu().numpy()
        
        mse = mean_squared_error(actuals, predictions)
        rmse = np.sqrt(mse)
        r2 = r2_score(actuals, predictions)
        
        print(f"Validation RMSE: {rmse:.4f}, R2: {r2:.4f}")

        return predictions

    def plot_predictions(self, product_ids):
        """
        각 상품별 예측 결과를 실제 값과 함께 시각화
        """
        predictions = self.trainer.predict(self.tft, self.val_dataloader)
        predictions = torch.cat([p.prediction for p in predictions]).cpu().numpy()

        for product_id in product_ids:
            product_data = self.val_df[self.val_df['product_id'] == product_id]
            if len(product_data) == 0:
                print(f"No data found for product {product_id}")
                continue

            product_data = product_data.sort_values("week_date")
            true_values = product_data[self.target_column].values

            # 예측값 매핑
            pred_indices = product_data.index
            if len(pred_indices) > len(predictions):
                pred_indices = pred_indices[-len(predictions):]
            product_predictions = predictions[:len(pred_indices)]

            plt.figure(figsize=(10, 6))
            plt.plot(product_data['week_date'], true_values, color='blue', label='Actual')
            plt.plot(product_data['week_date'].iloc[-len(product_predictions):], 
                    product_predictions, color='red', label='Predicted')
            plt.title(f"Product {product_id}: Actual vs Predicted Values")
            plt.xlabel("Week")
            plt.ylabel("Sales")
            plt.legend()
            plt.grid(True)
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.show()

    def predict_next_week(self, product_id):
        """
        특정 상품의 다음 주 판매량 예측
        """
        product_data = self.df[self.df['product_id'] == product_id]
        if len(product_data) == 0:
            print(f"No data found for product {product_id}")
            return None

        encoder_data = self.training.get_inference_data(
            product_data,
            predict_mode=True
        )
        
        predictions = self.tft.predict(encoder_data)
        next_week_prediction = predictions[0].cpu().numpy()

        # 통계 계산
        historical_stats = {
            'mean': product_data[self.target_column].mean(),
            'max': product_data[self.target_column].max(),
            'min': product_data[self.target_column].min(),
            'last': product_data[self.target_column].iloc[-1]
        }

        print(f"\nProduct {product_id} Statistics:")
        print(f"  Historical Mean Sales: {historical_stats['mean']:.2f}")
        print(f"  Historical Max Sales: {historical_stats['max']:.2f}")
        print(f"  Historical Min Sales: {historical_stats['min']:.2f}")
        print(f"  Last Week Sales: {historical_stats['last']:.2f}")
        print(f"  Predicted Next Week Sales: {next_week_prediction[0]:.2f}")

        return next_week_prediction[0]

# 사용 예시
#if __name__ == "__main__":
    # 필요한 변수 설정
data_path = "./fina_preprocessing_data.csv"
features = [ 
             'price',
             'review_cnt',
             'wish_cnt',
             'sixMothRatio(puchase_cnt/review_cnt)',
             'week_review_count',
             'average_review_score',
             'category1_encoded',
             'category2_encoded',
             'category3_encoded',
             'rolling_mean_purchase',
             'rolling_std_purchase',
             'week_num',
             'month',
             'month_sin',
             'month_cos',
             'week_sin',
             'week_cos']

target_column = 'week_purchase_cnt'
max_encoder_length = 24
max_prediction_length = 1
batch_size = 64

# 모델 인스턴스 생성 및 학습
tft_model = TFTModel(data_path, features, target_column, 
                    max_encoder_length, max_prediction_length, batch_size)
tft_model.fit()

# # 평가 및 예측
# predictions = tft_model.evaluate()

# # 예시 상품들에 대한 예측 시각화
# sample_products = tft_model.df['product_id'].unique()[:3]
# print(sample_products)
# tft_model.plot_predictions(sample_products)

# # 특정 상품의 다음 주 예측
# tft_model.predict_next_week(sample_products[0])

/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
  super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerato

TemporalFusionTransformer(
  	"attention_head_size":               1
  	"categorical_groups":                {}
  	"causal_attention":                  True
  	"dataset_parameters":                {'time_idx': 'time_idx', 'target': 'week_purchase_cnt', 'group_ids': ['product_id'], 'weight': None, 'max_encoder_length': 24, 'min_encoder_length': 12, 'min_prediction_idx': 0, 'min_prediction_length': 1, 'max_prediction_length': 1, 'static_categoricals': ['product_id'], 'static_reals': ['encoder_length'], 'time_varying_known_categoricals': [], 'time_varying_known_reals': ['time_idx', 'price', 'review_cnt', 'wish_cnt', 'sixMothRatio(puchase_cnt/review_cnt)', 'week_review_count', 'average_review_score', 'category1_encoded', 'category2_encoded', 'category3_encoded', 'rolling_mean_purchase', 'rolling_std_purchase', 'week_num', 'month', 'month_sin', 'month_cos', 'week_sin', 'week_cos'], 'time_varying_unknown_categoricals': [], 'time_varying_unknown_reals': ['week_purchase_cnt'], 'variable_groups

TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`

## 출력

In [None]:
Unique product_ids in training:
['81182199627' '82224112824' '7781174622' '81432925718' '83439829432'
 '82855381541' '11659574428' '84366233772' '7506929597' '84759655602'
 '83715863755' '12261313683' '81432917824' '82256221887' '86168008312'
 '85946192170' '80036527278' '83018697152' '82748436013' '9061388274'
 '82035499213' '85747878342' '13381025014' '81769963134' '83586891935'
 '83496016535' '82518325576' '85219719569' '86074262188' '81883717746'
 '80150432791' '83196162009' '86543062469' '12549294732' '82512029192'
 '82343766518' '6766327742' '82245440734' '84338752005' '86520272366'
 '82697228937' '85635072114' '84266140607' '81592586014' '86564530041'
 '82473335098' '84630569977' '82155387209' '85556153151' '84417974537'
 '86011600522' '82987858627' '10990197556' '86552777853' '81490308013'
 '82374636770' '82518358819' '86316244276' '84938641455' '83396511185'
 '83452559067' '11985210471' '82442958695' '85113603549' '81743617172'
 '86596962625' '83989019167' '83058672505' '80151600598' '85489125061'
 '86744054774' '86716255500' '86734323610' '82116389223' '82094228515'
 '9694820143' '86575042896' '86001457602' '13147435734' '10194331314'
 '10129739316' '11831091502' '81162882669' '9879202686' '80179015780'
 '80437478266' '80585484638' '81011449761' '81022718754' '81376726447'
 '81307567948' '11545229477' '12099494038' '80998837833' '85243095782'
 '82431335523' '83032577198' '84482901377' '83722394844' '84575983577'
 '84590609622' '83060973890' '85577317693' '83124441105' '84651414760'
 '83148914744' '82224200414' '8866795532' '10583217322' '82375566008'
 '83916701408' '82629272072' '83969946618' '82803508510' '82554079402'
 '83835508396' '82493840170' '8747453599' '84253456764' '85622876945'
 '82909794326' '86196224969' '82493839762' '86380103731' '86253563442'
 '82109132933' '86718087741' '84920211651' '83355904195' '85338623974'
 '86768411788' '86766394229' '85022115898' '83342948311' '82194575327'
 '85080929940' '84878502957' '83450413204' '85141197525' '86677200634'
 '81898265092' '86003052939' '82167582057' '86779626334' '83469823485'
 '82161091686' '82154885060' '84878420163' '85903271022' '84852924321'
 '83353405499' '83729421264' '81165039187' '83817222587' '83844563428'
 '85676427375' '85827942083' '83886328887' '80137720714' '80070124325'
 '83459206628' '80538088436' '83595717401' '81076354187' '81072587300'
 '85750841326' '81044807932' '83588503201' '81037586677' '81029307872'
 '83588420791' '85786543730' '83582241229' '83565897543' '80916609335'
 '80909194289' '80771696270' '83544700181' '80471449409' '80466364794'
 '80447378951' '80260009881' '85765526144' '83892959540' '85465017384'
 '84861658923' '84773371122' '11779789592' '84739735201' '11913737888'
 '11976178121' '84684925054' '85394939977' '10651348891' '10790326479'
 '10839131148' '11076937403' '11199076229' '85008754353' '11218721969'
 '11421983102' '11598092065' '85050087425' '11239882600' '84595711640'
 '84240189171' '84223183690' '84195356355' '84051723307' '13399566549'
 '80027880899' '83988837837' '83956059687' '83945044942' '85637826843'
 '83944786359' '83937248946' '85603598515' '84593454966' '85578911636'
 '84486897225' '85602479174' '12402425757' '12406570359' '84371446205'
 '84345558721' '12664919543' '12823927460' '83442851433' '82261431626'
 '82250264065' '82244034272' '82243364662' '82271542438' '82239755179'
 '82235144540' '82213682976' '86634985389' '82203873495' '82201860750'
 '82198829607' '82237266787' '82278485824' '82280680602' '82413203779'
 '82409708164' '82396216124' '86512643840' '82386649071' '82371516080'
 '82355482914' '82308116319' '82298266676' '82293068454' '82289916004'
 '82288244906' '82196153989' '82182472219' '81944633724' '86750799549'
 '81251971550' '81817765030' '81743662525' '81683115256' '81500447696'
 '81950814342' '82423604772' '81977652281' '82177741646' '86695180428'
 '82141150395' '82114593556' '82100081230' '82084083327' '82075021040'
 '82057321861' '82048947651' '83441947093' '86512422117' '82439360120'
 '83165781839' '83154350949' '83148737984' '83145373731' '86015286906'
 '83168990735' '83046474193' '83027433275' '83026221734' '86080579187'
 '83176583239' '83185458721' '83423045215' '83412461028' '83405521509'
 '83379992650' '83357689027' '85994962324' '83339174174' '83265103095'
 '83242599500' '83227652369' '83195131633' '82978885242' '82964195824'
 '82938287668' '82584169242' '86307626335' '82579351820' '82506716986'
 '82498044964' '86361182801' '82492189459' '82484077890' '82450880093'
 '82612392967' '86258398286' '82644613191' '82923956263' '82920847039'
 '82904703004' '82888223999' '86238161590' '82839137204' '82797428689'
 '86238199436' '82731789250' '82725368080' '82663901180' '82658175350'
 '10316027559' '6356018199' '9645322788' '10095244123' '8808295228'
 '10022619427' '9508439307' '10276314469' '8597689401' '8974847230'
 '8678210047' '10182440260' '8218195764' '6617143657' '6485255842'
 '9555749627' '86865818817' '86892789310' '86340219528' '82617930980'
 '86947868795' '86615650299' '11195202991' '86931391148' '86985198344'
 '87034617231' '87002755797' '85674849480' '87077078211' '87134721472'
 '13154622884' '87154257622' '87156542226' '87171220291' '87207387394'
 '87044616207' '83021652171' '87194315468' '87256966784' '87297137510'
 '87300027307' '87354688430' '87416266422' '87434901086' '87233647417'
 '84739899624' '86984516623' '87477209865' '82366359065' '10857366633'
 '82330465832' '87612495327' '87631394749' '87616961155' '87642859084'
 '81999320307' '9270155631' '87604505038' '85829941612' '87647012257'
 '87603328228' '87528396162' '87664970794' '87703272849' '87730178392'
 '83890388520' '86644487746' '87737396955' '87730829177' '87958278370'
 '87937248809' '87965476173' '87149885404' '87844396267' '87962457602'
 '87952389253' '87921020002' '88012479244' '87991745133' '87896780670'
 '87817018176' '87988876731' '88038712401' '88081241709' '82562973333'
 '88098187751' '88113889505' '88042110947' '88042074458' '87226364197'
 '88117899204' '88119947030' '88218200275' '88100622609' '88208430684']

Unique product_ids in validation:
['82439360120' '82100081230' '82562973333' '83227652369' '86307626335'
 '87730829177' '84938641455' '83412461028' '85022115898' '80771696270'
 '86677200634' '13147435734' '6356018199' '83423045215' '82923956263'
 '85008754353' '81376726447' '83242599500' '82167582057' '82484077890'
 '12823927460' '85338623974' '80916609335' '85827942083' '87603328228'
 '81432917824' '80036527278' '87965476173' '82109132933' '83439829432'
 '82518358819' '82473335098' '12664919543' '82177741646' '87154257622'
 '86984516623' '83441947093' '6766327742' '82386649071' '82584169242'
 '83937248946' '84482901377' '82239755179' '87233647417' '10990197556'
 '8597689401' '7781174622' '83027433275' '83944786359' '82035499213'
 '82644613191' '82293068454' '84417974537' '81883717746' '86238161590'
 '86543062469' '88119947030' '83145373731' '80466364794' '85635072114'
 '84371446205' '83945044942' '82658175350' '85578911636' '82289916004'
 '86575042896' '11545229477' '82243364662' '83032577198' '87958278370'
 '86931391148' '83956059687' '82298266676' '11598092065' '87226364197'
 '87844396267' '86520272366' '10839131148' '86985198344' '83890388520'
 '82629272072' '86750799549' '82748436013' '84575983577' '81769963134'
 '83148914744' '83892959540' '82235144540' '10857366633' '82330465832'
 '83026221734' '11659574428' '82048947651' '80179015780' '83588420791'
 '85637826843' '82308116319' '82237266787' '8678210047' '81817765030'
 '86892789310' '84486897225' '83916701408' '83148737984' '85577317693'
 '84590609622' '80260009881' '87616961155' '84051723307' '85602479174'
 '11239882600' '83060973890' '11195202991' '84195356355' '82250264065'
 '82725368080' '81950814342' '82271542438' '87703272849' '84266140607'
 '81029307872' '11218721969' '86734323610' '86564530041' '11199076229'
 '83046474193' '82261431626' '85603598515' '84223183690' '83058672505'
 '84253456764' '82697228937' '87297137510' '88038712401' '81977652281'
 '80437478266' '82256221887' '84240189171' '9694820143' '81898265092'
 '86015286906' '86238199436' '83124441105' '87952389253' '82731789250'
 '83969946618' '8866795532' '82288244906' '84366233772' '86011600522'
 '11421983102' '82244034272' '83988837837' '81037586677' '87300027307'
 '82280680602' '82278485824' '80447378951' '81999320307' '87256966784'
 '84345558721' '82663901180' '86744054774' '11076937403' '85622876945'
 '81022718754' '83989019167' '82245440734' '86552777853' '9555749627'
 '84338752005' '81944633724' '88184129427' '86596962625' '87354688430'
 '83342948311' '9879202686' '86168008312' '82075021040' '82375566008'
 '87962457602' '83595717401' '82371516080' '82224112824' '87817018176'
 '87434901086' '87937248809' '86615650299' '81072587300' '84595711640'
 '83817222587' '87034617231' '85556153151' '84651414760' '85489125061'
 '11831091502' '11985210471' '88242709418' '81683115256' '82797428689'
 '81076354187' '85676427375' '11976178121' '82374636770' '83018697152'
 '82213682976' '83168990735' '87194315468' '8808295228' '88113889505'
 '81743617172' '10651348891' '88100622609' '82612392967' '9061388274'
 '83729421264' '81011449761' '83715863755' '86512643840' '84630569977'
 '83886328887' '10583217322' '82617930980' '11913737888' '86718087741'
 '84684925054' '83165781839' '83835508396' '81743662525' '83154350949'
 '82203873495' '82355482914' '80471449409' '11779789592' '84593454966'
 '87207387394' '86003052939' '87647012257' '82987858627' '8218195764'
 '88042074458' '8974847230' '86253563442' '86766394229' '82084083327'
 '10790326479' '81044807932' '85674849480' '82343766518' '82057321861'
 '83021652171' '80151600598' '83844563428' '83722394844' '82803508510'
 '85747878342' '86074262188' '7506929597' '87416266422' '80150432791'
 '82366359065' '83588503201' '84739735201' '82224200414' '87896780670'
 '81182199627' '80998837833' '86947868795' '82904703004' '86001457602'
 '83176583239' '82978885242' '80909194289' '80538088436' '80137720714'
 '82518325576' '87077078211' '83339174174' '87002755797' '82909794326'
 '85946192170' '87737396955' '9508439307' '80585484638' '88098187751'
 '87642859084' '86865818817' '83196162009' '80070124325' '82839137204'
 '81162882669' '85994962324' '82855381541' '82554079402' '82938287668'
 '88081241709' '82964195824' '83265103095' '86258398286' '83185458721'
 '81165039187' '82920847039' '80027880899' '86196224969' '83195131633'
 '86080579187' '82579351820' '88042110947' '87044616207' '82888223999'
 '82409708164' '87991745133' '84920211651' '10182440260' '82182472219'
 '84878502957' '86716255500' '86644487746' '85394939977' '88117899204'
 '87528396162' '82194575327' '84878420163' '10194331314' '84861658923'
 '6617143657' '86695180428' '85219719569' '10022619427' '82141150395'
 '85141197525' '82154885060' '85113603549' '82116389223' '87612495327'
 '87604505038' '82155387209' '10095244123' '85080929940' '87664970794'
 '82161091686' '85050087425' '82114593556' '10129739316' '85243095782'
 '82094228515' '86634985389' '87988876731' '88012479244' '84852924321'
 '82196153989' '84773371122' '85465017384' '87477209865' '10276314469'
 '82198829607' '84759655602' '10316027559' '84739899624' '82201860750'
 '12549294732' '87171220291' '86512422117' '86768411788' '87149885404'
 '82493839762' '83405521509' '86779626334' '83396511185' '87921020002'
 '82493840170' '87134721472' '83469823485' '6485255842' '81307567948'
 '87156542226' '83586891935' '86380103731' '83496016535' '12402425757'
 '83544700181' '82423604772' '81500447696' '82413203779' '83565897543'
 '87730178392' '85750841326' '86340219528' '81592586014' '83582241229'
 '85829941612' '82492189459' '81490308013' '82396216124' '85765526144'
 '82431335523' '12099494038' '12261313683' '82498044964' '9645322788'
 '81251971550' '87631394749' '83355904195' '88218200275' '13399566549'
 '88206700121' '81432925718' '82506716986' '83353405499' '86361182801'
 '83450413204' '85786543730' '82512029192' '82442958695' '13381025014'
 '83452559067' '83379992650' '13154622884' '85903271022' '9270155631'
 '8747453599' '83459206628' '12406570359' '88208430684' '82450880093'
 '83442851433' '83357689027' '86316244276' '88222671639' '88291977922'
 '87896784213' '88284816214' '88351942304' '88303490649' '88321577019'
 '88347463796' '88283419814' '88334988204' '88380500239' '88401777165'
 '88388649996' '88388890335' '88415339810' '88415318369' '88388533454'
 '88356536951' '88411622436' '88438439266' '88378913358' '88415191727'
 '88453896846' '88473568229' '88459191933']