In [1]:
# FIXED Query-to-Query Similarity Search for FinQuest
# Addresses all alignment issues with your embedding generation code

import torch
import numpy as np
import json
import os
import pickle
import argparse
import multiprocessing
from pathlib import Path
import logging
from transformers import AutoTokenizer
from tqdm import tqdm
from collections import defaultdict
from datetime import datetime

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Disable warnings (same as your embedding code)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class FinQuestRetriever(torch.nn.Module):
    """Your custom trained FinQuest model - EXACT COPY from embedding code"""
    
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2', hidden_size=384, dropout_rate=0.1):
        super().__init__()
        from transformers import AutoModel
        
        self.encoder = AutoModel.from_pretrained(model_name)
        self.hidden_size = hidden_size
        
        encoder_dim = self.encoder.config.hidden_size
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hidden_size * 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(hidden_size * 2, hidden_size),
            torch.nn.LayerNorm(hidden_size)
        )
        
        self.financial_attention = torch.nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )
        
        self.dropout = torch.nn.Dropout(dropout_rate)
    
    def encode_sequence(self, sequences):
        """Encode sequences using FinQuest model - EXACT COPY"""
        if not sequences or all(not seq.strip() for seq in sequences):
            return torch.zeros(len(sequences), self.hidden_size).to(next(self.parameters()).device)
        
        inputs = self.tokenizer(
            sequences,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(next(self.parameters()).device)
        
        with torch.cuda.amp.autocast():
            outputs = self.encoder(**inputs)
            embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
        
        projected = self.projection(embeddings.float())
        attended, _ = self.financial_attention(
            projected.unsqueeze(1),
            projected.unsqueeze(1), 
            projected.unsqueeze(1)
        )
        attended = attended.squeeze(1)
        
        return torch.nn.functional.normalize(self.dropout(attended), p=2, dim=1)
    
    def mean_pooling(self, model_output, attention_mask):
        """Mean pooling with attention mask - EXACT COPY"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class RetrievalBasedSimilaritySearch:
    """
    FIXED: Retrieval-based similarity search using your existing embeddings
    """
    
    def __init__(self, test_queries_file, test_dataset='your_dataset', embedder_name='FinQuest', device='cuda'):
        self.test_dataset = test_dataset
        self.embedder_name = embedder_name
        self.device = device
        self.root = '/root/nfs/AJ FinRag/Models/'
        
        # Load model (same as your EMBEDDER class)
        self.model = self._load_finquest_model()
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        
        # Load test queries
        self.test_queries = self._load_test_queries(test_queries_file)
        
        # Load existing embeddings (using your exact file structure)
        self.embeddings_dir = '/root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest'
        self.historical_queries = self._load_historical_query_embeddings()
        self.candidates = self._load_candidate_embeddings()
        
        # Create query-candidate mapping
        self.query_candidate_mapping = self._create_query_candidate_mapping()
        
        logger.info(f"Loaded {len(self.test_queries)} test queries")
        logger.info(f"Loaded {len(self.historical_queries)} historical queries")
        logger.info(f"Loaded {len(self.candidates)} candidates")
        logger.info(f"Created mapping for {len(self.query_candidate_mapping)} queries")
    
    def _load_finquest_model(self):
        """Load your trained FinQuest model - EXACT COPY from EMBEDDER"""
        model_path = os.path.join(self.root, 'finquest_models/finquest_retriever_best.pth')
        model = FinQuestRetriever().to(self.device)
        checkpoint = torch.load(model_path, map_location=self.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        model.eval()
        logger.info(f"Loaded FinQuest model from: {model_path}")
        return model
    
    def _load_test_queries(self, test_queries_file):
        """Load new test queries - FIXED: Better error handling"""
        queries = []
        
        if not os.path.exists(test_queries_file):
            logger.error(f"Test queries file not found: {test_queries_file}")
            return queries
            
        try:
            with open(test_queries_file, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        query = json.loads(line.strip())
                        queries.append(query)
                    except json.JSONDecodeError as e:
                        logger.warning(f"Skipping invalid JSON on line {line_num}: {e}")
                        continue
        except Exception as e:
            logger.error(f"Error reading test queries file: {e}")
            
        return queries
    
    def _load_historical_query_embeddings(self):
        """FIXED: Load historical query embeddings with better error handling"""
        historical_queries = {}
        
        # Your query embedding file pattern: q_{test_dataset}_FinQuest_embeddings.pkl
        query_file = os.path.join(self.embeddings_dir, f'q_{self.test_dataset}_FinQuest_embeddings.pkl')
        
        if not os.path.exists(query_file):
            logger.error(f"Query embeddings file not found: {query_file}")
            logger.error("Please check:")
            logger.error(f"1. Embeddings directory exists: {self.embeddings_dir}")
            logger.error(f"2. File naming matches: q_{self.test_dataset}_{self.embedder_name}_embeddings.pkl")
            logger.error("3. Query embeddings were generated successfully")
            return historical_queries
        
        logger.info(f"Loading query embeddings: {query_file}")
        try:
            with open(query_file, 'rb') as f:
                embedding_data = pickle.load(f)
            
            # Your structure: list of {date: [{'data': query, 'embedding': emb}, ...]}
            for date_group in embedding_data:
                if not isinstance(date_group, dict):
                    logger.warning(f"Unexpected date_group type: {type(date_group)}")
                    continue
                    
                for date, queries in date_group.items():
                    if not isinstance(queries, list):
                        logger.warning(f"Unexpected queries type for date {date}: {type(queries)}")
                        continue
                        
                    for query_item in queries:
                        if 'data' not in query_item or 'embedding' not in query_item:
                            logger.warning(f"Query item missing required keys: {query_item.keys()}")
                            continue
                            
                        query_data = query_item['data']
                        embedding = torch.tensor(query_item['embedding'], dtype=torch.float32).to(self.device)
                        
                        # Use data_index as unique identifier (from your data structure)
                        query_id = query_data.get('data_index', len(historical_queries))
                        historical_queries[query_id] = {
                            'data': query_data,
                            'embedding': embedding,
                            'date': date
                        }
            
            logger.info(f"Successfully loaded {len(historical_queries)} historical queries")
                        
        except Exception as e:
            logger.error(f"Error loading query embeddings: {e}")
            
        return historical_queries
    
    def _load_candidate_embeddings(self):
        """FIXED: Load candidate embeddings with flexible file name matching"""
        candidates = {}
        
        # Search for candidate embedding files with flexible naming
        group = 1
        files_loaded = 0
        
        while True:
            possible_candidate_files = [
                os.path.join(self.embeddings_dir, f'c_{self.test_dataset}_FinQuest_embeddings_{group}.pkl'),
                os.path.join(self.embeddings_dir, f'c_your_dataset_FinQuest_embeddings_{group}.pkl'),
            ]
            
            candidate_file = None
            for file_path in possible_candidate_files:
                if os.path.exists(file_path):
                    candidate_file = file_path
                    break
            
            if candidate_file is None:
                break
            
            logger.info(f"Loading candidate embeddings: {candidate_file}")
            try:
                with open(candidate_file, 'rb') as f:
                    embedding_data = pickle.load(f)
                
                # Your structure: list of {date: [{'data': candidate, 'embedding': emb}, ...]}
                for date_group in embedding_data:
                    if not isinstance(date_group, dict):
                        logger.warning(f"Unexpected date_group type in file {group}: {type(date_group)}")
                        continue
                        
                    for date, candidate_list in date_group.items():
                        if not isinstance(candidate_list, list):
                            logger.warning(f"Unexpected candidate_list type for date {date}: {type(candidate_list)}")
                            continue
                            
                        for candidate_item in candidate_list:
                            if 'data' not in candidate_item or 'embedding' not in candidate_item:
                                logger.warning(f"Candidate item missing required keys: {candidate_item.keys()}")
                                continue
                                
                            candidate_data = candidate_item['data']
                            embedding = torch.tensor(candidate_item['embedding'], dtype=torch.float32).to(self.device)
                            
                            # Use data_index as unique identifier
                            candidate_id = candidate_data.get('data_index', len(candidates))
                            candidates[candidate_id] = {
                                'data': candidate_data,
                                'embedding': embedding,
                                'date': date
                            }
                
                files_loaded += 1
                            
            except Exception as e:
                logger.error(f"Error loading candidate file {group}: {e}")
            
            group += 1
        
        logger.info(f"Successfully loaded {files_loaded} candidate embedding files with {len(candidates)} total candidates")
        return candidates
    
    def _get_query_str(self, query):
        """Convert query dict to string - EXACT COPY from EMBEDDER"""
        if 'query_str' in query:
            return query['query_str']
        
        # Convert from your query format to string
        seq_dict = {
            'query_stock': query.get('query_stock', 'Unknown'),
            'query_date': query.get('query_date', 'Unknown'),
            'recent_date_list': query.get('recent_date_list', []),
            'adjusted_close_list': query.get('adjusted_close_list', [])
        }
        return str(seq_dict)
    
    def _create_query_candidate_mapping(self):
        """FIXED: Create mapping with better logic and error handling"""
        mapping = defaultdict(list)
        
        if not self.historical_queries or not self.candidates:
            logger.warning("Cannot create mapping: missing historical queries or candidates")
            return mapping
        
        logger.info("Creating query-candidate mapping based on dates and stocks...")
        
        successful_mappings = 0
        
        for query_id, query_info in self.historical_queries.items():
            query_date = query_info['date']
            query_stock = query_info['data'].get('query_stock', '')
            
            candidates_for_this_query = []
            
            for candidate_id, candidate_info in self.candidates.items():
                candidate_date = candidate_info['date']
                candidate_stock = candidate_info['data'].get('candidate_stock', '')
                
                # Strategy 1: Same date, same stock (primary association)
                if (query_date == candidate_date) and (query_stock == candidate_stock):
                    candidates_for_this_query.append(candidate_id)
                
                # Strategy 2: Candidate within 1-3 days before query, same stock
                elif query_stock == candidate_stock and query_stock != '':
                    try:
                        query_dt = datetime.strptime(query_date, "%Y-%m-%d")
                        candidate_dt = datetime.strptime(candidate_date, "%Y-%m-%d")
                        date_diff = (query_dt - candidate_dt).days
                        if 0 <= date_diff <= 3:
                            candidates_for_this_query.append(candidate_id)
                    except ValueError as e:
                        # Skip invalid date formats
                        continue
                    except Exception as e:
                        # Skip other date processing errors
                        continue
            
            if candidates_for_this_query:
                mapping[query_id] = candidates_for_this_query
                successful_mappings += 1
        
        # Log statistics
        total_queries = len(self.historical_queries)
        mapped_queries = len(mapping)
        
        if mapping:
            avg_candidates = np.mean([len(candidates) for candidates in mapping.values()])
            max_candidates = max([len(candidates) for candidates in mapping.values()])
        else:
            avg_candidates = 0
            max_candidates = 0
        
        logger.info(f"Mapping Statistics:")
        logger.info(f"  Total historical queries: {total_queries}")
        logger.info(f"  Queries with candidates: {mapped_queries} ({mapped_queries/total_queries*100:.1f}%)")
        logger.info(f"  Average candidates per query: {avg_candidates:.1f}")
        logger.info(f"  Max candidates for any query: {max_candidates}")
        
        if mapped_queries == 0:
            logger.warning("⚠️  No query-candidate mappings created!")
            logger.warning("This could indicate:")
            logger.warning("1. Date format mismatches between queries and candidates")
            logger.warning("2. Stock symbol mismatches")
            logger.warning("3. No overlapping dates in your data")
            
            # Show sample data for debugging
            if self.historical_queries:
                sample_query = next(iter(self.historical_queries.values()))
                logger.info(f"Sample query date format: {sample_query['date']}")
                logger.info(f"Sample query stock: {sample_query['data'].get('query_stock', 'N/A')}")
            
            if self.candidates:
                sample_candidate = next(iter(self.candidates.values()))
                logger.info(f"Sample candidate date format: {sample_candidate['date']}")
                logger.info(f"Sample candidate stock: {sample_candidate['data'].get('candidate_stock', 'N/A')}")
        
        return mapping
    
    def encode_test_query(self, test_query):
        """Encode a new test query using the same method as embedding generation"""
        query_str = self._get_query_str(test_query)
        
        with torch.no_grad():
            embedding = self.model.encode_sequence([query_str])
        
        return embedding[0]
    
    def find_similar_historical_queries(self, test_query, top_k=20):
        """Find most similar historical queries to a test query"""
        if not self.historical_queries:
            logger.warning("No historical queries available for similarity search")
            return []
            
        test_embedding = self.encode_test_query(test_query)
        
        similarities = []
        
        for query_id, query_info in self.historical_queries.items():
            historical_embedding = query_info['embedding']
            
            # Calculate cosine similarity (same as your training)
            similarity = torch.cosine_similarity(
                test_embedding.unsqueeze(0), 
                historical_embedding.unsqueeze(0)
            ).item()
            
            similarities.append({
                'query_id': query_id,
                'similarity': similarity,
                'historical_query': query_info['data'],
                'date': query_info['date']
            })
        
        # Sort by similarity and return top-K
        similarities.sort(key=lambda x: x['similarity'], reverse=True)
        return similarities[:top_k]
    
    def retrieve_candidates_from_similar_queries(self, similar_queries, top_k=10):
        """Retrieve candidates associated with similar historical queries"""
        if not similar_queries:
            return []
            
        candidate_scores = defaultdict(float)
        candidate_info = {}
        
        for similar_query in similar_queries:
            query_id = similar_query['query_id']
            query_similarity = similar_query['similarity']
            
            # Get candidates associated with this historical query
            associated_candidate_ids = self.query_candidate_mapping.get(query_id, [])
            
            for candidate_id in associated_candidate_ids:
                if candidate_id in self.candidates:
                    # Use max similarity as candidate score
                    candidate_scores[candidate_id] = max(
                        candidate_scores[candidate_id], 
                        query_similarity
                    )
                    candidate_info[candidate_id] = self.candidates[candidate_id]
        
        # Create final candidate list
        final_candidates = []
        for candidate_id, score in candidate_scores.items():
            final_candidates.append({
                'candidate_index': candidate_id,  # Match your expected output format
                'candidate_score': score,         # Match your expected output format
                'candidate_data': candidate_info[candidate_id]['data'],
                'candidate_date': candidate_info[candidate_id]['date']
            })
        
        # Sort by score and return top-K
        final_candidates.sort(key=lambda x: x['candidate_score'], reverse=True)
        return final_candidates[:top_k]
    
    def search_all_test_queries(self, top_k_queries=20, top_k_candidates=10):
        """FIXED: Run similarity search for all test queries with better error handling"""
        all_results = []
        
        if not self.test_queries:
            logger.error("No test queries to process")
            return all_results
        
        logger.info(f"Processing {len(self.test_queries)} test queries...")
        
        queries_with_candidates = 0
        
        for i, test_query in enumerate(self.test_queries):
            if (i + 1) % 100 == 0:
                logger.info(f"Processed {i + 1}/{len(self.test_queries)} queries")
            
            # Step 1: Find similar historical queries
            similar_queries = self.find_similar_historical_queries(test_query, top_k_queries)
            
            if not similar_queries:
                logger.warning(f"No similar queries found for test query {i}")
                all_results.append({
                    'query_id': test_query.get('query_id', f"query_{test_query.get('data_index', i)}"),
                    'query_index': test_query.get('data_index', i),
                    'query_stock': test_query.get('query_stock', 'Unknown'),
                    'query_date': test_query.get('query_date', 'Unknown'),
                    'similarity_list': [],
                    'total_candidates_searched': 0
                })
                continue
            
            # Step 2: Retrieve candidates from similar queries
            candidates = self.retrieve_candidates_from_similar_queries(similar_queries, top_k_candidates)
            
            if candidates:
                queries_with_candidates += 1
            
            # Format result to match your expected structure
            result = {
                'query_id': test_query.get('query_id', f"query_{test_query.get('data_index', i)}"),
                'query_index': test_query.get('data_index', i),
                'query_stock': test_query.get('query_stock', 'Unknown'),
                'query_date': test_query.get('query_date', 'Unknown'),
                'similarity_list': candidates,  # This matches your original output format
                'total_candidates_searched': len(candidates)
            }
            
            all_results.append(result)
        
        logger.info(f"Completed similarity search for {len(all_results)} test queries")
        logger.info(f"Queries with candidates found: {queries_with_candidates}/{len(all_results)} ({queries_with_candidates/len(all_results)*100:.1f}%)")
        
        return all_results
    
    def save_similarity_results(self, similarity_results, output_file):
        """Save similarity results to pickle file - same as your original"""
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        with open(output_file, 'wb') as f:
            pickle.dump(similarity_results, f)
        
        logger.info(f"Saved results to: {output_file}")
    
    def analyze_results(self, results):
        """FIXED: Better analysis with more detailed statistics"""
        if not results:
            logger.warning("No results to analyze")
            return
        
        total_queries = len(results)
        queries_with_candidates = sum(1 for r in results if r['similarity_list'])
        
        if queries_with_candidates > 0:
            avg_candidates = np.mean([len(r['similarity_list']) for r in results if r['similarity_list']])
            max_candidates = max([len(r['similarity_list']) for r in results])
            min_candidates = min([len(r['similarity_list']) for r in results if r['similarity_list']])
        else:
            avg_candidates = max_candidates = min_candidates = 0
        
        logger.info("="*50)
        logger.info("RETRIEVAL-BASED SEARCH RESULTS ANALYSIS")
        logger.info("="*50)
        logger.info(f"Total test queries processed: {total_queries}")
        logger.info(f"Queries with candidates found: {queries_with_candidates} ({queries_with_candidates/total_queries*100:.1f}%)")
        logger.info(f"Queries with no candidates: {total_queries - queries_with_candidates}")
        logger.info(f"Average candidates per successful query: {avg_candidates:.1f}")
        logger.info(f"Max candidates for any query: {max_candidates}")
        logger.info(f"Min candidates for successful queries: {min_candidates}")
        
        # Show sample result
        successful_results = [r for r in results if r['similarity_list']]
        if successful_results:
            sample = successful_results[0]
            logger.info("\nSAMPLE SUCCESSFUL RESULT:")
            logger.info(f"Test query: {sample['query_stock']} on {sample['query_date']}")
            logger.info(f"Found {len(sample['similarity_list'])} candidates")
            
            if sample['similarity_list']:
                top_candidate = sample['similarity_list'][0]
                logger.info(f"Top candidate: Index {top_candidate['candidate_index']} (score: {top_candidate['candidate_score']:.3f})")
        
        # Debugging info if no results found
        if queries_with_candidates == 0:
            logger.warning("\n⚠️  NO CANDIDATES FOUND FOR ANY QUERIES!")
            logger.warning("This suggests an issue with:")
            logger.warning("1. Query-candidate mapping creation")
            logger.warning("2. Historical data loading")
            logger.warning("3. Date/stock matching logic")

def run_retrieval_search():
    """FIXED: Main function with comprehensive error checking"""
    
    # Configuration - aligned with your paths and structure
    test_queries_file = '/root/nfs/AJ FinRag/Evaluation Results/Test Queries/test_queries_rise_fall_only.json'
    test_dataset = 'your_dataset'
    embedder_name = 'FinQuest'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Output configuration
    output_dir = 'similar_candidates/test/FinQuest'
    output_file = os.path.join(output_dir, 'test_similarity_results_FinQuest.pkl')
    
    logger.info("🚀 Starting Retrieval-Based Similarity Search")
    logger.info("="*60)
    
    # COMPREHENSIVE PRE-FLIGHT CHECKS
    
    # Check 1: Test queries file
    if not os.path.exists(test_queries_file):
        logger.error(f"❌ Test queries file not found: {test_queries_file}")
        return None
    
    # Check 2: Embeddings directory
    embeddings_dir = '/root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest'
    if not os.path.exists(embeddings_dir):
        logger.error(f"❌ Embeddings directory not found: {embeddings_dir}")
        logger.error("Run embedding generation first!")
        return None
    
    # Check 3: Query embeddings - flexible search
    import glob
    query_pattern = os.path.join(embeddings_dir, 'q_*_FinQuest_embeddings.pkl')
    query_files = glob.glob(query_pattern)
    
    if not query_files:
        logger.error(f"❌ No query embeddings found in: {embeddings_dir}")
        logger.error("Looked for pattern: q_*_FinQuest_embeddings.pkl")
        
        # List actual files for debugging
        if os.path.exists(embeddings_dir):
            actual_files = [f for f in os.listdir(embeddings_dir) if f.endswith('.pkl')]
            logger.error(f"Actual .pkl files found: {actual_files}")
        
        logger.error("Generate query embeddings first: get_embeddings(dataset, 'FinQuest', 'query', device)")
        return None
    
    query_file = query_files[0]  # Use first match
    
    # Check 4: Candidate embeddings - flexible search
    candidate_pattern = os.path.join(embeddings_dir, 'c_*_FinQuest_embeddings_1.pkl')
    candidate_files = glob.glob(candidate_pattern)
    
    if not candidate_files:
        logger.error(f"❌ No candidate embeddings found in: {embeddings_dir}")
        logger.error("Looked for pattern: c_*_FinQuest_embeddings_1.pkl")
        logger.error("Generate candidate embeddings first: get_embeddings(dataset, 'FinQuest', 'candidate', device)")
        return None
    
    # Check 5: Model file
    model_path = '/root/nfs/AJ FinRag/Models/finquest_models/finquest_retriever_best.pth'
    if not os.path.exists(model_path):
        logger.error(f"❌ Model file not found: {model_path}")
        return None
    
    logger.info(f"✅ All files found - proceeding with search")
    logger.info(f"✅ Test queries: {test_queries_file}")
    logger.info(f"✅ Query embeddings: {query_file}")
    logger.info(f"✅ Candidate embeddings: {embeddings_dir}")
    logger.info(f"✅ Model: {model_path}")
    
    try:
        # Initialize search system
        search_system = RetrievalBasedSimilaritySearch(
            test_queries_file=test_queries_file,
            test_dataset=test_dataset,
            embedder_name=embedder_name,
            device=device
        )
        
        # Verify system is properly initialized
        if len(search_system.historical_queries) == 0:
            logger.error("❌ No historical queries loaded - cannot proceed")
            return None
            
        if len(search_system.candidates) == 0:
            logger.error("❌ No candidates loaded - cannot proceed")
            return None
            
        if len(search_system.query_candidate_mapping) == 0:
            logger.error("❌ No query-candidate mappings created - cannot proceed")
            return None
        
        # Run search with same parameters as your original
        logger.info("Running similarity search...")
        results = search_system.search_all_test_queries(
            top_k_queries=20,    # Find top 20 similar historical queries
            top_k_candidates=10  # Retrieve top 10 candidates (save_top_k equivalent)
        )
        
        # Analyze results
        search_system.analyze_results(results)
        
        # Save results using the same format as your original
        search_system.save_similarity_results(results, output_file)
        
        logger.info("🎉 Retrieval-based similarity search completed successfully!")
        return results
        
    except Exception as e:
        logger.error(f"❌ Error in similarity search: {e}")
        import traceback
        traceback.print_exc()
        return None

if __name__ == "__main__":
    # Set multiprocessing method (same as your embedding code)
    multiprocessing.set_start_method('spawn', force=True)
    
    results = run_retrieval_search()

INFO:__main__:🚀 Starting Retrieval-Based Similarity Search
INFO:__main__:✅ All files found - proceeding with search
INFO:__main__:✅ Test queries: /root/nfs/AJ FinRag/Evaluation Results/Test Queries/test_queries_rise_fall_only.json
INFO:__main__:✅ Query embeddings: /root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest/q_your_dataset_FinQuest_embeddings.pkl
INFO:__main__:✅ Candidate embeddings: /root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest
INFO:__main__:✅ Model: /root/nfs/AJ FinRag/Models/finquest_models/finquest_retriever_best.pth
INFO:__main__:Loaded FinQuest model from: /root/nfs/AJ FinRag/Models/finquest_models/finquest_retriever_best.pth
INFO:__main__:Loading query embeddings: /root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest/q_your_dataset_FinQuest_embeddings.pkl
INFO:__main__:Successfully loaded 742 historical queries
INFO:__main__:Loading candidate embeddings: /root/nfs/AJ FinRag/Embeddings/embeddings/test/FinQuest/c_your_dataset_FinQuest_embeddings_1.pkl
INFO:_