In [None]:
# Section 4: Fixed LLM Scoring - Only Probabilities for Candidates

import random
import json
import torch
import os
from datetime import datetime
from itertools import groupby
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
from pathlib import Path
import gc
from tqdm import tqdm
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class DATASTORE:
    """Class to manage and process the query and candidate data"""
    def __init__(self, query_file_path, candidate_file_path):
        self.query_file_path = query_file_path
        self.candidate_file_path = candidate_file_path

        # Validate files exist
        if not os.path.exists(query_file_path):
            raise FileNotFoundError(f"Query file not found: {query_file_path}")
        if not os.path.exists(candidate_file_path):
            raise FileNotFoundError(f"Candidate file not found: {candidate_file_path}")

        # Load and filter query data to only include rise/fall movements
        all_queries = self.load_json_lines(query_file_path)
        self.query_sequence_list = self.filter_rise_fall_queries(all_queries)
        self.candidate_sequence_list = self.load_json_lines(candidate_file_path)

        logger.info(f"Loaded {len(all_queries)} total queries, filtered to {len(self.query_sequence_list)} rise/fall queries")
        logger.info(f"Loaded {len(self.candidate_sequence_list)} candidates")

    def load_json_lines(self, file_path):
        """Load JSON lines with error handling"""
        data = []
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        data.append(json.loads(line.strip()))
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping invalid JSON at line {line_num} in {file_path}: {e}")
        except Exception as e:
            logger.error(f"Error reading file {file_path}: {e}")
            raise
        return data

    def filter_rise_fall_queries(self, query_list):
        """Filter queries to only include those with 'rise' or 'fall' movements"""
        filtered_queries = []
        freeze_count = 0
        rise_count = 0
        fall_count = 0

        for query in query_list:
            movement = query.get('movement', '').lower()
            if movement == 'rise':
                filtered_queries.append(query)
                rise_count += 1
            elif movement == 'fall':
                filtered_queries.append(query)
                fall_count += 1
            elif movement == 'freeze':
                freeze_count += 1
                continue

        logger.info(f"Query movement distribution:")
        logger.info(f"  - Rise: {rise_count}")
        logger.info(f"  - Fall: {fall_count}")
        logger.info(f"  - Freeze (skipped): {freeze_count}")
        logger.info(f"  - Total processed: {rise_count + fall_count}")

        return filtered_queries

    def get_query_amount(self):
        return len(self.query_sequence_list)

    def get_rise_fall_stats(self):
        """Get statistics about rise/fall distribution"""
        rise_count = sum(1 for q in self.query_sequence_list if q.get('movement', '').lower() == 'rise')
        fall_count = sum(1 for q in self.query_sequence_list if q.get('movement', '').lower() == 'fall')
        return {"rise": rise_count, "fall": fall_count, "total": rise_count + fall_count}

    def group_query_by_date(self):
        """Group queries by date with error handling"""
        try:
            self.query_sequence_list.sort(key=lambda x: x["query_date"])
            processed_qlist_by_date = {
                date: list(items)
                for date, items in groupby(self.query_sequence_list, key=lambda x: x["query_date"])
            }
            return processed_qlist_by_date
        except KeyError as e:
            logger.error(f"Missing required field in query data: {e}")
            raise

    def group_candidate_by_date(self):
        """Group candidates by date with error handling"""
        try:
            self.candidate_sequence_list.sort(key=lambda x: x["candidate_date"])
            processed_clist_by_date = {
                date: list(items)
                for date, items in groupby(self.candidate_sequence_list, key=lambda x: x["candidate_date"])
            }
            return processed_clist_by_date
        except KeyError as e:
            logger.error(f"Missing required field in candidate data: {e}")
            raise

def get_embedding_sequence_str(sequence, query_or_candidate):
    """Convert sequence to string representation with improved error handling"""
    try:
        if query_or_candidate == 'query':
            seq1 = {
                'query_stock': sequence.get('query_stock', 'Unknown'),
                'query_date': sequence.get('query_date', 'Unknown'),
                'recent_date_list': sequence.get('recent_date_list', []),
                'adjusted_close_list': sequence.get('adjusted_close_list', []),
            }
        elif query_or_candidate == 'candidate':
            # Find the indicator key (the one ending with '_list' that's not in the standard keys)
            standard_keys = {'data_index', 'candidate_stock', 'candidate_date', 'candidate_movement', 'recent_date_list', 'indicator_name'}
            indicator_key = None
            indicator_values = []

            for key in sequence.keys():
                if key.endswith('_list') and key not in standard_keys:
                    indicator_key = key
                    indicator_values = sequence.get(key, [])
                    break

            if indicator_key is None:
                # Fallback to available indicators
                available_indicators = ['adj_close_list', 'close_list', 'volume_list', 'Returns_list']
                for ind in available_indicators:
                    if ind in sequence:
                        indicator_key = ind
                        indicator_values = sequence.get(ind, [])
                        break

                if indicator_key is None:
                    indicator_key = 'values_list'
                    indicator_values = []

            seq1 = {
                'candidate_stock': sequence.get('candidate_stock', 'Unknown'),
                'candidate_date': sequence.get('candidate_date', 'Unknown'),
                'recent_date_list': sequence.get('recent_date_list', []),
                indicator_key: indicator_values
            }
        else:
            raise ValueError(f"Invalid query_or_candidate value: {query_or_candidate}")

        return str(seq1)
    except Exception as e:
        logger.error(f"Error in get_embedding_sequence_str: {e}")
        return str({})

def generate_candidate_prompt_for_prob(query_sequence, candidate_list, retrieve_number):
    """Generate prompts for the LLM to predict stock movement with improved sampling"""
    try:
        # Early validation - skip if not rise/fall
        movement = query_sequence.get('movement', '').lower()
        if movement not in ['rise', 'fall']:
            logger.debug(f"Skipping query with movement '{movement}' (not rise/fall)")
            return [], [], [], ""

        query_date = query_sequence.get('query_date', 'Unknown')
        query_stock = query_sequence.get('query_stock', 'Unknown')
        query_sequence_str = get_embedding_sequence_str(sequence=query_sequence, query_or_candidate='query')

        instruction = ("Based on the following information, predict stock movement by filling in the [blank] with 'rise' or 'fall'. Just fill in the blank, do not explain.\n")
        query_inst = f'\nQuery: On {query_date}, the movement of ${query_stock} is [blank].\n'
        retrieve_prompt = 'These are sequences that may affect this stock\'s price recently:\n'
        query_prompt = 'This is the query sequence:\n'

        # Safely sample candidates
        actual_retrieve_number = min(retrieve_number, len(candidate_list))
        if actual_retrieve_number == 0:
            logger.warning(f"No candidates available for query on {query_date}")
            return [], [], [], query_sequence_str

        retrieve_result = random.sample(candidate_list, actual_retrieve_number)

        prompt_list = []
        candidate_index_list = []
        candidate_str_list = []

        for candidate_sequence in retrieve_result:
            candidate_index_list.append(candidate_sequence.get('data_index', 0))
            candidate_prompt = str({
                'candidate_sequence': get_embedding_sequence_str(sequence=candidate_sequence, query_or_candidate='candidate')
            })
            candidate_str_list.append(candidate_prompt)
            prompt = instruction + retrieve_prompt + candidate_prompt + '\n' + query_prompt + query_sequence_str + '\n' + query_inst
            prompt_list.append(prompt)

        return candidate_index_list, candidate_str_list, prompt_list, query_sequence_str

    except Exception as e:
        logger.error(f"Error generating candidate prompt: {e}")
        return [], [], [], ""

def get_probability_one_sequence(prompt, model, tokenizer, answer, device):
    """Get probability of correct answer from LLM with improved debugging and calculation"""
    try:
        messages = [
            {"role": "system", "content": "You are a stock analyst."},
            {"role": "user", "content": prompt}
        ]

        # Apply chat template safely
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        # Handle tokenization with proper padding
        input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024)
        input_ids = input_ids.to(device)

        with torch.no_grad():
            # Generate response with more deterministic parameters
            outputs = model.generate(
                input_ids,
                max_new_tokens=10,  # Reduced - we only need "rise" or "fall"
                num_return_sequences=1,
                temperature=0.1,  # Lower temperature for more deterministic output
                do_sample=False,  # Use greedy decoding
                pad_token_id=tokenizer.eos_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        # Get the generated text for debugging
        generated_ids = outputs.sequences[0][input_ids.shape[1]:]  # Only new tokens
        generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip().lower()

        # Debug: log what was generated
        logger.debug(f"Generated text: '{generated_text}' for answer: '{answer}'")

        # Method 1: Direct text matching (most reliable)
        prob = 0.0
        if answer.lower() in generated_text:
            # If the correct answer appears in generated text, use transition scores
            if hasattr(outputs, 'scores') and outputs.scores:
                # Get probabilities for the first generated token
                first_token_logits = outputs.scores[0][0]  # Shape: [vocab_size]
                first_token_probs = torch.softmax(first_token_logits, dim=-1)

                # Check both direct token and variations
                answer_variations = [answer.lower(), answer.upper(), answer.capitalize()]
                for variation in answer_variations:
                    token_id = tokenizer.encode(variation, add_special_tokens=False)
                    if token_id:  # If encoding successful
                        token_id = token_id[0] if isinstance(token_id, list) else token_id
                        if token_id < len(first_token_probs):
                            candidate_prob = first_token_probs[token_id].item()
                            prob = max(prob, candidate_prob)

                # Alternative: check for partial matches in generated tokens
                if prob == 0.0:
                    for i, score in enumerate(outputs.scores[:3]):  # Check first 3 tokens
                        if i >= len(generated_ids):
                            break
                        token_probs = torch.softmax(score[0], dim=-1)
                        generated_token_id = generated_ids[i]
                        generated_token = tokenizer.decode(generated_token_id).strip().lower()

                        if answer.lower() in generated_token or generated_token in answer.lower():
                            prob = token_probs[generated_token_id].item()
                            break
            else:
                # Fallback: assign a reasonable probability if text matches
                prob = 0.5

        # Ensure we return a reasonable minimum probability
        if prob == 0.0:
            prob = 0.01  # Very small but non-zero

        logger.debug(f"Final probability: {prob} for answer '{answer}' with generated '{generated_text}'")

        # Clean up GPU memory
        if hasattr(outputs, 'sequences'):
            del outputs.sequences
        if hasattr(outputs, 'scores'):
            del outputs.scores
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

        return float(prob)

    except Exception as e:
        logger.error(f"Error in probability calculation: {e}")
        logger.error(f"Prompt length: {len(prompt) if prompt else 0}")
        logger.error(f"Answer: {answer}")
        return 0.01  # Return small non-zero value instead of 0

def process_combined_dataset_scoring(query_file, candidate_file, output_file,
                                   model_name="ElsaShaw/StockLLM", retrieve_number=20,
                                   batch_size=100, max_queries=None):
    """Process ONLY RISE/FALL queries and return candidates with probabilities only"""
    
    logger.info(f"=== Processing Combined Dataset (RISE/FALL QUERIES ONLY) ===")
    logger.info(f"Query file: {query_file}")
    logger.info(f"Candidate file: {candidate_file}")
    logger.info(f"Output file: {output_file}")

    # Load model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Loading model {model_name} on {device}")

    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
            device_map="auto" if device == "cuda" else None,
            low_cpu_mem_usage=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        logger.info(f"Model loaded successfully")
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return 0

    try:
        # Load data (filtering happens in DATASTORE initialization)
        datastore = DATASTORE(query_file, candidate_file)
        qlist_by_date = datastore.group_query_by_date()
        clist_by_date = datastore.group_candidate_by_date()

        # Get statistics about filtered data
        rise_fall_stats = datastore.get_rise_fall_stats()
        total_queries_available = datastore.get_query_amount()
        queries_to_process = total_queries_available if max_queries is None else min(max_queries, total_queries_available)

        logger.info(f"Rise/Fall query statistics: {rise_fall_stats}")
        logger.info(f"Total rise/fall queries available: {total_queries_available}")
        logger.info(f"Queries to process: {queries_to_process}")

        # Create output directory
        os.makedirs(os.path.dirname(output_file), exist_ok=True)

        processed_queries = 0
        successful_queries = 0
        skipped_queries = 0
        all_results = []

        # Create progress bar for all queries
        pbar = tqdm(total=queries_to_process, desc="Processing RISE/FALL queries")

        # Process each query date
        for query_date in sorted(qlist_by_date.keys()):
            if max_queries and processed_queries >= max_queries:
                break

            query_sequence_list = qlist_by_date[query_date]

            try:
                query_date_dt = datetime.strptime(query_date, "%Y-%m-%d")
            except ValueError:
                logger.warning(f"Invalid date format: {query_date}")
                continue

            # Use 18-month window for better relevance
            cutoff_date = query_date_dt - pd.DateOffset(months=18)
            
            qualified_candidate_list = []
            for candidate_date, candidate_sequence_list in clist_by_date.items():
                try:
                    candidate_date_dt = datetime.strptime(candidate_date, "%Y-%m-%d")
                    if (query_date_dt > candidate_date_dt) and (candidate_date_dt >= cutoff_date):
                        qualified_candidate_list.extend(candidate_sequence_list)
                except ValueError:
                    logger.warning(f"Invalid candidate date format: {candidate_date}")
                    continue

            if len(qualified_candidate_list) < 100:  # Minimum threshold
                logger.warning(f"Insufficient candidates ({len(qualified_candidate_list)}) for {query_date} "
                             f"(18-month cutoff: {cutoff_date.strftime('%Y-%m-%d')})")
                continue

            # Optional: Log candidate pool info for monitoring
            if processed_queries < 5:  # Log details for first 5 queries
                logger.info(f"Query {query_date}: {len(qualified_candidate_list)} candidates "
                           f"from {cutoff_date.strftime('%Y-%m-%d')} to {query_date}")

            # Process each query in this date
            for query_sequence in query_sequence_list:
                if max_queries and processed_queries >= max_queries:
                    break

                query_id = query_sequence.get('data_index', processed_queries)
                reference_answer = query_sequence.get('movement', 'freeze')

                # Double-check: Only process rise/fall movements (should already be filtered)
                if reference_answer not in ['rise', 'fall']:
                    logger.debug(f"Skipping query {query_id} with movement '{reference_answer}'")
                    skipped_queries += 1
                    processed_queries += 1
                    pbar.update(1)
                    continue

                try:
                    # Generate prompts
                    candidate_index_list, candidate_str_list, prompt_list, query_sequence_str = generate_candidate_prompt_for_prob(
                        query_sequence, qualified_candidate_list, retrieve_number)

                    if not prompt_list:
                        processed_queries += 1
                        pbar.update(1)
                        continue

                    # Score with LLM (limit prompts for efficiency)
                    candidate_probabilities = []
                    max_prompts = min(10, len(prompt_list))  # Process up to 10 prompts per query

                    for prompt in prompt_list[:max_prompts]:
                        probability = get_probability_one_sequence(prompt, model, tokenizer, reference_answer, device)
                        candidate_probabilities.append(probability)

                    # Create result entry - FIXED FORMAT
                    if candidate_probabilities:
                        query_stock = query_sequence.get('query_stock', 'Unknown')
                        query_result = {
                            "query_id": f"{query_stock}_{query_id}",
                            "query": query_sequence_str,
                            "candidates": candidate_str_list[:max_prompts],  # All candidates here
                            "candidate_indices": candidate_index_list[:max_prompts],  # All indices here
                            "candidate_probabilities": candidate_probabilities,  # Probabilities for each candidate
                            "correct_answer": reference_answer,  # What the correct answer should be
                            "query_date": query_date,
                            "query_stock": query_stock,
                            "movement": reference_answer
                        }
                        all_results.append(query_result)
                        successful_queries += 1

                    processed_queries += 1
                    pbar.update(1)

                    # Save in batches to avoid memory issues
                    if len(all_results) >= batch_size:
                        save_batch_results(all_results, output_file, append=successful_queries > batch_size)
                        all_results = []

                        # Force garbage collection
                        gc.collect()
                        if device == "cuda":
                            torch.cuda.empty_cache()

                except Exception as e:
                    logger.error(f"Error processing query {query_id}: {e}")
                    processed_queries += 1
                    pbar.update(1)
                    continue

        # Save remaining results
        if all_results:
            save_batch_results(all_results, output_file, append=successful_queries > len(all_results))

        pbar.close()

        # Generate final statistics
        stats = {
            "total_rise_fall_queries": total_queries_available,
            "queries_processed": processed_queries,
            "queries_successfully_scored": successful_queries,
            "queries_skipped": skipped_queries,
            "rise_fall_distribution": rise_fall_stats,
            "success_rate": successful_queries / processed_queries if processed_queries > 0 else 0,
            "output_file": output_file,
            "model_used": model_name,
            "retrieve_number": retrieve_number,
            "processing_mode": "rise_fall_only_candidates_with_probabilities"
        }

        # Save statistics
        stats_file = output_file.replace('.json', '_stats.json')
        with open(stats_file, 'w') as f:
            json.dump(stats, f, indent=2)

        logger.info(f"\nProcessing Complete!")
        logger.info(f"Statistics:")
        logger.info(f"  - Total rise/fall queries available: {total_queries_available}")
        logger.info(f"  - Rise queries: {rise_fall_stats['rise']}")
        logger.info(f"  - Fall queries: {rise_fall_stats['fall']}")
        logger.info(f"  - Queries processed: {processed_queries}")
        logger.info(f"  - Queries successfully scored: {successful_queries}")
        logger.info(f"  - Queries skipped: {skipped_queries}")
        logger.info(f"  - Success rate: {stats['success_rate']:.2%}")
        logger.info(f"  - Results saved to: {output_file}")
        logger.info(f"  - Statistics saved to: {stats_file}")

        return successful_queries

    except Exception as e:
        logger.error(f"Error in combined dataset processing: {e}")
        return 0

def save_batch_results(results, output_file, append=False):
    """Save results in batches to manage memory"""
    mode = 'a' if append else 'w'
    with open(output_file, mode, encoding='utf-8') as f:
        for result in results:
            json_str = json.dumps(result)
            f.write(json_str + '\n')

def test_model_predictions(query_file, candidate_file, model_name="ElsaShaw/StockLLM", num_tests=5):
    """Test a few predictions to see what the model generates"""
    logger.info("=== Testing Model Predictions (Rise/Fall Only) ===")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None,
        low_cpu_mem_usage=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    try:
        datastore = DATASTORE(query_file, candidate_file)
        qlist_by_date = datastore.group_query_by_date()
        clist_by_date = datastore.group_candidate_by_date()

        logger.info(f"Testing with {datastore.get_query_amount()} rise/fall queries")

        test_count = 0
        for query_date, query_sequence_list in qlist_by_date.items():
            if test_count >= num_tests:
                break

            try:
                query_date_dt = datetime.strptime(query_date, "%Y-%m-%d")
            except ValueError:
                continue

            # Apply 18-month window in testing too
            cutoff_date = query_date_dt - pd.DateOffset(months=18)
            qualified_candidate_list = []

            for candidate_date, candidate_sequence_list in clist_by_date.items():
                try:
                    candidate_date_dt = datetime.strptime(candidate_date, "%Y-%m-%d")
                    if (query_date_dt > candidate_date_dt) and (candidate_date_dt >= cutoff_date):
                        qualified_candidate_list.extend(candidate_sequence_list)
                except ValueError:
                    continue

            for query_sequence in query_sequence_list[:2]:  # Test first 2 queries per date
                if test_count >= num_tests:
                    break

                reference_answer = query_sequence.get('movement', 'freeze')
                # Since datastore already filtered, this should only be rise/fall
                if reference_answer not in ['rise', 'fall'] or not qualified_candidate_list:
                    continue

                candidate_index_list, candidate_str_list, prompt_list, query_sequence_str = generate_candidate_prompt_for_prob(
                    query_sequence, qualified_candidate_list, 3)

                if prompt_list:
                    print(f"\n--- Test {test_count + 1} ---")
                    print(f"Stock: {query_sequence.get('query_stock', 'Unknown')}")
                    print(f"Date: {query_sequence.get('query_date', 'Unknown')}")
                    print(f"Expected: {reference_answer}")
                    print(f"Candidates from: {cutoff_date.strftime('%Y-%m-%d')} onwards ({len(qualified_candidate_list)} total)")

                    # Test first prompt
                    prob = get_probability_one_sequence(prompt_list[0], model, tokenizer, reference_answer, device)

                    print(f"Calculated probability: {prob}")
                    test_count += 1

    except Exception as e:
        logger.error(f"Error in testing: {e}")

    logger.info("=== Testing Complete ===")

# Main execution
if __name__ == "__main__":
    logger.info("=== LLM Scoring for RISE/FALL Queries (Candidates with Probabilities Only) ===")

    # Your specific file paths
    query_file_path = "/root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_qlist.json"
    candidate_file_path = "/root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_clist.json"
    output_file_path = "/root/nfs/AJ FinRag/LLM Scores/llm_data/all_companies_scored_candidates_probabilities.json"

    # Configuration - Process only rise/fall queries with 18-month window
    config = {
        "model_name": "ElsaShaw/StockLLM",
        "retrieve_number": 20,  # Number of candidates per query
        "batch_size": 50,       # Process in batches for memory management
        "max_queries": None     # None = process ALL rise/fall queries (no limit)
    }

    logger.info(f"Query file: {query_file_path}")
    logger.info(f"Candidate file: {candidate_file_path}")
    logger.info(f"Output file: {output_file_path}")
    logger.info(f"Configuration: {config}")
    logger.info(f"Processing Mode: RISE/FALL ONLY - Candidates with Probabilities")

    # Test the model first (optional - comment out if you want to skip)
    logger.info("\n=== Testing model outputs ===")
    test_model_predictions(
        query_file_path,
        candidate_file_path,
        config["model_name"],
        num_tests=3
    )

    # Process ONLY rise/fall queries in the combined dataset with 18-month window
    logger.info(f"\n=== Processing RISE/FALL queries with candidate probabilities ===")
    total_queries = process_combined_dataset_scoring(
        query_file=query_file_path,
        candidate_file=candidate_file_path,
        output_file=output_file_path,
        model_name=config["model_name"],
        retrieve_number=config["retrieve_number"],
        batch_size=config["batch_size"],
        max_queries=config["max_queries"]
    )

    logger.info(f"\nRise/Fall scoring complete! Total queries successfully processed: {total_queries}")
    logger.info(f"Output format: Candidates with probabilities (ready for pos/neg allocation)")
    
    # Show sample output format
    logger.info("\nSample output format:")
    logger.info('''{
  "query": "query_string",
  "candidates": ["candidate_1", "candidate_2", "candidate_3"],
  "candidate_probabilities": [0.85, 0.23, 0.67],
  "correct_answer": "rise"
}''')

2025-08-26 13:50:41,426 - INFO - === LLM Scoring for RISE/FALL Queries (Candidates with Probabilities Only) ===
2025-08-26 13:50:41,428 - INFO - Query file: /root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_qlist.json
2025-08-26 13:50:41,429 - INFO - Candidate file: /root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_clist.json
2025-08-26 13:50:41,430 - INFO - Output file: /root/nfs/AJ FinRag/LLM Scores/llm_data/all_companies_scored_candidates_probabilities.json
2025-08-26 13:50:41,431 - INFO - Configuration: {'model_name': 'ElsaShaw/StockLLM', 'retrieve_number': 20, 'batch_size': 50, 'max_queries': None}
2025-08-26 13:50:41,432 - INFO - Processing Mode: RISE/FALL ONLY - Candidates with Probabilities
2025-08-26 13:50:41,432 - INFO - 
=== Testing model outputs ===
2025-08-26 13:50:41,433 - INFO - === Testing Model Predictions (Rise/Fall Only) ===
2025-08-26 13:50:43,507 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the 

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2025-08-26 13:51:11,161 - INFO - Query movement distribution:
2025-08-26 13:51:11,163 - INFO -   - Rise: 6747
2025-08-26 13:51:11,164 - INFO -   - Fall: 6484
2025-08-26 13:51:11,165 - INFO -   - Freeze (skipped): 5319
2025-08-26 13:51:11,166 - INFO -   - Total processed: 13231
2025-08-26 13:51:14,873 - INFO - Loaded 18550 total queries, filtered to 13231 rise/fall queries
2025-08-26 13:51:14,875 - INFO - Loaded 203275 candidates
2025-08-26 13:51:15,299 - INFO - Testing with 13231 rise/fall queries
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



--- Test 1 ---
Stock: AAPL
Date: 2022-01-19
Expected: fall
Candidates from: 2020-07-19 onwards (225 total)
Calculated probability: 0.7955255508422852

--- Test 2 ---
Stock: ADBE
Date: 2022-01-19
Expected: rise
Candidates from: 2020-07-19 onwards (225 total)
Calculated probability: 0.01

--- Test 3 ---
Stock: AAPL
Date: 2022-01-20
Expected: fall
Candidates from: 2020-07-20 onwards (450 total)


2025-08-26 13:51:17,148 - INFO - === Testing Complete ===


Calculated probability: 0.619976282119751


2025-08-26 13:51:17,743 - INFO - 
=== Processing RISE/FALL queries with candidate probabilities ===
2025-08-26 13:51:17,746 - INFO - === Processing Combined Dataset (RISE/FALL QUERIES ONLY) ===
2025-08-26 13:51:17,748 - INFO - Query file: /root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_qlist.json
2025-08-26 13:51:17,750 - INFO - Candidate file: /root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_clist.json
2025-08-26 13:51:17,753 - INFO - Output file: /root/nfs/AJ FinRag/LLM Scores/llm_data/all_companies_scored_candidates_probabilities.json
2025-08-26 13:51:17,754 - INFO - Loading model ElsaShaw/StockLLM on cuda
2025-08-26 13:51:17,919 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2025-08-26 13:51:20,249 - INFO - Model loaded successfully
2025-08-26 13:51:20,465 - INFO - Query movement distribution:
2025-08-26 13:51:20,466 - INFO -   - Rise: 6747
2025-08-26 13:51:20,468 - INFO -   - Fall: 6484
2025-08-26 13:51:20,469 - INFO -   - Freeze (skipped): 5319
2025-08-26 13:51:20,469 - INFO -   - Total processed: 13231
2025-08-26 13:51:24,372 - INFO - Loaded 18550 total queries, filtered to 13231 rise/fall queries
2025-08-26 13:51:24,375 - INFO - Loaded 203275 candidates
2025-08-26 13:51:24,646 - INFO - Rise/Fall query statistics: {'rise': 6747, 'fall': 6484, 'total': 13231}
2025-08-26 13:51:24,647 - INFO - Total rise/fall queries available: 13231
2025-08-26 13:51:24,648 - INFO - Queries to process: 13231
2025-08-26 13:51:24,667 - INFO - Query 2022-01-19: 225 candidates from 2020-07-19 to 2022-01-19
Processing RISE/FALL queries:  10%|█         | 1363/13231 [43:27<6:24:02,  1.94s/it]