In [None]:
# SECTION 3: FINQUEST SIMILARITY RETRIEVAL EXPERIMENT
# LLM prediction with FinQuest similarity-based retrieved candidates
# Tests the effectiveness of learned financial pattern similarity

import os
import torch
import json
import pandas as pd
import numpy as np
import pickle
from datetime import datetime
from pathlib import Path
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Dict, List, Tuple, Optional

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class FinQuestSimilarityExperiment:
    """
    FinQuest similarity retrieval experiment: LLM + similarity-based retrieved candidates
    Tests the effectiveness of trained FinQuest retriever vs. baselines
    """
    
    def __init__(self, 
                 test_queries_file: str, 
                 ground_truth_file: str, 
                 embeddings_dir: str,
                 similarity_results_file: str):
        
        self.test_queries_file = test_queries_file
        self.ground_truth_file = ground_truth_file
        self.embeddings_dir = embeddings_dir
        self.similarity_results_file = similarity_results_file
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load all required data
        self._load_data()
        
        logger.info(f"Loaded {len(self.test_queries)} test queries")
        logger.info(f"Loaded {len(self.ground_truth)} ground truth entries")
        logger.info(f"Loaded {len(self.candidates)} candidates")
        logger.info(f"Loaded similarity results for {len(self.similarity_results)} queries")
        
        # LLM configurations
        self.llm_configs = {
            'StockLLM': {
                'model_name': 'ElsaShaw/StockLLM',
                'description': 'Specialized financial LLM',
                'max_length': 1024
            },
            'Llama3.2-3B': {
                'model_name': 'meta-llama/Llama-3.2-3B-Instruct',
                'description': 'Medium general-purpose LLM',
                'max_length': 2048
            },
            'Qwen2.5-1.5B': {
                'model_name': 'Qwen/Qwen2.5-1.5B-Instruct',
                'description': 'Qwen instruction-following model',
                'max_length': 2048
            },
            'Phi3-Mini': {
                'model_name': 'microsoft/Phi-3-mini-4k-instruct',
                'description': 'Microsoft compact LLM',
                'max_length': 4096
            }
        }
        
        # Current LLM state
        self.current_llm = None
        self.current_tokenizer = None
        self.current_llm_name = None
    
    def _load_data(self):
        """Load all required datasets"""
        logger.info("Loading experimental data...")
        
        # Load test queries
        self.test_queries = self._load_json_file(self.test_queries_file)
        
        # Load ground truth
        self.ground_truth = self._load_json_file(self.ground_truth_file)
        self.gt_lookup = {gt['query_id']: gt for gt in self.ground_truth}
        
        # Load candidates from embeddings
        self.candidates = self._load_candidates()
        
        # Load pre-computed similarity results
        if os.path.exists(self.similarity_results_file):
            with open(self.similarity_results_file, 'rb') as f:
                self.similarity_results = pickle.load(f)
            
            # Create lookup for fast access
            self.similarity_lookup = {}
            for result in self.similarity_results:
                key = f"{result['query_stock']}_{result['query_date']}"
                self.similarity_lookup[key] = result
            
        else:
            logger.error(f"Similarity results file not found: {self.similarity_results_file}")
            raise FileNotFoundError("FinQuest similarity results required for this experiment")
    
    def _load_json_file(self, file_path: str) -> List[Dict]:
        """Load JSONL file"""
        data = []
        with open(file_path, 'r') as f:
            for line in f:
                try:
                    data.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    continue
        return data
    
    def _load_candidates(self) -> Dict:
        """Load candidates from embedding files"""
        candidates = {}
        embedding_files = list(Path(self.embeddings_dir).glob("c_*_FinQuest_embeddings_*.pkl"))
        
        logger.info(f"Loading candidates from {len(embedding_files)} files...")
        
        for file_path in embedding_files:
            with open(file_path, 'rb') as f:
                embedding_data = pickle.load(f)
            
            for date_group in embedding_data:
                for date, candidates_on_date in date_group.items():
                    for candidate_item in candidates_on_date:
                        candidate_data = candidate_item['data']
                        candidates[candidate_data['data_index']] = candidate_data
        
        return candidates
    
    def load_llm(self, llm_name: str):
        """Load specified LLM for testing"""
        if llm_name not in self.llm_configs:
            raise ValueError(f"Unknown LLM: {llm_name}")
        
        # Clear previous model
        if self.current_llm is not None:
            del self.current_llm
            del self.current_tokenizer
            torch.cuda.empty_cache()
        
        config = self.llm_configs[llm_name]
        model_name = config['model_name']
        
        logger.info(f"Loading {llm_name}: {model_name}")
        
        try:
            self.current_tokenizer = AutoTokenizer.from_pretrained(model_name)
            if self.current_tokenizer.pad_token is None:
                self.current_tokenizer.pad_token = self.current_tokenizer.eos_token
            
            self.current_llm = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16 if self.device.type == 'cuda' else torch.float32,
                device_map="auto" if self.device.type == 'cuda' else None,
                low_cpu_mem_usage=True,
                trust_remote_code=True
            )
            
            self.current_llm_name = llm_name
            logger.info(f"✅ Successfully loaded {llm_name}")
            
        except Exception as e:
            logger.error(f"❌ Failed to load {llm_name}: {e}")
            raise
    
    def get_finquest_candidates(self, query: Dict, k: int = 5) -> Tuple[List[Dict], List[float], List[int]]:
        """
        Get top-k candidates from pre-computed FinQuest similarity results
        Returns candidates, similarity scores, and movement distribution
        """
        query_stock = query.get('query_stock', '')
        query_date = query.get('query_date', '')
        
        # Look up pre-computed similarity results
        key = f"{query_stock}_{query_date}"
        
        if key not in self.similarity_lookup:
            logger.warning(f"No similarity results found for {query_stock} on {query_date}")
            return [], [], [0, 0]
        
        similarity_result = self.similarity_lookup[key]
        similarity_list = similarity_result.get('similarity_list', [])
        
        # Extract top-k candidates
        top_candidates = []
        similarity_scores = []
        candidate_movement_count = [0, 0]  # [rise, fall]
        
        for candidate_info in similarity_list[:k]:
            candidate_index = candidate_info['candidate_index']
            candidate_score = candidate_info['candidate_score']
            
            # Get candidate data
            if candidate_index in self.candidates:
                candidate_data = self.candidates[candidate_index]
                top_candidates.append(candidate_data)
                similarity_scores.append(candidate_score)
                
                # Track movement distribution
                movement = candidate_data.get('movement', 'unknown')
                if movement == 'rise':
                    candidate_movement_count[0] += 1
                elif movement == 'fall':
                    candidate_movement_count[1] += 1
        
        return top_candidates, similarity_scores, candidate_movement_count
    
    def generate_finquest_prompt(self, 
                                query: Dict, 
                                candidates: List[Dict], 
                                similarity_scores: List[float]) -> str:
        """
        Generate prompt with FinQuest similarity-retrieved candidates
        Uses same format as your working similarity search
        """
        query_stock = query.get('query_stock', 'Unknown')
        query_date = query.get('query_date', 'Unknown')
        
        instruction = (
            "Based on the following information, predict stock movement by filling in the [blank] with 'rise' or 'fall'. "
            "Just fill in the blank, do not explain.\n"
        )
        
        retrieve_prompt = 'These are sequences that may affect this stock\'s price recently, where similarity score shows the similarity to the query sequence:\n'
        
        # Format similarity-retrieved candidates
        candidate_text = ""
        
        for candidate, score in zip(candidates, similarity_scores):
            # Format candidate sequence (matching your working format)
            candidate_sequence = {
                'candidate_stock': candidate.get('candidate_stock', 'Unknown'),
                'candidate_date': candidate.get('candidate_date', 'Unknown'),
                'recent_date_list': candidate.get('recent_date_list', []),
                'adjusted_close_list': candidate.get('adjusted_close_list', [])
            }
            
            candidate_text += str({
                'candidate_sequence': candidate_sequence,
                'similarity_score': round(score, 4)
            }) + '\n'
        
        # Query section
        query_prompt = 'This is the query sequence:\n'
        query_sequence = {
            'query_stock': query_stock,
            'query_date': query_date,
            'recent_date_list': query.get('recent_date_list', []),
            'adjusted_close_list': query.get('adjusted_close_list', [])
        }
        
        query_instruction = f'\nQuery: On {query_date}, the movement of ${query_stock} is [blank].\n'
        
        # Combine all parts
        full_prompt = instruction + retrieve_prompt + candidate_text + '\n' + query_prompt + str(query_sequence) + '\n' + query_instruction
        
        return full_prompt
    
    def ask_llm(self, prompt: str) -> str:
        """Get prediction from current LLM with appropriate context length"""
        if self.current_llm is None:
            raise RuntimeError("No LLM loaded")
        
        config = self.llm_configs[self.current_llm_name]
        max_length = config.get('max_length', 1024)
        
        # Format prompt based on LLM architecture
        if 'Llama' in self.current_llm_name:
            messages = [
                {"role": "system", "content": "You are a financial analyst. Use the similar historical patterns to predict stock movements accurately."},
                {"role": "user", "content": prompt}
            ]
            
            formatted_prompt = self.current_tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        else:
            formatted_prompt = f"System: You are a financial analyst.\nUser: {prompt}\nAssistant:"
        
        # Tokenize with appropriate max length
        input_ids = self.current_tokenizer.encode(
            formatted_prompt, 
            return_tensors="pt", 
            truncation=True, 
            max_length=max_length
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.current_llm.generate(
                input_ids,
                max_new_tokens=10,
                temperature=0.1,
                do_sample=False,
                pad_token_id=self.current_tokenizer.eos_token_id,
                eos_token_id=self.current_tokenizer.eos_token_id
            )
        
        generated_ids = outputs[0][input_ids.shape[1]:]
        response = self.current_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
        
        return response
    
    def extract_prediction(self, response: str, reference: str) -> Tuple[str, bool]:
        """Extract rise/fall prediction from LLM response"""
        response_clean = response.lower().strip()
        reference_clean = reference.lower().strip()
        
        if 'rise' in response_clean:
            prediction = 'rise'
        elif 'fall' in response_clean:
            prediction = 'fall'
        else:
            prediction = 'freeze'
        
        correct = (prediction == reference_clean)
        return prediction, correct
    
    def run_single_llm_experiment(self, llm_name: str, output_dir: str, k_candidates: int = 5) -> pd.DataFrame:
        """Run FinQuest similarity retrieval experiment for a single LLM"""
        logger.info(f"🚀 Starting FINQUEST SIMILARITY RETRIEVAL experiment with {llm_name} (k={k_candidates})")
        
        # Load the specified LLM
        self.load_llm(llm_name)
        
        results = []
        processed_count = 0
        correct_count = 0
        queries_with_candidates = 0
        queries_without_candidates = 0
        
        for i, query in enumerate(self.test_queries):
            query_id = query['query_id']
            ground_truth = self.gt_lookup.get(query_id)
            
            if not ground_truth:
                continue
            
            reference_answer = ground_truth['actual_movement']
            
            # Skip freeze movements
            if reference_answer == 'freeze':
                continue
            
            # Progress logging
            if (i + 1) % 50 == 0:
                accuracy = correct_count / processed_count if processed_count > 0 else 0
                candidate_coverage = queries_with_candidates / (queries_with_candidates + queries_without_candidates) if (queries_with_candidates + queries_without_candidates) > 0 else 0
                logger.info(f"Progress: {i+1}/{len(self.test_queries)} | Accuracy: {accuracy:.3f} | Candidate Coverage: {candidate_coverage:.3f}")
            
            try:
                # Get FinQuest similarity-based candidates
                candidates, similarity_scores, candidate_movement_count = self.get_finquest_candidates(query, k_candidates)
                
                if len(candidates) == 0:
                    logger.warning(f"No FinQuest candidates found for query {query_id}")
                    queries_without_candidates += 1
                    continue
                
                queries_with_candidates += 1
                
                # Generate prompt with similarity-retrieved candidates
                prompt = self.generate_finquest_prompt(query, candidates, similarity_scores)
                
                # Get LLM prediction
                llm_response = self.ask_llm(prompt)
                
                # Extract and validate prediction
                prediction, correct = self.extract_prediction(llm_response, reference_answer)
                
                # Track accuracy
                processed_count += 1
                if correct:
                    correct_count += 1
                
                # Store detailed result
                result = {
                    'llm_name': llm_name,
                    'query_id': query_id,
                    'query_stock': query.get('query_stock', ''),
                    'query_date': query.get('query_date', ''),
                    'method': 'finquest_similarity_retrieval',
                    'prompt': prompt,
                    'llm_response': llm_response,
                    'prediction': prediction,
                    'reference': reference_answer,
                    'correct': correct,
                    'candidate_count': len(candidates),
                    'candidate_movement_dist': str(candidate_movement_count),
                    'similarity_scores': str([round(s, 4) for s in similarity_scores]),
                    'avg_similarity_score': np.mean(similarity_scores) if similarity_scores else 0,
                    'max_similarity_score': max(similarity_scores) if similarity_scores else 0,
                    'min_similarity_score': min(similarity_scores) if similarity_scores else 0,
                    'k_candidates': k_candidates
                }
                
                results.append(result)
                
            except Exception as e:
                logger.error(f"Error processing query {query_id}: {e}")
                continue
        
        # Create results DataFrame
        df = pd.DataFrame(results)
        
        # Save detailed results
        os.makedirs(output_dir, exist_ok=True)
        output_file = os.path.join(output_dir, f'{llm_name}_finquest_similarity_k{k_candidates}_results.csv')
        df.to_csv(output_file, index=False)
        
        # Calculate comprehensive metrics
        total_predictions = len(df)
        overall_accuracy = df['correct'].mean() if total_predictions > 0 else 0
        
        # Class-specific metrics
        rise_df = df[df['reference'] == 'rise']
        fall_df = df[df['reference'] == 'fall']
        
        rise_accuracy = rise_df['correct'].mean() if len(rise_df) > 0 else 0
        fall_accuracy = fall_df['correct'].mean() if len(fall_df) > 0 else 0
        
        # Candidate and similarity statistics
        avg_candidates_used = df['candidate_count'].mean() if total_predictions > 0 else 0
        avg_similarity = df['avg_similarity_score'].mean() if total_predictions > 0 else 0
        
        # Coverage statistics
        candidate_coverage = queries_with_candidates / (queries_with_candidates + queries_without_candidates) if (queries_with_candidates + queries_without_candidates) > 0 else 0
        
        # Log comprehensive results
        logger.info("="*70)
        logger.info(f"FINQUEST SIMILARITY RETRIEVAL RESULTS - {llm_name}")
        logger.info("="*70)
        logger.info(f"Total Predictions: {total_predictions}")
        logger.info(f"Overall Accuracy: {overall_accuracy:.4f}")
        logger.info(f"Rise Accuracy: {rise_accuracy:.4f} ({len(rise_df)} samples)")
        logger.info(f"Fall Accuracy: {fall_accuracy:.4f} ({len(fall_df)} samples)")
        logger.info(f"Candidate Coverage: {candidate_coverage:.4f} ({queries_with_candidates}/{queries_with_candidates + queries_without_candidates})")
        logger.info(f"Avg Candidates Used: {avg_candidates_used:.1f}")
        logger.info(f"Avg Similarity Score: {avg_similarity:.4f}")
        logger.info(f"Results saved to: {output_file}")
        logger.info("="*70)
        
        return df
    
    def run_multi_llm_experiment(self, llm_list: List[str], output_dir: str, k_candidates: int = 5) -> pd.DataFrame:
        """Run FinQuest similarity retrieval experiment across multiple LLMs"""
        logger.info(f"🚀 Starting MULTI-LLM FINQUEST SIMILARITY RETRIEVAL experiment")
        logger.info(f"Testing LLMs: {llm_list}")
        logger.info(f"Candidates per query: {k_candidates}")
        
        all_results = []
        
        for llm_name in llm_list:
            logger.info(f"\n{'='*60}")
            logger.info(f"TESTING LLM: {llm_name}")
            logger.info(f"Description: {self.llm_configs[llm_name]['description']}")
            logger.info(f"{'='*60}")
            
            try:
                # Run experiment for this LLM
                llm_results = self.run_single_llm_experiment(llm_name, output_dir, k_candidates)
                all_results.append(llm_results)
                
                logger.info(f"✅ Completed {llm_name}")
                
            except Exception as e:
                logger.error(f"❌ Failed {llm_name}: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        # Combine all results
        if all_results:
            combined_df = pd.concat(all_results, ignore_index=True)
            
            # Save combined results
            combined_file = os.path.join(output_dir, f'all_llms_finquest_similarity_k{k_candidates}_combined.csv')
            combined_df.to_csv(combined_file, index=False)
            
            # Generate comprehensive comparison report
            self._generate_comprehensive_comparison_report(combined_df, output_dir, k_candidates)
            
            # Analyze similarity score patterns
            self._analyze_similarity_patterns(combined_df, output_dir)
            
            logger.info(f"\n✅ Multi-LLM FinQuest similarity retrieval experiment completed!")
            logger.info(f"Combined results saved to: {combined_file}")
            
            return combined_df
        else:
            logger.error("❌ No successful experiments!")
            return pd.DataFrame()
    
    def run_k_ablation_study(self, llm_name: str, output_dir: str, k_values: List[int] = [1, 3, 5, 10, 15]) -> pd.DataFrame:
        """Run ablation study on number of FinQuest similarity candidates"""
        logger.info(f"🔬 Starting FINQUEST K-ABLATION STUDY for {llm_name}")
        logger.info(f"Testing k values: {k_values}")
        
        all_results = []
        
        for k in k_values:
            logger.info(f"\n{'='*40}")
            logger.info(f"TESTING k={k} FinQuest candidates")
            logger.info(f"{'='*40}")
            
            try:
                # Run experiment with k candidates
                k_results = self.run_single_llm_experiment(llm_name, output_dir, k)
                all_results.append(k_results)
                
                logger.info(f"✅ Completed k={k}")
                
            except Exception as e:
                logger.error(f"❌ Failed k={k}: {e}")
                continue
        
        # Analyze k-ablation results
        if all_results:
            self._analyze_finquest_k_ablation(all_results, llm_name, output_dir, k_values)
            
            combined_df = pd.concat(all_results, ignore_index=True)
            return combined_df
        else:
            logger.error("❌ FinQuest K-ablation study failed!")
            return pd.DataFrame()
    
    def _analyze_finquest_k_ablation(self, results_list: List[pd.DataFrame], llm_name: str, output_dir: str, k_values: List[int]):
        """Analyze effect of different k values for FinQuest similarity retrieval"""
        
        k_analysis = []
        
        for i, df in enumerate(results_list):
            k = k_values[i]
            
            # Basic accuracy metrics
            accuracy = df['correct'].mean()
            rise_acc = df[df['reference'] == 'rise']['correct'].mean()
            fall_acc = df[df['reference'] == 'fall']['correct'].mean()
            
            # Similarity-specific metrics
            avg_similarity = df['avg_similarity_score'].mean()
            max_similarity_avg = df['max_similarity_score'].mean()
            
            k_analysis.append({
                'k_candidates': k,
                'overall_accuracy': accuracy,
                'rise_accuracy': rise_acc,
                'fall_accuracy': fall_acc,
                'total_predictions': len(df),
                'avg_similarity_score': avg_similarity,
                'avg_max_similarity': max_similarity_avg
            })
        
        # Save k-ablation analysis
        k_df = pd.DataFrame(k_analysis)
        k_file = os.path.join(output_dir, f'{llm_name}_finquest_k_ablation_analysis.csv')
        k_df.to_csv(k_file, index=False)
        
        # Print detailed analysis
        logger.info(f"\nFINQUEST K-ABLATION ANALYSIS - {llm_name}")
        logger.info("="*80)
        for result in k_analysis:
            logger.info(f"k={result['k_candidates']:2d} | Accuracy: {result['overall_accuracy']:.4f} | "
                       f"Rise: {result['rise_accuracy']:.3f} | Fall: {result['fall_accuracy']:.3f} | "
                       f"Avg Sim: {result['avg_similarity_score']:.3f}")
        
        logger.info(f"\nFinQuest K-ablation analysis saved to: {k_file}")
    
    def _analyze_similarity_patterns(self, combined_df: pd.DataFrame, output_dir: str):
        """Analyze patterns in similarity scores vs. prediction accuracy"""
        
        # Similarity score binning analysis
        similarity_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]
        combined_df['similarity_bin'] = pd.cut(combined_df['avg_similarity_score'], 
                                             bins=similarity_bins, 
                                             labels=['0.0-0.3', '0.3-0.5', '0.5-0.7', '0.7-0.9', '0.9-1.0'])
        
        similarity_analysis = []
        
        for bin_name in ['0.0-0.3', '0.3-0.5', '0.5-0.7', '0.7-0.9', '0.9-1.0']:
            bin_df = combined_df[combined_df['similarity_bin'] == bin_name]
            
            if len(bin_df) > 0:
                similarity_analysis.append({
                    'similarity_range': bin_name,
                    'count': len(bin_df),
                    'accuracy': bin_df['correct'].mean(),
                    'avg_similarity': bin_df['avg_similarity_score'].mean()
                })
        
        # Save similarity analysis
        sim_df = pd.DataFrame(similarity_analysis)
        sim_file = os.path.join(output_dir, 'finquest_similarity_score_analysis.csv')
        sim_df.to_csv(sim_file, index=False)
        
        logger.info("\nFINQUEST SIMILARITY SCORE ANALYSIS")
        logger.info("="*50)
        for result in similarity_analysis:
            logger.info(f"Similarity {result['similarity_range']} | Count: {result['count']:4d} | "
                       f"Accuracy: {result['accuracy']:.4f}")
        
        logger.info(f"\nSimilarity analysis saved to: {sim_file}")
    
    def _generate_comprehensive_comparison_report(self, combined_df: pd.DataFrame, output_dir: str, k_candidates: int):
        """Generate comprehensive comparison report across LLMs with FinQuest-specific metrics"""
        
        comparison_results = []
        
        for llm_name in combined_df['llm_name'].unique():
            llm_df = combined_df[combined_df['llm_name'] == llm_name]
            
            # Basic accuracy metrics
            total = len(llm_df)
            accuracy = llm_df['correct'].mean()
            
            rise_df = llm_df[llm_df['reference'] == 'rise']
            fall_df = llm_df[llm_df['reference'] == 'fall']
            
            rise_acc = rise_df['correct'].mean() if len(rise_df) > 0 else 0
            fall_acc = fall_df['correct'].mean() if len(fall_df) > 0 else 0
            
            # Prediction distribution
            pred_dist = llm_df['prediction'].value_counts()
            
            # FinQuest-specific metrics
            avg_candidates = llm_df['candidate_count'].mean()
            avg_similarity = llm_df['avg_similarity_score'].mean()
            max_similarity_avg = llm_df['max_similarity_score'].mean()
            
            # High similarity predictions (>0.7 avg similarity)
            high_sim_df = llm_df[llm_df['avg_similarity_score'] > 0.7]
            high_sim_accuracy = high_sim_df['correct'].mean() if len(high_sim_df) > 0 else 0
            
            comparison_results.append({
                'llm_name': llm_name,
                'model_description': self.llm_configs[llm_name]['description'],
                'total_predictions': total,
                'overall_accuracy': accuracy,
                'rise_accuracy': rise_acc,
                'fall_accuracy': fall_acc,
                'rise_predictions': len(rise_df),
                'fall_predictions': len(fall_df),
                'predicted_rise': pred_dist.get('rise', 0),
                'predicted_fall': pred_dist.get('fall', 0),
                'predicted_freeze': pred_dist.get('freeze', 0),
                'avg_candidates_used': avg_candidates,
                'avg_similarity_score': avg_similarity,
                'avg_max_similarity': max_similarity_avg,
                'high_similarity_accuracy': high_sim_accuracy,
                'high_similarity_count': len(high_sim_df),
                'k_candidates': k_candidates,
                'method': 'finquest_similarity_retrieval'
            })
        
        # Save comprehensive comparison
        comparison_df = pd.DataFrame(comparison_results)
        comparison_file = os.path.join(output_dir, f'finquest_similarity_k{k_candidates}_comprehensive_comparison.csv')
        comparison_df.to_csv(comparison_file, index=False)
        
        # Print comprehensive summary
        logger.info("\n" + "="*100)
        logger.info(f"FINQUEST SIMILARITY RETRIEVAL (k={k_candidates}) - COMPREHENSIVE LLM COMPARISON")
        logger.info("="*100)
        
        for result in comparison_results:
            logger.info(f"{result['llm_name']:15} | Accuracy: {result['overall_accuracy']:.4f} | "
                       f"Rise: {result['rise_accuracy']:.3f} | Fall: {result['fall_accuracy']:.3f} | "
                       f"AvgSim: {result['avg_similarity_score']:.3f} | "
                       f"HighSim: {result['high_similarity_accuracy']:.3f}({result['high_similarity_count']})")
        
        logger.info(f"\nComprehensive comparison saved to: {comparison_file}")
        
        return comparison_df

# USAGE EXAMPLE
def run_finquest_similarity_experiment():
    """Main function to run the FinQuest similarity retrieval experiment"""
    
    # Configuration
    test_queries_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/test_queries_rise_fall_only.json'
    ground_truth_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/ground_truth_rise_fall_only.json'
    embeddings_dir = '/root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest'
    similarity_results_file = 'similar_candidates/test/FinQuest/test_similarity_results_FinQuest.pkl'
    output_dir = 'finquest_similarity_experiments'
    
    # LLMs to test
    test_llms = [
        'StockLLM',      # Specialized financial model - should work best
        'Qwen2.5-1.5B',   # Small general model
        'Llama3.2-3B'    # Medium general model - good balance
    ]
    
    # Experimental parameters
    k_candidates = 5  # Number of similarity-based candidates to retrieve
    
    try:
        # Initialize experiment
        experiment = FinQuestSimilarityExperiment(
            test_queries_file, 
            ground_truth_file, 
            embeddings_dir,
            similarity_results_file
        )
        
        # Option 1: Run multi-LLM experiment with fixed k
        logger.info("🚀 RUNNING MULTI-LLM FINQUEST SIMILARITY EXPERIMENT")
        results = experiment.run_multi_llm_experiment(test_llms, output_dir, k_candidates)
        
        if len(results) > 0:
            logger.info("🎉 FINQUEST SIMILARITY EXPERIMENT COMPLETED SUCCESSFULLY!")
            logger.info(f"📊 Total results: {len(results)} predictions across {len(test_llms)} LLMs")
        
        # Option 2: Run k-ablation study (uncomment to run)
        # logger.info("🔬 RUNNING FINQUEST K-ABLATION STUDY")
        # ablation_results = experiment.run_k_ablation_study('StockLLM', output_dir, [1, 3, 5, 10, 15])
        
        # Option 3: Deep analysis of best performing LLM (uncomment to run)
        # best_llm = 'StockLLM'  # Choose based on initial results
        # logger.info(f"🔍 RUNNING DETAILED ANALYSIS FOR {best_llm}")
        # detailed_results = experiment.run_single_llm_experiment(best_llm, output_dir, 10)
        
    except Exception as e:
        logger.error(f"❌ Experiment failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    run_finquest_similarity_experiment()

INFO:__main__:Loading experimental data...
INFO:__main__:Loading candidates from 75 files...
INFO:__main__:Loaded 2642 test queries
INFO:__main__:Loaded 2642 ground truth entries
INFO:__main__:Loaded 8151 candidates
INFO:__main__:Loaded similarity results for 2642 queries
INFO:__main__:🚀 RUNNING MULTI-LLM FINQUEST SIMILARITY EXPERIMENT
INFO:__main__:🚀 Starting MULTI-LLM FINQUEST SIMILARITY RETRIEVAL experiment
INFO:__main__:Testing LLMs: ['StockLLM', 'Qwen2.5-1.5B', 'Llama3.2-3B']
INFO:__main__:Candidates per query: 5
INFO:__main__:
INFO:__main__:TESTING LLM: StockLLM
INFO:__main__:Description: Specialized financial LLM
INFO:__main__:🚀 Starting FINQUEST SIMILARITY RETRIEVAL experiment with StockLLM (k=5)
INFO:__main__:Loading StockLLM: ElsaShaw/StockLLM
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

INFO:__main__:✅ Successfully loaded StockLLM
INFO:__main__:Progress: 50/2642 | Accuracy: 0.510 | Candidate Coverage: 1.000
INFO:__main__:Progress: 100/2642 | Accuracy: 0.455 | Candidate Coverage: 1.000
INFO:__main__:Progress: 150/2642 | Accuracy: 0.470 | Candidate Coverage: 1.000
INFO:__main__:Progress: 200/2642 | Accuracy: 0.472 | Candidate Coverage: 1.000
INFO:__main__:Progress: 250/2642 | Accuracy: 0.486 | Candidate Coverage: 1.000
INFO:__main__:Progress: 300/2642 | Accuracy: 0.482 | Candidate Coverage: 1.000
INFO:__main__:Progress: 350/2642 | Accuracy: 0.493 | Candidate Coverage: 1.000
INFO:__main__:Progress: 400/2642 | Accuracy: 0.496 | Candidate Coverage: 1.000
INFO:__main__:Progress: 450/2642 | Accuracy: 0.499 | Candidate Coverage: 1.000
INFO:__main__:Progress: 500/2642 | Accuracy: 0.501 | Candidate Coverage: 1.000
INFO:__main__:Progress: 550/2642 | Accuracy: 0.488 | Candidate Coverage: 1.000
INFO:__main__:Progress: 600/2642 | Accuracy: 0.486 | Candidate Coverage: 1.000
INFO:__m

In [3]:
# Debug script to investigate StockLLM's response bias

import pandas as pd
import numpy as np
import os
import pickle

def analyze_stockllm_responses(results_files):
    """Analyze StockLLM responses across different methods"""
    
    print("🔍 ANALYZING STOCKLLM RESPONSE PATTERNS")
    print("="*60)
    
    methods = ['no_retrieval', 'random_retrieval', 'finquest_similarity']
    
    for method in methods:
        if method in results_files and os.path.exists(results_files[method]):
            print(f"\n📊 {method.upper()}")
            print("-" * 40)
            
            df = pd.read_csv(results_files[method])
            
            # Response analysis
            response_sample = df['llm_response'].head(20).tolist()
            prediction_dist = df['prediction'].value_counts()
            reference_dist = df['reference'].value_counts()
            
            print(f"Total predictions: {len(df)}")
            print(f"Prediction distribution: {dict(prediction_dist)}")
            print(f"Reference distribution: {dict(reference_dist)}")
            
            # Sample responses
            print(f"\nSample LLM responses:")
            for i, response in enumerate(response_sample[:5]):
                ref = df.iloc[i]['reference']
                pred = df.iloc[i]['prediction']
                print(f"  {i+1}. '{response}' -> {pred} (ref: {ref})")
            
            # Check for repetitive responses
            unique_responses = df['llm_response'].nunique()
            print(f"\nUnique responses: {unique_responses}/{len(df)} ({unique_responses/len(df)*100:.1f}%)")
            
            # Most common responses
            common_responses = df['llm_response'].value_counts().head(10)
            print(f"\nMost common responses:")
            for response, count in common_responses.items():
                print(f"  '{response}': {count} times ({count/len(df)*100:.1f}%)")

def debug_prompt_quality(similarity_results_file, sample_size=5):
    """Check if prompts contain meaningful candidate differences"""
    
    print("\n🔍 DEBUGGING PROMPT QUALITY")
    print("="*50)
    
    # Load similarity results
    with open(similarity_results_file, 'rb') as f:
        similarity_results = pickle.load(f)
    
    for i in range(min(sample_size, len(similarity_results))):
        result = similarity_results[i]
        query_stock = result['query_stock']
        query_date = result['query_date']
        candidates = result.get('similarity_list', [])[:3]  # Top 3
        
        print(f"\nQuery {i+1}: {query_stock} on {query_date}")
        
        if candidates:
            print("Top candidates:")
            for j, candidate in enumerate(candidates):
                score = candidate.get('candidate_score', 0)
                index = candidate.get('candidate_index', 'unknown')
                print(f"  {j+1}. Index: {index}, Score: {score:.6f}")
                
                # Check if we have actual candidate data
                if 'candidate_data' in candidate:
                    cand_stock = candidate['candidate_data'].get('candidate_stock', 'unknown')
                    cand_date = candidate['candidate_data'].get('candidate_date', 'unknown')
                    print(f"      Stock: {cand_stock}, Date: {cand_date}")
        else:
            print("  No candidates found!")

# Usage
results_files = {
    'no_retrieval': 'no_retrieval_experiments/StockLLM_no_retrieval_results.csv',
    'random_retrieval': 'random_retrieval_experiments/StockLLM_random_retrieval_k5_results.csv',
    'finquest_similarity': 'finquest_similarity_experiments/StockLLM_finquest_similarity_k5_results.csv'
}

analyze_stockllm_responses(results_files)
debug_prompt_quality('similar_candidates/test/FinQuest/test_similarity_results_FinQuest.pkl')

🔍 ANALYZING STOCKLLM RESPONSE PATTERNS

📊 NO_RETRIEVAL
----------------------------------------
Total predictions: 2642
Prediction distribution: {'rise': 2640, 'fall': 2}
Reference distribution: {'rise': 1382, 'fall': 1260}

Sample LLM responses:
  1. 'rise.' -> rise (ref: rise)
  2. 'rise.' -> rise (ref: fall)
  3. 'rise.' -> rise (ref: rise)
  4. 'rise.' -> rise (ref: rise)
  5. 'rise.' -> rise (ref: fall)

Unique responses: 8/2642 (0.3%)

Most common responses:
  'rise.': 2175 times (82.3%)
  'rise

Query: On 2025-02': 165 times (6.2%)
  'rise

Query: On 2025-04': 125 times (4.7%)
  'rise

Query: On 2025-03': 120 times (4.5%)
  'rise

Query: On 2025-01': 48 times (1.8%)
  'rise': 5 times (0.2%)
  'rise

Query: On 2025-05': 2 times (0.1%)
  'fall.': 2 times (0.1%)

📊 RANDOM_RETRIEVAL
----------------------------------------
Total predictions: 2642
Prediction distribution: {'rise': 2611, 'fall': 31}
Reference distribution: {'rise': 1382, 'fall': 1260}

Sample LLM responses:
  1. 'rise

In [4]:
# Script to trace your FinQuest data pipeline and understand the similarity score issue

import pickle
import json
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime

def trace_finquest_data_pipeline():
    """Trace the complete FinQuest data pipeline to understand similarity scores"""
    
    print("🔍 TRACING FINQUEST DATA PIPELINE")
    print("="*80)
    
    # Step 1: Understand your data sources
    print("\n📂 STEP 1: DATA SOURCES ANALYSIS")
    print("-" * 50)
    
    # Test queries (what you're predicting on)
    test_queries_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/test_queries_rise_fall_only.json'
    ground_truth_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/ground_truth_rise_fall_only.json'
    
    # Load test queries
    test_queries = []
    with open(test_queries_file, 'r') as f:
        for line in f:
            try:
                test_queries.append(json.loads(line.strip()))
            except:
                continue
    
    # Load ground truth
    ground_truth = []
    with open(ground_truth_file, 'r') as f:
        for line in f:
            try:
                ground_truth.append(json.loads(line.strip()))
            except:
                continue
    
    print(f"✅ Test queries: {len(test_queries)}")
    print(f"✅ Ground truth: {len(ground_truth)}")
    
    # Analyze test query date ranges
    test_dates = []
    test_stocks = set()
    for query in test_queries:
        if 'query_date' in query:
            test_dates.append(query['query_date'])
            test_stocks.add(query.get('query_stock', ''))
    
    test_dates.sort()
    print(f"Test date range: {test_dates[0] if test_dates else 'None'} to {test_dates[-1] if test_dates else 'None'}")
    print(f"Test stocks: {len(test_stocks)} unique stocks")
    print(f"Sample test stocks: {list(test_stocks)[:10]}")
    
    # Step 2: Analyze candidate pool (training data embeddings)
    print("\n📂 STEP 2: CANDIDATE POOL ANALYSIS")  
    print("-" * 50)
    
    embeddings_dir = '/root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest'
    
    # Load candidates
    candidates = {}
    candidate_dates = []
    candidate_stocks = set()
    
    embedding_files = list(Path(embeddings_dir).glob("c_*_FinQuest_embeddings_*.pkl"))
    print(f"Found {len(embedding_files)} candidate embedding files")
    
    for file_path in embedding_files:
        with open(file_path, 'rb') as f:
            embedding_data = pickle.load(f)
        
        for date_group in embedding_data:
            for date, candidates_on_date in date_group.items():
                for candidate_item in candidates_on_date:
                    candidate_data = candidate_item['data']
                    candidates[candidate_data['data_index']] = candidate_data
                    
                    if 'candidate_date' in candidate_data:
                        candidate_dates.append(candidate_data['candidate_date'])
                        candidate_stocks.add(candidate_data.get('candidate_stock', ''))
    
    candidate_dates.sort()
    print(f"✅ Total candidates: {len(candidates)}")
    print(f"Candidate date range: {candidate_dates[0] if candidate_dates else 'None'} to {candidate_dates[-1] if candidate_dates else 'None'}")
    print(f"Candidate stocks: {len(candidate_stocks)} unique stocks")
    print(f"Sample candidate stocks: {list(candidate_stocks)[:10]}")
    
    # Check for overlap
    stock_overlap = test_stocks.intersection(candidate_stocks)
    print(f"Stock overlap: {len(stock_overlap)}/{len(test_stocks)} test stocks also in candidates")
    
    # Date overlap check
    test_date_set = set(test_dates)
    candidate_date_set = set(candidate_dates)
    date_overlap = test_date_set.intersection(candidate_date_set)
    print(f"Date overlap: {len(date_overlap)} dates appear in both test and candidate sets")
    
    if len(date_overlap) > 0:
        print(f"⚠️  WARNING: Overlapping dates detected!")
        print(f"Sample overlapping dates: {list(date_overlap)[:10]}")
    
    # Step 3: Analyze similarity results structure
    print("\n📂 STEP 3: SIMILARITY RESULTS ANALYSIS")
    print("-" * 50)
    
    similarity_file = 'similar_candidates/test/FinQuest/test_similarity_results_FinQuest.pkl'
    
    if not Path(similarity_file).exists():
        print(f"❌ Similarity file not found: {similarity_file}")
        return
    
    with open(similarity_file, 'rb') as f:
        similarity_results = pickle.load(f)
    
    print(f"✅ Similarity results: {len(similarity_results)} queries")
    
    # Analyze first few similarity results in detail
    print(f"\n🔍 DETAILED ANALYSIS OF FIRST 5 SIMILARITY RESULTS:")
    
    for i in range(min(5, len(similarity_results))):
        result = similarity_results[i]
        query_stock = result.get('query_stock', 'Unknown')
        query_date = result.get('query_date', 'Unknown')
        similarity_list = result.get('similarity_list', [])
        
        print(f"\n--- Query {i+1}: {query_stock} on {query_date} ---")
        print(f"Candidates found: {len(similarity_list)}")
        
        if similarity_list:
            # Analyze top 3 candidates
            for j, candidate_info in enumerate(similarity_list[:3]):
                candidate_index = candidate_info.get('candidate_index', 'Unknown')
                candidate_score = candidate_info.get('candidate_score', 0)
                
                print(f"  Candidate {j+1}:")
                print(f"    Index: {candidate_index}")
                print(f"    Score: {candidate_score:.10f}")
                
                # Get candidate details
                if candidate_index in candidates:
                    candidate_data = candidates[candidate_index]
                    cand_stock = candidate_data.get('candidate_stock', 'Unknown')
                    cand_date = candidate_data.get('candidate_date', 'Unknown')
                    cand_movement = candidate_data.get('movement', 'Unknown')
                    
                    print(f"    Stock: {cand_stock}")
                    print(f"    Date: {cand_date}")
                    print(f"    Movement: {cand_movement}")
                    
                    # Check if this is an exact match
                    if query_stock == cand_stock and query_date == cand_date:
                        print(f"    🚨 EXACT MATCH: Same stock and date!")
                    
                    # Check similarity score patterns
                    if abs(candidate_score - 1.0) < 1e-10:
                        print(f"    ⚠️  PERFECT SCORE: Exactly 1.0")
                else:
                    print(f"    ❌ Candidate data not found for index {candidate_index}")
    
    # Step 4: Statistical analysis of similarity scores
    print(f"\n📊 STEP 4: SIMILARITY SCORE STATISTICS")
    print("-" * 50)
    
    all_scores = []
    perfect_scores = 0
    exact_matches = 0
    score_patterns = {}
    
    for result in similarity_results:
        query_stock = result.get('query_stock', '')
        query_date = result.get('query_date', '')
        similarity_list = result.get('similarity_list', [])
        
        for candidate_info in similarity_list:
            candidate_index = candidate_info.get('candidate_index', -1)
            candidate_score = candidate_info.get('candidate_score', 0)
            
            all_scores.append(candidate_score)
            
            # Count perfect scores
            if abs(candidate_score - 1.0) < 1e-10:
                perfect_scores += 1
            
            # Count exact matches
            if candidate_index in candidates:
                candidate_data = candidates[candidate_index]
                if (candidate_data.get('candidate_stock', '') == query_stock and 
                    candidate_data.get('candidate_date', '') == query_date):
                    exact_matches += 1
            
            # Track score patterns
            score_rounded = round(candidate_score, 6)
            score_patterns[score_rounded] = score_patterns.get(score_rounded, 0) + 1
    
    print(f"Total similarity scores analyzed: {len(all_scores)}")
    print(f"Perfect scores (1.0): {perfect_scores} ({perfect_scores/len(all_scores)*100:.2f}%)")
    print(f"Exact stock-date matches: {exact_matches}")
    print(f"Unique score values: {len(score_patterns)}")
    
    if len(all_scores) > 0:
        print(f"Score statistics:")
        print(f"  Min: {min(all_scores):.10f}")
        print(f"  Max: {max(all_scores):.10f}")
        print(f"  Mean: {np.mean(all_scores):.10f}")
        print(f"  Std: {np.std(all_scores):.10f}")
    
    # Show most common scores
    print(f"\nMost common similarity scores:")
    sorted_patterns = sorted(score_patterns.items(), key=lambda x: x[1], reverse=True)
    for score, count in sorted_patterns[:10]:
        print(f"  {score:.6f}: {count} times ({count/len(all_scores)*100:.1f}%)")
    
    # Step 5: Root cause analysis
    print(f"\n🔍 STEP 5: ROOT CAUSE ANALYSIS")
    print("-" * 50)
    
    if perfect_scores > len(all_scores) * 0.5:
        print("🚨 CRITICAL ISSUE: >50% of similarity scores are exactly 1.0")
        
        possible_causes = [
            "1. Query-candidate data overlap (same stock-date pairs in both sets)",
            "2. Embedding computation error (all embeddings identical)",
            "3. Cosine similarity calculation bug (returning 1.0 for everything)", 
            "4. Normalization issue (all embeddings have norm 0 or identical)",
            "5. Data preprocessing creating duplicate sequences",
            "6. Similarity search using wrong embedding vectors"
        ]
        
        print("Possible causes:")
        for cause in possible_causes:
            print(f"  {cause}")
    
    elif exact_matches > 0:
        print(f"🚨 DATA LEAK: {exact_matches} queries match candidates with same stock-date")
        print("This means your 'training' candidates include your test queries!")
        
    elif len(score_patterns) == 1:
        print("🚨 COMPUTATION ERROR: All similarity scores are identical")
        print("This suggests a bug in similarity computation")
        
    else:
        print("🤔 UNCLEAR: Similarity scores look normal but something else is wrong")
        print("Check your similarity search generation process")
    
    # Step 6: Recommendations
    print(f"\n💡 RECOMMENDATIONS")
    print("-" * 50)
    
    if exact_matches > 0:
        print("1. 🔧 IMMEDIATE FIX: Remove exact stock-date matches from candidates")
        print("   Add this filter in your similarity search:")
        print("""
   def filter_overlapping_candidates(query, similarity_list, candidates):
       filtered = []
       query_stock = query.get('query_stock', '')
       query_date = query.get('query_date', '')
       
       for candidate_info in similarity_list:
           candidate_index = candidate_info.get('candidate_index', -1)
           if candidate_index in candidates:
               candidate_data = candidates[candidate_index]
               # Skip if same stock and date
               if not (candidate_data.get('candidate_stock', '') == query_stock and 
                      candidate_data.get('candidate_date', '') == query_date):
                   filtered.append(candidate_info)
       
       return filtered
        """)
    
    if perfect_scores > len(all_scores) * 0.1:
        print("2. 🔧 CHECK EMBEDDINGS: Verify embedding computation")
        print("   - Check if all embeddings are identical")
        print("   - Verify cosine similarity calculation") 
        print("   - Check embedding normalization")
    
    print("3. 🔧 VERIFY DATA SPLIT: Ensure temporal separation")
    print("   - Training data should be from earlier dates")
    print("   - Test data should be from later dates") 
    print("   - No overlap in stock-date combinations")
    
    return {
        'test_queries': len(test_queries),
        'candidates': len(candidates),
        'similarity_results': len(similarity_results),
        'perfect_scores': perfect_scores,
        'exact_matches': exact_matches,
        'total_scores': len(all_scores),
        'stock_overlap': len(stock_overlap),
        'date_overlap': len(date_overlap)
    }

if __name__ == "__main__":
    results = trace_finquest_data_pipeline()
    
    print(f"\n🎯 SUMMARY")
    print("="*50)
    for key, value in results.items():
        print(f"{key}: {value}")

🔍 TRACING FINQUEST DATA PIPELINE

📂 STEP 1: DATA SOURCES ANALYSIS
--------------------------------------------------
✅ Test queries: 2642
✅ Ground truth: 2642
Test date range: 2025-01-17 to 2025-04-29
Test stocks: 50 unique stocks
Sample test stocks: ['PEP', 'CVX', 'JNJ', 'HD', 'AXP', 'MRK', 'COP', 'CAT', 'AMZN', 'INTC']

📂 STEP 2: CANDIDATE POOL ANALYSIS
--------------------------------------------------
Found 75 candidate embedding files
✅ Total candidates: 8151
Candidate date range: 2022-01-18 to 2024-12-27
Candidate stocks: 50 unique stocks
Sample candidate stocks: ['CVX', 'PEP', 'JNJ', 'HD', 'AXP', 'COP', 'MRK', 'CAT', 'AMZN', 'INTC']
Stock overlap: 50/50 test stocks also in candidates
Date overlap: 0 dates appear in both test and candidate sets

📂 STEP 3: SIMILARITY RESULTS ANALYSIS
--------------------------------------------------
✅ Similarity results: 2642 queries

🔍 DETAILED ANALYSIS OF FIRST 5 SIMILARITY RESULTS:

--- Query 1: AAPL on 2025-01-17 ---
Candidates found: 10
  Ca