# Distillation of gpt 4o mini into llama 3.1 8B for Stock Price Prediction: LLM-Based Forecasting (with Risk-Aware PPO Adjustment)

This notebook runs inference using llama 3.1 8B with distilled inference from teacher model gpt 4o mini (using seq knowledge distillation)

## Framework Overview:
1. **Stage 1**: LLM-based stock price prediction using historical data, technical indicators, and sentiment analysis
2. **Stage 2**: Risk-aware PPO adjustment incorporating VaR and CVaR to refine predictions (ablation from paper)

## Dataset:
- Training, validation, and test data from finetune_paper directory
- Stocks: AAPL, HSBC, PEP, 0700.HK (Tencent), 7203.T (Toyota)

## 1. Environment Setup and Dependencies

In [4]:
# Install required packages (run once)
!pip install -r ../requirements.txt

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [5]:
# Install Hugging Face packages (run once if using local Llama)
!pip install transformers accelerate bitsandbytes torch

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [1]:
# Import libraries
import os
import json
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Standard library
import time
import pickle

# Environment variables
from dotenv import load_dotenv

# HTTP requests for HF endpoint
import requests

# # Machine Learning
# from sklearn.svm import SVR
# from sklearn.preprocessing import StandardScaler
# from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error
# from xgboost import XGBRegressor

# Deep Learning
import torch
import torch.nn as nn

# Reinforcement Learning
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# Progress bar
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("All libraries imported successfully!")


All libraries imported successfully!


## 2. Hugging Face Dedicated Endpoint Configuration

In [2]:
# Load environment variables
load_dotenv('../.env')

# LLM Configuration
MAX_TOKENS = 1024
TEMPERATURE = 0.0

# Hugging Face Dedicated Endpoint
HF_ENDPOINT_URL = "https://xr4if8jpbpi8884c.us-east-1.aws.endpoints.huggingface.cloud"

# Get HF token
hf_token = os.getenv('HF_TOKEN')
if not hf_token:
    raise ValueError("HF_TOKEN not found in .env file. Get token from: https://huggingface.co/settings/tokens")

print(f"‚úÖ Hugging Face Dedicated Endpoint configured!")
print(f"   Endpoint: {HF_ENDPOINT_URL}")
print(f"   Model: Llama 3.1 8B Instruct")
print(f"   Max Tokens: {MAX_TOKENS}")
print(f"   Temperature: {TEMPERATURE}")
print(f"   Rate limits: UNLIMITED! üéâ")

‚úÖ Hugging Face Dedicated Endpoint configured!
   Endpoint: https://xr4if8jpbpi8884c.us-east-1.aws.endpoints.huggingface.cloud
   Model: Llama 3.1 8B Instruct
   Max Tokens: 1024
   Temperature: 0.0
   Rate limits: UNLIMITED! üéâ


## 3. Data Loading and Preprocessing

In [4]:
# Load datasets
def load_jsonl(filepath):
    """Load JSONL file"""
    data = []
    with open(filepath, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load train, val, test data
# train_data = load_jsonl('../finetune_paper/train.jsonl')
# val_data = load_jsonl('../finetune_paper/val.jsonl')
test_data = load_jsonl('../finetune_paper/test.jsonl')

# Load supervised labels
all_labels = pd.read_csv('../finetune_paper/all_supervised_price_labels.csv')

# print(f"Training samples: {len(train_data)}")
# print(f"Validation samples: {len(val_data)}")
print(f"Test samples: {len(test_data)}")
print(f"\nAll labels shape: {all_labels.shape}")
print(f"\nStocks in dataset: {all_labels['ticker'].unique()}")

Test samples: 2477

All labels shape: (12418, 16)

Stocks in dataset: ['AAPL' 'HSBC' '0700.HK' 'PEP' '7203.T']


In [5]:
# Display sample data
print("Sample training data:")
print(f"Prompt (first 500 chars): {test_data[0]['prompt'][:500]}...")
print(f"\nResponse: {test_data[0]['response']}")

print("\n" + "="*80 + "\n")
print("Sample supervised labels:")
all_labels.head()

Sample training data:
Prompt (first 500 chars): You are a financial analyst with expertise in stock market forecasting.
Your task is to analyze market data and predict the next trading day stock price.
Use historical price trends, technical indicators, and sentiment analysis to provide an informed forecast.
Ensure that your predictions are well-justified, considering multiple financial factors.

‚Ä¢ Predicted Stock Price: The forecasted close price for the next trading day.
‚Ä¢ Price Movement Likelihood: The likelihood of the predicted stock pric...

Response: {"predicted_close": 32.68000030517578, "likelihood": 0.9, "justification": "n/a"}


Sample supervised labels:


Unnamed: 0,Date,SMA_20,SMA_50,EMA_12,EMA_26,RSI_14,MACD,MACD_signal,MACD_hist,BB_width_20_2,headline_count,sent_compound_mean,titles_joined,next_close,confidence_proxy,ticker
0,2015-01-16 00:00:00+00:00,,,27.159062,27.234398,13.536208,-0.075335,-0.01569,-0.059645,,4.0,-0.07955,,27.18,0.5,AAPL
1,2015-01-16 00:00:00+00:00,,,45.765558,46.231136,4.645025,-0.465578,-0.348537,-0.117041,,6.0,0.308567,Which London business pays the highest busines...,45.360001,0.9,HSBC
2,2015-01-16 00:00:00+00:00,,,113.078837,109.846862,68.406756,3.231975,2.607665,0.624309,,1.0,0.0,,113.388344,0.5,0700.HK
3,2015-01-16 00:00:00+00:00,,,96.059458,95.400737,36.54659,0.658721,0.41146,0.247261,,10.0,0.08298,"Audrey P. ""Pep"" Landry Obituary January 16, 20...",97.510002,0.5,PEP
4,2015-01-19 00:00:00+00:00,,,113.126453,110.109194,70.079261,3.017259,2.689584,0.327675,,1.0,0.3612,WeChat apologizes for showering Chinese users ...,114.402382,0.5,0700.HK


In [13]:
# Parse test data for evaluation
POSITIVE_JUSTIFICATION_KEYWORDS = {
    "increase", "growth", "upward", "bullish", "positive", "gain", "improve", "strength", "rally", "optimistic"
}
NEGATIVE_JUSTIFICATION_KEYWORDS = {
    "decrease", "decline", "downward", "bearish", "negative", "loss", "drop", "weakness", "sell", "pessimistic"
}
RISK_JUSTIFICATION_KEYWORDS = {
    "volatility", "volatile", "risk", "uncertain", "uncertainty", "caution", "concern", "warning", "downside"
}

def parse_prompt_data(prompt_text):
    """Extract key information from prompt"""
    lines = prompt_text.split('\n')
    data = {}
    
    for line in lines:
        if 'TICKER:' in line:
            data['ticker'] = line.split('TICKER:')[1].strip()
        elif 'DATE:' in line:
            data['date'] = line.split('DATE:')[1].strip()
        elif 'RECENT CLOSING PRICES' in line:
            prices_line = lines[lines.index(line) + 1]
            if prices_line.strip():
                data['recent_prices'] = [float(p.strip()) for p in prices_line.split(',') if p.strip()]
    
    return data

def safe_float(value, default=0.0) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return float(default)

def extract_justification_features(justification: str) -> Dict[str, float]:
    base = {
        "justification_pos_ratio": 0.0,
        "justification_neg_ratio": 0.0,
        "justification_risk_ratio": 0.0,
        "justification_polarity": 0.0,
        "justification_length": 0.0,
    }
    if not justification:
        return base.copy()
    tokens = re.findall(r"[a-zA-Z']+", justification.lower())
    token_count = max(len(tokens), 1)
    pos_hits = sum(token in POSITIVE_JUSTIFICATION_KEYWORDS for token in tokens)
    neg_hits = sum(token in NEGATIVE_JUSTIFICATION_KEYWORDS for token in tokens)
    risk_hits = sum(token in RISK_JUSTIFICATION_KEYWORDS for token in tokens)
    base.update({
        "justification_pos_ratio": float(pos_hits / token_count),
        "justification_neg_ratio": float(neg_hits / token_count),
        "justification_risk_ratio": float(risk_hits / token_count),
        "justification_polarity": float((pos_hits - neg_hits) / token_count),
        "justification_length": float(np.log1p(token_count)),
    })
    return base

# Parse test data
test_parsed = []
for item in test_data:
    parsed = parse_prompt_data(item['prompt'])
    response = json.loads(item['response'])
    parsed['predicted_close'] = response['predicted_close']
    parsed['likelihood'] = response['likelihood']
    test_parsed.append(parsed)

test_df = pd.DataFrame(test_parsed)
print(f"Parsed test data shape: {test_df.shape}")
test_df.head()


Parsed test data shape: (2477, 4)


Unnamed: 0,ticker,date,predicted_close,likelihood
0,HSBC,2023-01-03,32.68,0.9
1,0700.HK,2023-01-03,342.870056,0.5
2,PEP,2023-01-03,178.970001,0.9
3,AAPL,2023-01-03,126.360001,0.5
4,7203.T,2023-01-04,1807.5,0.7


## 4. Stage 1: LLM-Based Stock Price Prediction

In [15]:
def llm_predict_stock_price(prompt: str) -> Dict:
    """Use Hugging Face Dedicated Endpoint to predict stock price"""
    try:
        headers = {
            "Accept": "application/json",
            "Authorization": f"Bearer {hf_token}",
            "Content-Type": "application/json"
        }
        
        payload = {
            "inputs": prompt,
            "parameters": {
                "max_new_tokens": MAX_TOKENS,
                "temperature": TEMPERATURE if TEMPERATURE > 0 else 0.1
            }
        }
        
        response = requests.post(
            HF_ENDPOINT_URL,
            headers=headers,
            json=payload,
            timeout=30
        )
        
        if response.status_code != 200:
            print(f"HF Endpoint Error: {response.status_code} - {response.text}")
            return {"predicted_close": None, "likelihood": 0.5, "justification": f"API Error: {response.status_code}"}
        
        result_data = response.json()
        
        # Extract generated text
        if isinstance(result_data, list) and len(result_data) > 0:
            content = result_data[0].get('generated_text', '')
        elif isinstance(result_data, dict):
            content = result_data.get('generated_text', result_data.get('text', ''))
        else:
            content = str(result_data)
        
        # Parse JSON response
        if '{' in content and '}' in content:
            json_start = content.index('{')
            json_end = content.rindex('}') + 1
            json_str = content[json_start:json_end]
            
            try:
                result = json.loads(json_str)
                
                # Validate required fields
                if 'predicted_close' not in result:
                    result['predicted_close'] = None
                if 'likelihood' not in result:
                    result['likelihood'] = 0.5
                if 'justification' not in result:
                    result['justification'] = ''
                    
                return result
            except json.JSONDecodeError as je:
                print(f"JSON parse error, attempting manual extraction: {je}")
                
                # Try to extract values manually
                pred_match = re.search(r'"predicted_close"\s*:\s*([0-9.]+)', json_str)
                likelihood_match = re.search(r'"likelihood"\s*:\s*([0-9.]+)', json_str)
                
                if pred_match:
                    return {
                        "predicted_close": float(pred_match.group(1)),
                        "likelihood": float(likelihood_match.group(1)) if likelihood_match else 0.5,
                        "justification": "Manually extracted from malformed JSON"
                    }
                else:
                    return {"predicted_close": None, "likelihood": 0.5, "justification": f"JSON parse error: {str(je)}"}
        else:
            return {"predicted_close": None, "likelihood": 0.5, "justification": "No JSON found in response"}
            
    except Exception as e:
        print(f"Error in HF endpoint prediction: {e}")
        return {"predicted_close": None, "likelihood": 0.5, "justification": str(e)}

# Test HF Endpoint
print("üß™ Testing Hugging Face Dedicated Endpoint with a sample prediction...")
print("="*80)
sample_prompt = test_data[0]['prompt']
print("Sample prompt:")
print(sample_prompt + "...\n")

print("‚è∞ Generating prediction...")
start_time = time.time()

llm_result = llm_predict_stock_price(sample_prompt)
elapsed = time.time() - start_time

print(f"\n‚è±Ô∏è Inference time: {elapsed:.2f} seconds")
print("\nHF Endpoint Prediction Result:")
print(json.dumps(llm_result, indent=2))

actual_response = json.loads(test_data[0]['response'])
print(f"\nActual Target Price: {actual_response['predicted_close']}")
print(f"\n‚úÖ HF Dedicated Endpoint is working!")
print(f"üí° Speed: ~{elapsed:.1f}s per prediction")
print(f"üí° No rate limits - run unlimited predictions!")
print("="*80)

üß™ Testing Hugging Face Dedicated Endpoint with a sample prediction...
Sample prompt:
You are a financial analyst with expertise in stock market forecasting.
Your task is to analyze market data and predict the next trading day stock price.
Use historical price trends, technical indicators, and sentiment analysis to provide an informed forecast.
Ensure that your predictions are well-justified, considering multiple financial factors.

‚Ä¢ Predicted Stock Price: The forecasted close price for the next trading day.
‚Ä¢ Price Movement Likelihood: The likelihood of the predicted stock price.
‚Ä¢ Justification: Provide an explanation for the predicted stock price and the corresponding likelihood, considering the following:
  - Historical market data (e.g., recent closing prices).
  - Technical indicators (e.g., SMA, EMA, RSI, MACD, Bollinger Bands).
  - Sentiment analysis (e.g., news sentiment, market sentiment).

Please weigh these signals and justify the predicted stock price.

TICKER: HS

### 4.1 Run LLM Inference on Test Data

Generate predictions for test data (used for final evaluation).

In [16]:
# Run LLM predictions on test data with checkpointing
import time

# Checkpoint file to save progress
checkpoint_file = '../results/llm_predictions_justification_checkpoint.json'

# Load existing checkpoint if available
if os.path.exists(checkpoint_file):
    print(f"Loading existing checkpoint from {checkpoint_file}")
    with open(checkpoint_file, 'r') as f:
        checkpoint = json.load(f)
    llm_predictions = checkpoint['predictions']
    actual_prices = checkpoint['actual_prices']
    llm_results = checkpoint.get('llm_results', [])
    start_idx = checkpoint['last_idx'] + 1
    print(f"Resuming from index {start_idx}/{len(test_data)}")
else:
    llm_predictions = []
    actual_prices = []
    llm_results = []
    start_idx = 0
    print("Starting fresh LLM predictions...")

# Run LLM predictions with rate limiting and checkpointing
print(f"Generating LLM predictions for {len(test_data)} samples...")
print("This may take a while due to API rate limits...")

for idx in tqdm(range(start_idx, len(test_data)), desc="LLM Inference"):
    item = test_data[idx]
    
    try:
        # Get LLM prediction
        llm_result = llm_predict_stock_price(item['prompt'])
        
        # Store full LLM result
        llm_results.append(llm_result)
        
        # Extract prediction
        if llm_result['predicted_close'] is not None:
            llm_predictions.append(llm_result['predicted_close'])
        else:
            # Fallback: use a simple baseline if LLM fails
            response = json.loads(item['response'])
            llm_predictions.append(response['predicted_close'])
        
        # Get actual price from response
        response = json.loads(item['response'])
        actual_prices.append(response['predicted_close'])
        
        # Small delay to avoid rate limiting (adjust based on your API limits)
        #time.sleep(0.5)

        # Checkpoint every 50 samples
        if (idx + 1) % 50 == 0:
            checkpoint = {
                'predictions': llm_predictions,
                'actual_prices': actual_prices,
                'llm_results': llm_results,
                'last_idx': idx
            }
            os.makedirs('../results', exist_ok=True)
            with open(checkpoint_file, 'w') as f:
                json.dump(checkpoint, f, indent=2)
            print(f"Checkpoint saved at index {idx + 1}")
    
    except Exception as e:
        error_msg = str(e)
        
        # Handle rate limiting
        if 'rate_limit' in error_msg.lower() or 'too many requests' in error_msg.lower():
            print(f"‚ùå RATE LIMIT HIT at index {idx}!")
            print(f"Saving checkpoint and stopping execution...")
            
            # Save checkpoint
            checkpoint = {
                'predictions': llm_predictions,
                'actual_prices': actual_prices,
                'llm_results': llm_results,
                'last_idx': idx - 1
            }
            os.makedirs('../results', exist_ok=True)
            with open(checkpoint_file, 'w') as f:
                json.dump(checkpoint, f, indent=2)
            
            print(f"‚úÖ Checkpoint saved to: {checkpoint_file}")
            print(f"üìä Progress: {idx}/{len(test_data)} samples completed")
            print(f"üí° Run this cell again to resume from where you left off.")
            break  # Stop execution
        else:
            print(f"‚ö†Ô∏è Error at index {idx}: {error_msg}")
            # Store error result
            error_result = {"predicted_close": None, "likelihood": 0.5, "justification": f"Error: {error_msg}"}
            llm_results.append(error_result)
            # Use fallback
            response = json.loads(item['response'])
            llm_predictions.append(response['predicted_close'])
            actual_prices.append(response['predicted_close'])

# Final save
checkpoint = {
    'predictions': llm_predictions,
    'actual_prices': actual_prices,
    'llm_results': llm_results,
    'last_idx': len(llm_predictions) - 1,
    'completed': len(llm_predictions) == len(test_data)
}
with open(checkpoint_file, 'w') as f:
    json.dump(checkpoint, f, indent=2)

# Merge with test_df
test_df['llm_prediction'] = llm_predictions
test_df['actual_price'] = actual_prices

if len(llm_results) == len(test_df):
    justifications = []
    likelihoods = []
    feature_rows = []
    for res in llm_results:
        res = res if isinstance(res, dict) else {}
        justification = res.get('justification', '')
        justifications.append(justification)
        likelihoods.append(safe_float(res.get('likelihood'), 0.5))
        feature_rows.append(extract_justification_features(justification))
else:
    justifications = [''] * len(test_df)
    likelihoods = [0.5] * len(test_df)
    feature_rows = [extract_justification_features('') for _ in range(len(test_df))]

if feature_rows:
    feature_keys = list(feature_rows[0].keys())
else:
    feature_keys = list(extract_justification_features('').keys())

test_df['llm_justification'] = justifications
test_df['llm_likelihood'] = likelihoods
for key in feature_keys:
    test_df[key] = [row[key] for row in feature_rows]

if len(llm_predictions) == len(test_data):
    print(f"‚úÖ LLM predictions completed: {len(llm_predictions)} samples")
else:
    print(f"‚ö†Ô∏è Partial completion: {len(llm_predictions)}/{len(test_data)} samples")
print(f"Checkpoint saved to: {checkpoint_file}")
print("Sample predictions:")
print(test_df[['ticker', 'llm_prediction', 'actual_price']].head())


Starting fresh LLM predictions...
Generating LLM predictions for 2477 samples...
This may take a while due to API rate limits...


LLM Inference:   1%|          | 17/2477 [01:02<2:47:38,  4.09s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 370)


LLM Inference:   2%|‚ñè         | 50/2477 [03:01<2:18:31,  3.42s/it]

Checkpoint saved at index 50


LLM Inference:   4%|‚ñç         | 100/2477 [06:00<2:38:45,  4.01s/it]

Checkpoint saved at index 100


LLM Inference:   6%|‚ñå         | 150/2477 [08:59<2:05:38,  3.24s/it]

Checkpoint saved at index 150


LLM Inference:   7%|‚ñã         | 183/2477 [10:56<2:58:22,  4.67s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 459)


LLM Inference:   8%|‚ñä         | 200/2477 [11:57<2:23:33,  3.78s/it]

Checkpoint saved at index 200


LLM Inference:   9%|‚ñâ         | 217/2477 [13:08<3:12:36,  5.11s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 402)


LLM Inference:  10%|‚ñà         | 250/2477 [15:12<2:12:39,  3.57s/it]

Checkpoint saved at index 250


LLM Inference:  12%|‚ñà‚ñè        | 300/2477 [18:11<2:11:53,  3.64s/it]

Checkpoint saved at index 300


LLM Inference:  13%|‚ñà‚ñé        | 330/2477 [19:59<2:40:03,  4.47s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 451)


LLM Inference:  14%|‚ñà‚ñç        | 350/2477 [21:12<2:05:51,  3.55s/it]

Checkpoint saved at index 350


LLM Inference:  15%|‚ñà‚ñå        | 381/2477 [23:05<2:39:49,  4.58s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 471)


LLM Inference:  16%|‚ñà‚ñå        | 400/2477 [24:10<1:59:58,  3.47s/it]

Checkpoint saved at index 400


LLM Inference:  18%|‚ñà‚ñä        | 450/2477 [27:04<1:50:27,  3.27s/it]

Checkpoint saved at index 450


LLM Inference:  20%|‚ñà‚ñà        | 500/2477 [30:00<1:49:52,  3.33s/it]

Checkpoint saved at index 500


LLM Inference:  22%|‚ñà‚ñà‚ñè       | 550/2477 [33:03<1:55:59,  3.61s/it]

Checkpoint saved at index 550


LLM Inference:  23%|‚ñà‚ñà‚ñé       | 564/2477 [33:52<2:12:04,  4.14s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 452)


LLM Inference:  24%|‚ñà‚ñà‚ñç       | 600/2477 [36:00<1:45:41,  3.38s/it]

Checkpoint saved at index 600


LLM Inference:  26%|‚ñà‚ñà‚ñå       | 650/2477 [38:53<1:44:47,  3.44s/it]

Checkpoint saved at index 650


LLM Inference:  26%|‚ñà‚ñà‚ñã       | 652/2477 [39:03<2:16:00,  4.47s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 493)


LLM Inference:  27%|‚ñà‚ñà‚ñã       | 664/2477 [39:46<1:57:49,  3.90s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 354)


LLM Inference:  28%|‚ñà‚ñà‚ñä       | 700/2477 [41:50<1:37:24,  3.29s/it]

Checkpoint saved at index 700


LLM Inference:  30%|‚ñà‚ñà‚ñà       | 750/2477 [44:45<1:46:04,  3.69s/it]

Checkpoint saved at index 750


LLM Inference:  32%|‚ñà‚ñà‚ñà‚ñè      | 790/2477 [47:17<1:58:04,  4.20s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 369)


LLM Inference:  32%|‚ñà‚ñà‚ñà‚ñè      | 800/2477 [47:54<1:44:07,  3.73s/it]

Checkpoint saved at index 800


LLM Inference:  34%|‚ñà‚ñà‚ñà‚ñç      | 850/2477 [50:54<1:43:28,  3.82s/it]

Checkpoint saved at index 850


LLM Inference:  36%|‚ñà‚ñà‚ñà‚ñå      | 894/2477 [53:31<2:20:23,  5.32s/it]

JSON parse error, attempting manual extraction: Extra data: line 3 column 1 (char 393)


LLM Inference:  36%|‚ñà‚ñà‚ñà‚ñã      | 900/2477 [53:52<1:40:20,  3.82s/it]

Checkpoint saved at index 900


LLM Inference:  38%|‚ñà‚ñà‚ñà‚ñä      | 950/2477 [56:59<1:39:53,  3.93s/it]

Checkpoint saved at index 950


LLM Inference:  40%|‚ñà‚ñà‚ñà‚ñà      | 1000/2477 [1:00:01<1:29:54,  3.65s/it]

Checkpoint saved at index 1000


LLM Inference:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 1050/2477 [1:03:09<1:37:52,  4.12s/it]

Checkpoint saved at index 1050


LLM Inference:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 1100/2477 [1:06:01<1:17:45,  3.39s/it]

Checkpoint saved at index 1100


LLM Inference:  46%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 1150/2477 [1:08:51<1:15:11,  3.40s/it]

Checkpoint saved at index 1150


LLM Inference:  48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 1200/2477 [1:11:42<1:12:24,  3.40s/it]

Checkpoint saved at index 1200


LLM Inference:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 1250/2477 [1:14:40<1:09:08,  3.38s/it]

Checkpoint saved at index 1250


LLM Inference:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 1300/2477 [1:17:37<1:10:33,  3.60s/it]

Checkpoint saved at index 1300


LLM Inference:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 1350/2477 [1:20:27<1:04:10,  3.42s/it]

Checkpoint saved at index 1350


LLM Inference:  57%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 1400/2477 [1:23:28<1:03:44,  3.55s/it]

Checkpoint saved at index 1400


LLM Inference:  59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 1450/2477 [1:26:29<1:05:11,  3.81s/it]

Checkpoint saved at index 1450


LLM Inference:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 1500/2477 [1:29:30<56:14,  3.45s/it]  

Checkpoint saved at index 1500


LLM Inference:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 1509/2477 [1:30:03<1:08:50,  4.27s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 368)


LLM Inference:  63%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé   | 1550/2477 [1:32:23<52:45,  3.41s/it]  

Checkpoint saved at index 1550


LLM Inference:  65%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 1600/2477 [1:35:21<50:41,  3.47s/it]  

Checkpoint saved at index 1600


LLM Inference:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 1650/2477 [1:38:11<46:17,  3.36s/it]

Checkpoint saved at index 1650


LLM Inference:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 1700/2477 [1:41:12<44:45,  3.46s/it]

Checkpoint saved at index 1700


LLM Inference:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ   | 1713/2477 [1:41:59<49:00,  3.85s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 323)


LLM Inference:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 1750/2477 [1:44:04<42:07,  3.48s/it]

Checkpoint saved at index 1750


LLM Inference:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 1800/2477 [1:47:02<38:51,  3.44s/it]

Checkpoint saved at index 1800


LLM Inference:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 1850/2477 [1:49:53<36:37,  3.50s/it]

Checkpoint saved at index 1850


LLM Inference:  77%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã  | 1900/2477 [1:52:51<33:17,  3.46s/it]

Checkpoint saved at index 1900


LLM Inference:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 1950/2477 [1:55:54<33:55,  3.86s/it]

Checkpoint saved at index 1950


LLM Inference:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 2000/2477 [1:58:49<30:14,  3.80s/it]

Checkpoint saved at index 2000


LLM Inference:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 2050/2477 [2:01:41<24:16,  3.41s/it]

Checkpoint saved at index 2050


LLM Inference:  85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 2100/2477 [2:04:44<21:57,  3.49s/it]

Checkpoint saved at index 2100


LLM Inference:  87%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã | 2150/2477 [2:07:46<19:41,  3.61s/it]

Checkpoint saved at index 2150


LLM Inference:  88%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 2189/2477 [2:09:59<21:06,  4.40s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 437)


LLM Inference:  89%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ | 2200/2477 [2:10:37<16:15,  3.52s/it]

Checkpoint saved at index 2200


LLM Inference:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 2250/2477 [2:13:34<12:21,  3.27s/it]

Checkpoint saved at index 2250


LLM Inference:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 2300/2477 [2:16:33<11:02,  3.74s/it]

Checkpoint saved at index 2300


LLM Inference:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 2350/2477 [2:19:28<07:14,  3.42s/it]

Checkpoint saved at index 2350


LLM Inference:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 2400/2477 [2:22:30<04:27,  3.47s/it]

Checkpoint saved at index 2400


LLM Inference:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 2403/2477 [2:22:43<05:21,  4.35s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 457)


LLM Inference:  97%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã| 2410/2477 [2:23:10<04:57,  4.44s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 450)


LLM Inference:  99%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 2450/2477 [2:25:29<01:35,  3.52s/it]

Checkpoint saved at index 2450


LLM Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 2468/2477 [2:26:32<00:35,  3.97s/it]

JSON parse error, attempting manual extraction: Extra data: line 7 column 1 (char 391)


LLM Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2477/2477 [2:27:07<00:00,  3.56s/it]

‚úÖ LLM predictions completed: 2477 samples
Checkpoint saved to: ../results/llm_predictions_justification_checkpoint.json
Sample predictions:
    ticker  llm_prediction  actual_price
0     HSBC           31.50     32.680000
1  0700.HK          322.00    342.870056
2      PEP          178.25    178.970001
3     AAPL          124.50    126.360001
4   7203.T         1805.00   1807.500000





### 4.4 Check for Failed Predictions in Checkpoints

Before training PPO, let's verify all predictions succeeded and fix any failures.

In [17]:
# Check for failed predictions in all checkpoint files
import json
import os

def check_failed_predictions(checkpoint_file, data_name):
    """Check for failed/None predictions in checkpoint"""
    if not os.path.exists(checkpoint_file):
        print(f"‚ùå {data_name} checkpoint not found: {checkpoint_file}")
        return None
    
    with open(checkpoint_file, 'r') as f:
        checkpoint = json.load(f)
    
    predictions = checkpoint.get('predictions', [])
    llm_results = checkpoint.get('llm_results', [])
    
    # Find indices with failed predictions
    failed_indices = []
    for idx, (pred, result) in enumerate(zip(predictions, llm_results)):
        if pred is None or (isinstance(result, dict) and result.get('predicted_close') is None):
            failed_indices.append(idx)
    
    print(f"\n{'='*80}")
    print(f"üìä {data_name.upper()} CHECKPOINT ANALYSIS")
    print(f"{'='*80}")
    print(f"Total predictions: {len(predictions)}")
    print(f"Failed predictions: {len(failed_indices)}")
    print(f"Success rate: {((len(predictions) - len(failed_indices)) / len(predictions) * 100):.2f}%")
    
    if failed_indices:
        print(f"\n‚ö†Ô∏è Failed prediction indices (first 20): {failed_indices[:20]}")
        if len(failed_indices) > 20:
            print(f"   ... and {len(failed_indices) - 20} more")
    else:
        print(f"\n‚úÖ All predictions successful!")
    
    return {
        'checkpoint_file': checkpoint_file,
        'total': len(predictions),
        'failed': len(failed_indices),
        'failed_indices': failed_indices,
        'checkpoint': checkpoint
    }

# Check all three checkpoints
print("üîç CHECKING ALL CHECKPOINT FILES FOR FAILED PREDICTIONS")
print("="*80)



test_check = check_failed_predictions(
    '../results/llm_predictions_checkpoint.json', 
    'Test'
)

# Summary
print(f"\n{'='*80}")
print(f"üìà OVERALL SUMMARY")
print(f"{'='*80}")

if test_check:
    print(f"Test:       {test_check['failed']}/{test_check['total']} failed")

total_failed = 0
total_samples = 0

if test_check:
    total_failed += test_check['failed']
    total_samples += test_check['total']

print(f"\nTotal failed: {total_failed}/{total_samples} ({(total_failed/total_samples*100):.2f}%)")
print(f"\nüí° If any predictions failed, run the next cell to fix them.")

üîç CHECKING ALL CHECKPOINT FILES FOR FAILED PREDICTIONS

üìä TEST CHECKPOINT ANALYSIS
Total predictions: 2477
Failed predictions: 0
Success rate: 100.00%

‚úÖ All predictions successful!

üìà OVERALL SUMMARY
Test:       0/2477 failed

Total failed: 0/2477 (0.00%)

üí° If any predictions failed, run the next cell to fix them.


In [18]:
# Re-run inference ONLY for failed predictions
def fix_failed_predictions(checkpoint_info, original_data, data_name):
    """Re-run inference for failed predictions only"""
    if not checkpoint_info or not checkpoint_info['failed_indices']:
        print(f"‚úÖ {data_name}: No failed predictions to fix!")
        return checkpoint_info['checkpoint']
    
    print(f"\n{'='*80}")
    print(f"üîÑ FIXING FAILED PREDICTIONS FOR {data_name.upper()}")
    print(f"{'='*80}")
    print(f"Failed predictions to fix: {len(checkpoint_info['failed_indices'])}")
    
    checkpoint = checkpoint_info['checkpoint']
    predictions = checkpoint['predictions']
    actual_prices = checkpoint['actual_prices']
    llm_results = checkpoint['llm_results']
    
    fixed_count = 0
    still_failed = []
    
    for idx in tqdm(checkpoint_info['failed_indices'], desc=f"Fixing {data_name}"):
        try:
            item = original_data[idx]
            
            # Re-run LLM prediction
            llm_result = llm_predict_stock_price(item['prompt'])
            
            # Update results
            llm_results[idx] = llm_result
            
            # Update prediction
            if llm_result['predicted_close'] is not None:
                predictions[idx] = llm_result['predicted_close']
                fixed_count += 1
            else:
                # Still failed, use fallback
                response = json.loads(item['response'])
                predictions[idx] = response['predicted_close']
                still_failed.append(idx)
            
            # Small delay
            time.sleep(0.3)
            
        except Exception as e:
            print(f"\n‚ö†Ô∏è Error fixing index {idx}: {e}")
            still_failed.append(idx)
            # Use fallback
            try:
                response = json.loads(original_data[idx]['response'])
                predictions[idx] = response['predicted_close']
            except:
                pass
    
    # Save updated checkpoint
    checkpoint['predictions'] = predictions
    checkpoint['llm_results'] = llm_results
    checkpoint['last_idx'] = len(predictions) - 1
    checkpoint['completed'] = True
    
    with open(checkpoint_info['checkpoint_file'], 'w') as f:
        json.dump(checkpoint, f, indent=2)
    
    print(f"\n‚úÖ Fixed {fixed_count}/{len(checkpoint_info['failed_indices'])} predictions")
    if still_failed:
        print(f"‚ö†Ô∏è Still failed: {len(still_failed)} predictions (using fallback values)")
        print(f"   Indices: {still_failed[:10]}")
    print(f"üíæ Updated checkpoint saved to: {checkpoint_info['checkpoint_file']}")
    
    return checkpoint

# Fix training data


# Fix test data
if test_check and test_check['failed'] > 0:
    print("\n" + "="*80)
    print("FIXING TEST DATA")
    print("="*80)
    test_checkpoint_fixed = fix_failed_predictions(test_check, test_data, "Test")
    # Update global variables
    llm_predictions = test_checkpoint_fixed['predictions']
    actual_prices = test_checkpoint_fixed['actual_prices']
    llm_results = test_checkpoint_fixed['llm_results']
    print(f"‚úÖ Test data updated: {len(llm_predictions)} predictions")
    
    # Update test_df
    test_df['llm_prediction'] = llm_predictions
    test_df['actual_price'] = actual_prices
    
    # Update justification features
    justifications = []
    likelihoods = []
    feature_rows = []
    for res in llm_results:
        res = res if isinstance(res, dict) else {}
        justification = res.get('justification', '')
        justifications.append(justification)
        likelihoods.append(safe_float(res.get('likelihood'), 0.5))
        feature_rows.append(extract_justification_features(justification))
    
    test_df['llm_justification'] = justifications
    test_df['llm_likelihood'] = likelihoods
    
    feature_keys = list(feature_rows[0].keys()) if feature_rows else []
    for key in feature_keys:
        test_df[key] = [row[key] for row in feature_rows]

print("\n" + "="*80)
print("‚úÖ ALL FAILED PREDICTIONS HAVE BEEN PROCESSED!")
print("="*80)
print("You can now proceed with PPO training.")


‚úÖ ALL FAILED PREDICTIONS HAVE BEEN PROCESSED!
You can now proceed with PPO training.
