In [1]:
import sys
from pathlib import Path

project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

import os
os.chdir(project_root)

from dotenv import load_dotenv
load_dotenv()

True

In [2]:
import torch
from torch.utils.data import DataLoader

from src.data.pipeline_with_news import get_datasets_with_news
from src.models.transformer_model_with_news import StockTransformerWithNews
from src.training.trainer_with_news import TrainerWithNews
from src.utils.config import load_config

config = load_config()

# Enable news in config
config.data.use_news = True

print(f"News enabled: {config.data.use_news}")
print(f"Tickers: {config.data.tickers}")

News enabled: True
Tickers: ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'META']


In [3]:
# Load datasets with news
print("Loading datasets with news embeddings...")
train_dataset, val_dataset, test_dataset, feature_columns = get_datasets_with_news(
    config=config,
    use_news_cache=True,
    force_refresh_news=False  # Set to True to refresh news cache
)

print(f"‚úÖ Datasets loaded!")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")
print(f"  Features: {len(feature_columns)}")

Loading datasets with news embeddings...
–ó–∞—Ä–µ–∂–¥–∞–Ω–µ –Ω–∞ –ª–æ–∫–∞–ª–µ–Ω dataset –æ—Ç: data\raw\sp500_stocks_data.parquet
–ó–∞—Ä–µ–¥–µ–Ω–æ! –†–∞–∑–º–µ—Ä: (1048575, 23)
–§–∏–ª—Ç—Ä–∏—Ä–∞–Ω–µ –Ω–∞ –¥–∞–Ω–Ω–∏... –ü—ä—Ä–≤–æ–Ω–∞—á–∞–ª–µ–Ω —Ä–∞–∑–º–µ—Ä: (1048575, 25)
–§–∏–ª—Ç—Ä–∏—Ä–∞–Ω–æ! –§–∏–Ω–∞–ª–µ–Ω —Ä–∞–∑–º–µ—Ä: (10068, 25)

üì∞ Extracting news embeddings...
Loading FinBERT encoder...
Loading FinBERT model: ProsusAI/finbert
Device: cpu
‚úÖ FinBERT loaded successfully
üì¶ Loading cached news embeddings for AAPL from data\processed\news_cache\AAPL_20100105_20200930.pkl
üì¶ Loading cached news embeddings for AMZN from data\processed\news_cache\AMZN_20100105_20200930.pkl
üì¶ Loading cached news embeddings for GOOGL from data\processed\news_cache\GOOGL_20100105_20200930.pkl
üì¶ Loading cached news embeddings for MSFT from data\processed\news_cache\MSFT_20100105_20200930.pkl
‚úÖ News embeddings extracted: (4, 770)
‚úÖ Datasets loaded!
  Train: 6876 samples
  Val: 1426 samples
  Tes

In [4]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.training.batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.training.batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
)

# Test batch to check news embeddings
sample_batch = next(iter(train_loader))
print(f"Batch format: {len(sample_batch)} items")
if len(sample_batch) == 3:
    x, news_emb, y = sample_batch
    print(f"  x shape: {x.shape}")
    print(f"  news_emb shape: {news_emb.shape if news_emb is not None else None}")
    print(f"  y shape: {y.shape}")

Batch format: 3 items
  x shape: torch.Size([64, 60, 34])
  news_emb shape: torch.Size([64, 768])
  y shape: torch.Size([64, 1])


In [5]:
# Create enhanced model with news
model = StockTransformerWithNews(
    input_dim=len(feature_columns),
    news_embedding_dim=768,  # FinBERT embedding dimension
    d_model=config.model.d_model,
    n_heads=config.model.n_heads,
    n_layers=config.model.n_layers,
    d_ff=config.model.d_ff,
    dropout=config.model.dropout,
    activation=config.model.activation,
    prediction_horizon=config.data.prediction_horizon,
    news_fusion_method="concat",  # or "add"
)

print(f"‚úÖ Model created!")
print(f"  Input dim: {len(feature_columns)}")
print(f"  News embedding dim: 768")
print(f"  Fusion method: concat")

‚úÖ Model created!
  Input dim: 34
  News embedding dim: 768
  Fusion method: concat


In [6]:
# Create trainer
trainer = TrainerWithNews(
    model=model,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
)

print(f"‚úÖ Trainer created!")

‚úÖ Trainer created!


In [7]:
# Train model
history = trainer.train()

print(f"\n{'='*60}")
print(f"–û–±—É—á–µ–Ω–∏–µ—Ç–æ –∑–∞–≤—ä—Ä—à–∏!")
print(f"Best validation loss: {history['best_val_loss']:.6f}")
print(f"{'='*60}")

–ó–∞–ø–æ—á–≤–∞–Ω–µ –Ω–∞ –æ–±—É—á–µ–Ω–∏–µ –∑–∞ 50 epochs...

Epoch 1/50... Train loss: 0.075509 Val loss: 0.068093 (18.8s)

Epoch 2/50... Train loss: 0.019546 Val loss: 0.046746 (18.4s)

Epoch 3/50... Train loss: 0.016253 Val loss: 0.053059 (18.3s)

Epoch 4/50... Train loss: 0.014535 Val loss: 0.045858 (17.9s)

Epoch 5/50... Train loss: 0.013542 Val loss: 0.026436 (18.4s)

Epoch 6/50... Train loss: 0.013780 Val loss: 0.027150 (18.4s)

Epoch 7/50... Train loss: 0.013078 Val loss: 0.030975 (18.2s)

Epoch 8/50... Train loss: 0.012565 Val loss: 0.022738 (18.1s)

Epoch 9/50... Train loss: 0.012717 Val loss: 0.025734 (18.1s)

Epoch 10/50... Train loss: 0.012480 Val loss: 0.021866 (17.8s)

Epoch 11/50... Train loss: 0.014767 Val loss: 0.023496 (18.0s)

Epoch 12/50... Train loss: 0.012088 Val loss: 0.019637 (18.0s)

Epoch 13/50... Train loss: 0.011921 Val loss: 0.025595 (19.2s)

Epoch 14/50... Train loss: 0.011808 Val loss: 0.017069 (18.4s)

Epoch 15/50... Train loss: 0.011347 Val loss: 0.01950

In [8]:
# Save enhanced model
from src.utils import config as _cfg

checkpoint_name = "best_model_with_news.pt"
checkpoint_path = _cfg.PROJECT_ROOT / config.paths.models_dir / checkpoint_name
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)

print(f"\n–ó–∞–ø–∞–∑–≤–∞–Ω–µ –Ω–∞ –º–æ–¥–µ–ª–∞ –≤: {checkpoint_path}")
torch.save({
    'epoch': len(history['train_losses']) - 1,
    'model_state_dict': model.state_dict(),
    'score': history['best_val_loss'],
    'model_type': 'StockTransformerWithNews',
    'config': config,
}, checkpoint_path)

if checkpoint_path.exists():
    import time
    file_size = checkpoint_path.stat().st_size / (1024 * 1024)  # MB
    mtime = time.ctime(checkpoint_path.stat().st_mtime)
    print(f"‚úì –§–∞–π–ª—ä—Ç –µ –∑–∞–ø–∞–∑–µ–Ω —É—Å–ø–µ—à–Ω–æ!")
    print(f"  –†–∞–∑–º–µ—Ä: {file_size:.2f} MB")
    print(f"  –ú–æ–¥–∏—Ñ–∏—Ü–∏—Ä–∞–Ω: {mtime}")


–ó–∞–ø–∞–∑–≤–∞–Ω–µ –Ω–∞ –º–æ–¥–µ–ª–∞ –≤: C:\Users\vyoto\OneDrive\Desktop\CODE STUFF\Stock price prediction\models\checkpoints\best_model_with_news.pt
‚úì –§–∞–π–ª—ä—Ç –µ –∑–∞–ø–∞–∑–µ–Ω —É—Å–ø–µ—à–Ω–æ!
  –†–∞–∑–º–µ—Ä: 5.89 MB
  –ú–æ–¥–∏—Ñ–∏—Ü–∏—Ä–∞–Ω: Fri Feb 13 20:30:30 2026
