Skip to content

2026Graduation-Work/Stock-Prediction

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Hybrid Stock Prediction Platform

TFT + DRL 기반 심리지수 반영 주가 예측 및 매매 의사결정 시스템


모델 구조 상세 설명

models/tft_model.py - Temporal Fusion Transformer

이 모델은 사전학습 모델을 가져와 파인튜닝한 형태가 아니라, Temporal Fusion Transformer 구조를 PyTorch로 직접 구현한 시계열 예측 모델입니다.
입력으로는 일정 길이의 과거 가격 시계열, 기술적 지표, 심리지수 특성을 받고, 출력으로는 향후 여러 시점의 가격 분위수 예측과 해석용 가중치를 함께 제공합니다.
즉, "앞으로 오를지 내릴지"만 보는 단순 분류기가 아니라, 얼마나 변할 가능성이 있는지와 어떤 입력이 예측에 크게 작용했는지까지 확인할 수 있도록 설계되어 있습니다.

핵심 파이프라인은 아래 5단계입니다.

  1. Variable Selection Network (VSN)

    • 각 시점마다 입력 변수별 중요도를 softmax 가중치로 계산해, 현재 시장 상황에서 더 중요한 특성에 더 큰 비중을 둡니다.
    • 이 프로젝트에서는 기술적 지표 20개와 심리지수 7개를 함께 넣기 때문에, 단일 지표에 고정적으로 의존하지 않고 시장 국면에 따라 중요한 정보를 동적으로 고를 수 있습니다.
  2. Gated Residual Network (GRN)

    • 변수 선택 단계와 후속 표현 학습 단계에서 반복적으로 사용되는 핵심 블록으로, 비선형 변환과 게이트를 통해 정보 흐름을 조절합니다.
    • 잔차 연결과 LayerNorm이 함께 들어가 있어 학습 안정성을 높이고, 필요 없는 정보는 억제하면서 중요한 패턴은 보존하도록 돕습니다.
  3. LSTM Encoder (3층)

    • 선택된 시계열 표현을 순차적으로 읽으면서 최근 며칠의 흐름, 완만한 추세 변화, 단기 반등/하락 패턴 같은 국소적 시간 구조를 인코딩합니다.
    • TFT에서 LSTM은 "시간 순서" 자체를 먼저 안정적으로 요약하는 역할을 하며, 이후 attention이 더 긴 범위의 관계를 보완합니다.
  4. Temporal Self-Attention (Multi-head = 4)

    • LSTM이 만든 시계열 표현 위에서 시점 간 연관성을 다시 계산해, 장기 의존성과 패턴 반복을 포착합니다.
    • causal mask를 적용해 미래 시점 정보가 현재 예측에 섞이지 않도록 막기 때문에, 실제 예측 환경과 동일한 조건으로 학습됩니다.
    • attention weight를 함께 반환하므로, 어떤 과거 구간을 특히 참고했는지 해석에 활용할 수 있습니다.
  5. Quantile Output

    • 마지막 은닉 표현으로부터 단일 값 하나가 아니라 10%, 50%, 90% 분위수를 동시에 예측합니다.
    • 50% 분위수는 대표 예측값으로, 10%~90% 구간은 예측 불확실성을 표현하는 신뢰 범위로 해석할 수 있습니다.
    • 이 결과는 이후 DRL 에이전트의 상태 입력으로도 사용되어, 단순 방향성뿐 아니라 예측 폭까지 의사결정에 반영됩니다.

models/sentiment_encoder.py - 감성 인코더

이 모듈은 시장 심리를 수치화해서 TFT 입력 특성으로 공급하는 역할을 합니다.
외부 심리 지표를 그대로 가져오는 데서 끝나지 않고, 가격과 거래량 패턴에서 추론한 내부 심리 신호까지 함께 생성해 더 풍부한 입력을 만듭니다.

  • fear_greed_index

    • alternative.me API에서 Fear & Greed Index를 수집해 시장 전반의 공포/탐욕 수준을 반영합니다.
    • API 호출이 실패하면 모멘텀, 변동성, 거래량, 52주 고저 위치를 이용해 대체용 합성 FGI를 생성합니다.
  • sentiment_score 계열

    • RSI와 50일 이동평균 대비 괴리율을 이용해 가격 기반 심리 점수를 계산합니다.
    • 여기에 7일/14일 이동평균, 5일 변화량, 공포-탐욕 구간 분류, 복합 심리지수까지 추가해 총 7개의 심리 특성을 만듭니다.
  • FinBERT 선택 로딩

    • ProsusAI/finbert를 사용할 수 있도록 구현되어 있어, 향후 뉴스 기사나 공시 텍스트 감성 분석으로 확장할 여지를 남겨두었습니다.
    • 현재 기본 설정은 use_finbert=False라서 텍스트 입력 없이도 동작하며, 학습 파이프라인은 가격 기반 심리지수 중심으로 구성됩니다.

models/feature_engine.py - 기술적 지표 엔진

이 모듈은 원본 OHLCV 데이터를 모델이 바로 쓰기 좋은 수치 특성으로 변환하는 전처리 엔진입니다.
별도 학습 없이 지표 공식만으로 특성을 생성하므로, 입력 데이터 품질과 지표 설계가 전체 성능의 기반이 됩니다.

  • 가격 기반 지표

    • RSI, MACD, MACD Signal, MACD Histogram, Bollinger Band 상/하단 및 폭, SMA(20/50), EMA(12/26), ATR을 계산합니다.
    • 추세, 과매수/과매도, 평균 회귀 가능성, 변동성 같은 서로 다른 관점의 신호를 동시에 제공합니다.
  • 거래량 및 모멘텀 지표

    • OBV, VWAP 비율, Stochastic %K/%D를 생성해 거래량 흐름과 단기 모멘텀을 반영합니다.
  • 수익률 및 변동성 지표

    • 1일, 5일, 20일 수익률과 20일 변동성을 추가해 최근 성과와 가격 흔들림 정도를 수치화합니다.
  • 데이터 정제

    • 롤링 계산으로 생기는 초기 결측치는 제거하고 인덱스를 다시 정리해, 이후 학습 단계에서 바로 사용할 수 있는 입력 테이블을 만듭니다.

models/drl_agent.py - PPO 매매 에이전트

이 모듈은 예측 모델의 출력값을 실제 매매 행동으로 연결하는 강화학습 계층입니다.
즉, TFT가 "앞으로 어떻게 움직일 가능성이 높은가"를 추정하면, DRL 에이전트는 그 예측과 현재 포트폴리오 상태를 바탕으로 "지금 무엇을 해야 하는가"를 결정합니다.

  • 학습 환경 (TradingEnvironment)

    • 상태는 TFT가 만든 하한/중앙값/상한 예측, 복합 심리지수, 현재 보유 상태, 누적 수익률, 최근 가격 변화율 윈도우로 구성됩니다.
    • 행동 공간은 매도(0), 보유(1), 매수(2)의 3가지이며, 거래 수수료와 보유 자산 수량이 함께 반영됩니다.
  • 보상 함수

    • 기본적으로는 스텝 수익률을 보상으로 사용하되, 잦은 거래에는 패널티를 주고 큰 손실에는 추가 감점을 적용합니다.
    • 최근 수익률 분산을 고려한 Sharpe-like 보너스를 넣어, 단순 고수익보다 위험 대비 성과가 좋은 정책을 학습하도록 유도합니다.
  • PPO 학습 및 추론

    • Stable-Baselines3의 PPO를 사용해 정책을 학습하며, 학습 완료 후 관측값을 넣으면 BUY/HOLD/SELL 결정과 신뢰도를 반환합니다.
    • 환경이 단순 규칙 기반보다 더 많은 상태 정보를 활용할 수 있어, 예측값과 심리 지표를 함께 고려한 정책 최적화가 가능합니다.
  • 폴백 전략

    • Stable-Baselines3가 설치되지 않은 환경에서는 룰 기반 의사결정으로 자동 전환됩니다.
    • 이 경우 중앙 예측값과 심리지수를 가중 합산해 매수/매도/보유를 판단하므로, 최소한의 데모와 파이프라인 검증은 계속 수행할 수 있습니다.

실행 가이드

아래 순서대로 실행하면 됩니다.

1. 가상환경 생성 및 활성화

python -m venv .venv
.venv\Scripts\activate.bat

2. 라이브러리 설치

pip install --upgrade pip
pip install -r requirements.txt

3. 데이터 준비 (필요 시)

python scripts/collect_data.py

4. TFT 학습

python scripts/train_tft.py

학습 완료 시 체크포인트가 models/checkpoints 경로에 저장됩니다.

5. DRL 학습

python scripts/train_drl.py

이 단계에서 TFT 예측 결과를 기반으로 PPO 에이전트를 학습합니다.
Stable-Baselines3가 없는 환경에서는 룰 기반 전략으로 자동 폴백됩니다.

6. 웹 서버 실행

python app.py

정상 실행되면 http://localhost:5000 에서 대시보드를 확인할 수 있습니다.


4) 웹 파일 체크 포인트

대시보드 렌더링을 위해 다음 파일이 필요합니다.

  • web/templates/dashboard.html
  • web/static/css/style.css
  • web/static/js/dashboard.js

정적 파일이 없거나 경로가 어긋나면 화면 스타일/동작이 깨질 수 있습니다.


5) 프로젝트 구조

stock-prediction/
├── app.py
├── requirements.txt
├── configs/
│   └── config.yaml
├── models/
│   ├── tft_model.py
│   ├── sentiment_encoder.py
│   ├── feature_engine.py
│   ├── drl_agent.py
│   └── checkpoints/
├── data/
│   ├── data_loader.py
│   └── processed/
├── scripts/
│   ├── collect_data.py
│   ├── train_tft.py
│   └── train_drl.py
├── utils/
│   ├── metrics.py
│   └── visualizer.py
└── web/
    ├── templates/dashboard.html
    └── static/
        ├── css/style.css
        └── js/dashboard.js

About

A Stock Price Prediction Flatform that was implemented using time series transformers and reinforcement learning, and reflected in the psychological index

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors