In [1]:
# FIXED FinQuest Embedding Generation
# Addresses potential issues causing embedding collapse

import torch
import numpy as np
import json
import os
import pickle
import argparse
import multiprocessing
from pathlib import Path
import logging
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Disable warnings
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class FinQuestRetrieverFixed(torch.nn.Module):
    """FIXED: FinQuest model with better embedding generation"""
    
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2', hidden_size=384, dropout_rate=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.hidden_size = hidden_size
        
        encoder_dim = self.encoder.config.hidden_size
        self.projection = torch.nn.Sequential(
            torch.nn.Linear(encoder_dim, hidden_size * 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(hidden_size * 2, hidden_size),
            torch.nn.LayerNorm(hidden_size)
        )
        
        self.financial_attention = torch.nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8,
            dropout=dropout_rate,
            batch_first=True
        )
        
        self.dropout = torch.nn.Dropout(dropout_rate)
    
    def encode_sequence(self, sequences):
        """FIXED: Encode sequences with better error handling and validation"""
        device = next(self.parameters()).device
        
        if not sequences or all(not seq.strip() for seq in sequences):
            logger.warning("Empty sequences provided!")
            return torch.zeros(len(sequences), self.hidden_size).to(device)
        
        # FIXED: Better tokenization with error handling
        try:
            inputs = self.tokenizer(
                sequences,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors='pt'
            ).to(device)
        except Exception as e:
            logger.error(f"Tokenization error: {e}")
            return torch.zeros(len(sequences), self.hidden_size).to(device)
        
        # FIXED: Ensure model is in eval mode during inference
        was_training = self.training
        if was_training:
            self.eval()
        
        try:
            with torch.no_grad():  # FIXED: Ensure no gradients
                # Encoder forward pass
                outputs = self.encoder(**inputs)
                embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
                
                # FIXED: Check for NaN or extreme values
                if torch.isnan(embeddings).any():
                    logger.error("NaN detected in encoder embeddings!")
                    return torch.zeros(len(sequences), self.hidden_size).to(device)
                
                if torch.abs(embeddings).max() > 100:
                    logger.warning(f"Extreme values in embeddings: {torch.abs(embeddings).max().item()}")
                
                # Projection layer
                projected = self.projection(embeddings.float())
                
                # FIXED: Check projection output
                if torch.isnan(projected).any():
                    logger.error("NaN detected in projection!")
                    return torch.zeros(len(sequences), self.hidden_size).to(device)
                
                # Attention mechanism  
                attended, _ = self.financial_attention(
                    projected.unsqueeze(1),
                    projected.unsqueeze(1), 
                    projected.unsqueeze(1)
                )
                attended = attended.squeeze(1)
                
                # FIXED: Check attention output
                if torch.isnan(attended).any():
                    logger.error("NaN detected in attention!")
                    return torch.zeros(len(sequences), self.hidden_size).to(device)
                
                # FIXED: Apply dropout only during training, not inference
                if self.training:
                    attended = self.dropout(attended)
                
                # FIXED: L2 normalization with epsilon to prevent division by zero
                final_embeddings = torch.nn.functional.normalize(attended, p=2, dim=1, eps=1e-8)
                
                # FIXED: Final validation
                if torch.isnan(final_embeddings).any():
                    logger.error("NaN detected in final embeddings!")
                    return torch.zeros(len(sequences), self.hidden_size).to(device)
                
                # FIXED: Ensure embeddings are properly normalized
                norms = torch.norm(final_embeddings, dim=1)
                if (norms < 0.9).any() or (norms > 1.1).any():
                    logger.warning(f"Embedding normalization issue. Norms range: {norms.min().item():.6f} to {norms.max().item():.6f}")
                
                return final_embeddings
                
        except Exception as e:
            logger.error(f"Error in encode_sequence: {e}")
            return torch.zeros(len(sequences), self.hidden_size).to(device)
        
        finally:
            # Restore training mode
            if was_training:
                self.train()
    
    def mean_pooling(self, model_output, attention_mask):
        """FIXED: Mean pooling with better numerical stability"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        
        # FIXED: Better numerical stability
        masked_embeddings = token_embeddings * input_mask_expanded
        summed_embeddings = torch.sum(masked_embeddings, 1)
        summed_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        
        return summed_embeddings / summed_mask

class EmbedderFixed:
    """FIXED: Embedding class with better error handling and validation"""
    
    def __init__(self, embedder_name, device='cuda'):
        self.embedder_name = embedder_name
        self.device = device
        self.root = '/root/nfs/AJ FinRag/Models/'
        
        if self.embedder_name == 'FinQuest':
            logger.info('Loading trained FinQuest model...')
            model_path = os.path.join(self.root, 'finquest_models/finquest_retriever_best.pth')
            self.model = self._load_finquest_model(model_path)
            self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        else:
            raise ValueError(f"Embedder {embedder_name} not supported. Only 'FinQuest' is available.")
    
    def _load_finquest_model(self, model_path):
        """FIXED: Load model with better error handling and validation"""
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location=self.device)
        
        # Validate checkpoint
        if 'model_state_dict' not in checkpoint:
            raise KeyError("No 'model_state_dict' found in checkpoint")
        
        # Check for NaN parameters
        state_dict = checkpoint['model_state_dict']
        nan_params = []
        for name, param in state_dict.items():
            if torch.isnan(param).any():
                nan_params.append(name)
        
        if nan_params:
            logger.error(f"NaN parameters found in checkpoint: {nan_params}")
            raise ValueError("Checkpoint contains NaN parameters - model is corrupted")
        
        # Create and load model
        model = FinQuestRetrieverFixed().to(self.device)
        model.load_state_dict(state_dict)
        
        # Attach tokenizer
        tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
        model.tokenizer = tokenizer
        
        # Set to eval mode
        model.eval()
        
        # FIXED: Test model with dummy input to ensure it works
        test_input = "Test sequence for validation"
        try:
            with torch.no_grad():
                test_embedding = model.encode_sequence([test_input])
                if torch.isnan(test_embedding).any():
                    raise ValueError("Model produces NaN embeddings")
                if torch.norm(test_embedding) < 1e-6:
                    logger.warning("Model produces very small embeddings - possible collapse")
        except Exception as e:
            logger.error(f"Model validation failed: {e}")
            raise
        
        logger.info(f" FinQuest model loaded and validated successfully")
        return model
    
    def embed_sequences_batch(self, sequences, batch_size=32):
        """FIXED: Batch processing with better memory management"""
        all_embeddings = []
        
        for i in range(0, len(sequences), batch_size):
            batch = sequences[i:i + batch_size]
            
            try:
                with torch.no_grad():
                    batch_embeddings = self.model.encode_sequence(batch)
                    
                    # Validate batch results
                    if torch.isnan(batch_embeddings).any():
                        logger.error(f"NaN embeddings in batch {i//batch_size + 1}")
                        # Replace NaN with zeros
                        batch_embeddings = torch.where(
                            torch.isnan(batch_embeddings),
                            torch.zeros_like(batch_embeddings),
                            batch_embeddings
                        )
                    
                    # Move to CPU to save GPU memory
                    batch_embeddings = batch_embeddings.cpu().numpy()
                    all_embeddings.extend(batch_embeddings)
                    
            except Exception as e:
                logger.error(f"Error in batch {i//batch_size + 1}: {e}")
                # Add zero embeddings for failed batch
                zero_embeddings = np.zeros((len(batch), self.model.hidden_size))
                all_embeddings.extend(zero_embeddings)
            
            # Clear GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return all_embeddings
    
    def embed_queries_in_parallel(self, queries_on_date):
        """FIXED: Embed queries with better error handling"""
        query_strs = []
        
        for query in queries_on_date:
            try:
                query_str = self._get_query_str(query)
                query_strs.append(query_str)
            except Exception as e:
                logger.error(f"Error converting query to string: {e}")
                # Use a default string for failed queries
                query_strs.append("{'query_stock': 'Unknown', 'query_date': 'Unknown', 'recent_date_list': [], 'adjusted_close_list': []}")
        
        # Generate embeddings
        embeddings = self.embed_sequences_batch(query_strs)
        
        # Create results
        results = []
        for i, query in enumerate(queries_on_date):
            results.append({
                'data': query,
                'embedding': embeddings[i]
            })
        
        return results
    
    def embed_candidates_in_parallel(self, candidates_on_date):
        """FIXED: Embed candidates with better error handling"""
        candidate_strs = []
        
        for candidate in candidates_on_date:
            try:
                candidate_str = self._get_candidate_str(candidate)
                candidate_strs.append(candidate_str)
            except Exception as e:
                logger.error(f"Error converting candidate to string: {e}")
                # Use a default string for failed candidates
                candidate_strs.append("{'candidate_stock': 'Unknown', 'candidate_date': 'Unknown', 'recent_date_list': [], 'values_list': []}")
        
        # Generate embeddings
        embeddings = self.embed_sequences_batch(candidate_strs)
        
        # Create results
        results = []
        for i, candidate in enumerate(candidates_on_date):
            results.append({
                'data': candidate,
                'embedding': embeddings[i]
            })
        
        return results
    
    def _get_query_str(self, query):
        """FIXED: Convert query to string with better error handling"""
        if 'query_str' in query and query['query_str']:
            return query['query_str']
        
        # FIXED: Ensure all fields have defaults and are properly formatted
        seq_dict = {
            'query_stock': str(query.get('query_stock', 'Unknown')),
            'query_date': str(query.get('query_date', 'Unknown')),
            'recent_date_list': list(query.get('recent_date_list', [])),
            'adjusted_close_list': [float(x) if x is not None else 0.0 for x in query.get('adjusted_close_list', [])]
        }
        
        return str(seq_dict)
    
    def _get_candidate_str(self, candidate):
        """FIXED: Convert candidate to string with better error handling"""
        if 'candidate_str' in candidate and candidate['candidate_str']:
            return candidate['candidate_str']
        
        # Find the indicator key
        indicator_key = None
        for key in candidate.keys():
            if key.endswith('_list') and key != 'recent_date_list':
                indicator_key = key
                break
        
        if indicator_key is None:
            indicator_key = 'values_list'
            candidate[indicator_key] = []
        
        # FIXED: Ensure all fields are properly formatted
        seq_dict = {
            'candidate_stock': str(candidate.get('candidate_stock', 'Unknown')),
            'candidate_date': str(candidate.get('candidate_date', 'Unknown')),
            'recent_date_list': list(candidate.get('recent_date_list', [])),
            indicator_key: [float(x) if x is not None else 0.0 for x in candidate.get(indicator_key, [])]
        }
        
        return str(seq_dict)

class DatastoreFixed:
    """FIXED: Datastore with better error handling"""
    
    def __init__(self, test_dataset, mode='test'):
        self.test_dataset = test_dataset
        self.mode = mode
        
        # Data paths
        if mode == 'test':
            self.query_path = '/root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_qlist.json'
            self.candidate_path = '/root/nfs/AJ FinRag/Query Candidate/llm_data/all_companies_train_clist.json'
        
        # Load data with error handling
        self.queries = self._load_json_file_safe(self.query_path)
        self.candidates = self._load_json_file_safe(self.candidate_path)
        
        logger.info(f"Loaded {len(self.queries)} queries and {len(self.candidates)} candidates")
        
        # Validate data
        self._validate_data()
    
    def _load_json_file_safe(self, file_path):
        """FIXED: Load JSON file with better error handling"""
        if not os.path.exists(file_path):
            logger.error(f"File not found: {file_path}")
            return []
        
        data = []
        line_num = 0
        errors = 0
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    line_num += 1
                    line = line.strip()
                    
                    if not line:  # Skip empty lines
                        continue
                    
                    try:
                        item = json.loads(line)
                        data.append(item)
                    except json.JSONDecodeError as e:
                        errors += 1
                        if errors <= 10:  # Log first 10 errors
                            logger.warning(f"JSON decode error on line {line_num}: {e}")
                        continue
                    except Exception as e:
                        errors += 1
                        if errors <= 10:
                            logger.warning(f"Error on line {line_num}: {e}")
                        continue
        
        except Exception as e:
            logger.error(f"Error reading file {file_path}: {e}")
            return []
        
        if errors > 0:
            logger.warning(f"Encountered {errors} errors while loading {file_path}")
        
        return data
    
    def _validate_data(self):
        """FIXED: Validate loaded data"""
        # Check queries
        valid_queries = 0
        for query in self.queries:
            if self._is_valid_query(query):
                valid_queries += 1
        
        # Check candidates
        valid_candidates = 0
        for candidate in self.candidates:
            if self._is_valid_candidate(candidate):
                valid_candidates += 1
        
        logger.info(f"Valid queries: {valid_queries}/{len(self.queries)}")
        logger.info(f"Valid candidates: {valid_candidates}/{len(self.candidates)}")
        
        if valid_queries < len(self.queries) * 0.9:
            logger.warning("More than 10% of queries are invalid!")
        
        if valid_candidates < len(self.candidates) * 0.9:
            logger.warning("More than 10% of candidates are invalid!")
    
    def _is_valid_query(self, query):
        """Check if query has required fields"""
        required_fields = ['query_stock', 'query_date']
        return all(field in query and query[field] for field in required_fields)
    
    def _is_valid_candidate(self, candidate):
        """Check if candidate has required fields"""
        required_fields = ['candidate_stock', 'candidate_date']
        return all(field in candidate and candidate[field] for field in required_fields)
    
    def group_no_freeze_query_str_by_date(self):
        """FIXED: Group queries by date with better filtering"""
        queries_by_date = {}
        skipped = 0
        
        for query in self.queries:
            # Skip invalid queries
            if not self._is_valid_query(query):
                skipped += 1
                continue
            
            # Filter out freeze movements
            movement = query.get('movement', 'freeze')
            if movement == 'freeze':
                skipped += 1
                continue
            
            date = query.get('query_date', 'Unknown')
            if date == 'Unknown':
                skipped += 1
                continue
            
            if date not in queries_by_date:
                queries_by_date[date] = []
            queries_by_date[date].append(query)
        
        logger.info(f"Grouped queries by date: {len(queries_by_date)} dates, skipped {skipped} invalid/freeze queries")
        return queries_by_date
    
    def group_candidate_str_by_date(self):
        """FIXED: Group candidates by date with better filtering"""
        candidates_by_date = {}
        skipped = 0
        
        for candidate in self.candidates:
            # Skip invalid candidates
            if not self._is_valid_candidate(candidate):
                skipped += 1
                continue
            
            date = candidate.get('candidate_date', 'Unknown')
            if date == 'Unknown':
                skipped += 1
                continue
            
            if date not in candidates_by_date:
                candidates_by_date[date] = []
            candidates_by_date[date].append(candidate)
        
        logger.info(f"Grouped candidates by date: {len(candidates_by_date)} dates, skipped {skipped} invalid candidates")
        return candidates_by_date

def generate_embeddings_fixed(test_dataset, embedder_name, q_or_c, device='cuda'):
    """FIXED: Main function with better error handling and validation"""
    logger.info(f"Starting {q_or_c} embedding generation for {embedder_name}")
    
    try:
        # Initialize components
        datastore = DatastoreFixed(test_dataset, 'test')
        embedder = EmbedderFixed(embedder_name, device)
        
        # Create output directory
        output_dir = f'/root/nfs/AJ FinRag/Embeddings/embeddings/test/{embedder_name}'
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Output directory: {output_dir}")
        
        if q_or_c == 'query':
            logger.info("Generating query embeddings...")
            
            # Get grouped queries
            queries_by_date = datastore.group_no_freeze_query_str_by_date()
            
            if not queries_by_date:
                logger.error("No valid queries found!")
                return 1
            
            # Generate embeddings
            query_embedding_list = []
            total_queries = 0
            
            for date, queries_on_date in tqdm(queries_by_date.items(), desc="Processing query dates"):
                try:
                    query_embeddings_on_date = embedder.embed_queries_in_parallel(queries_on_date)
                    query_embedding_list.append({date: query_embeddings_on_date})
                    total_queries += len(queries_on_date)
                    
                except Exception as e:
                    logger.error(f"Error processing queries for date {date}: {e}")
                    continue
            
            # Save results
            output_file = os.path.join(output_dir, f'q_{test_dataset}_{embedder_name}_embeddings.pkl')
            
            try:
                with open(output_file, 'wb') as f:
                    pickle.dump(query_embedding_list, f)
                logger.info(f" Query embeddings saved: {output_file}")
                logger.info(f"Total queries embedded: {total_queries}")
            except Exception as e:
                logger.error(f"Error saving query embeddings: {e}")
                return 1
        
        elif q_or_c == 'candidate':
            logger.info("Generating candidate embeddings...")
            
            # Get grouped candidates
            candidates_by_date = datastore.group_candidate_str_by_date()
            
            if not candidates_by_date:
                logger.error("No valid candidates found!")
                return 1
            
            # Generate embeddings in groups
            candidate_embedding_list = []
            total_candidates = 0
            group = 1
            dates_processed = 0
            
            for date, candidates_on_date in tqdm(candidates_by_date.items(), desc="Processing candidate dates"):
                try:
                    candidate_embeddings_on_date = embedder.embed_candidates_in_parallel(candidates_on_date)
                    candidate_embedding_list.append({date: candidate_embeddings_on_date})
                    total_candidates += len(candidates_on_date)
                    dates_processed += 1
                    
                    # Save every 10 dates
                    if dates_processed % 10 == 0:
                        output_file = os.path.join(output_dir, f'c_{test_dataset}_{embedder_name}_embeddings_{group}.pkl')
                        
                        try:
                            with open(output_file, 'wb') as f:
                                pickle.dump(candidate_embedding_list, f)
                            logger.info(f" Saved candidate group {group}: {output_file}")
                        except Exception as e:
                            logger.error(f"Error saving candidate group {group}: {e}")
                        
                        candidate_embedding_list = []
                        group += 1
                
                except Exception as e:
                    logger.error(f"Error processing candidates for date {date}: {e}")
                    continue
            
            # Save final group
            if candidate_embedding_list:
                output_file = os.path.join(output_dir, f'c_{test_dataset}_{embedder_name}_embeddings_{group}.pkl')
                
                try:
                    with open(output_file, 'wb') as f:
                        pickle.dump(candidate_embedding_list, f)
                    logger.info(f" Saved final candidate group {group}: {output_file}")
                except Exception as e:
                    logger.error(f"Error saving final candidate group: {e}")
            
            logger.info(f"Total candidates embedded: {total_candidates}")
            logger.info(f"Total groups saved: {group}")
        
        else:
            logger.error(f"Invalid q_or_c parameter: {q_or_c}")
            return 1
        
        logger.info(" Embedding generation completed successfully!")
        return 0
        
    except Exception as e:
        logger.error(f"Fatal error in embedding generation: {e}")
        import traceback
        traceback.print_exc()
        return 1

def test_embeddings_quality():
    """Test the quality of generated embeddings"""
    logger.info("🧪 Testing embedding quality...")
    
    try:
        embedder = EmbedderFixed('FinQuest')
        
        # Test with diverse sequences
        test_sequences = [
            "{'query_stock': 'AAPL', 'query_date': '2024-01-15', 'recent_date_list': ['2024-01-01'], 'adjusted_close_list': [150.0]}",
            "{'query_stock': 'TSLA', 'query_date': '2024-01-15', 'recent_date_list': ['2024-01-01'], 'adjusted_close_list': [200.0]}",
            "{'query_stock': 'XOM', 'query_date': '2024-01-15', 'recent_date_list': ['2024-01-01'], 'adjusted_close_list': [80.0]}"
        ]
        
        embeddings = embedder.embed_sequences_batch(test_sequences)
        
        # Calculate similarities
        similarities = []
        for i in range(len(embeddings)):
            for j in range(i+1, len(embeddings)):
                sim = np.dot(embeddings[i], embeddings[j])
                similarities.append(sim)
        
        logger.info(f"Embedding test results:")
        logger.info(f"  Min similarity: {min(similarities):.6f}")
        logger.info(f"  Max similarity: {max(similarities):.6f}")
        logger.info(f"  Mean similarity: {np.mean(similarities):.6f}")
        
        if min(similarities) > 0.95:
            logger.error(" EMBEDDING COLLAPSE DETECTED in generated embeddings!")
            return False
        else:
            logger.info(" Embeddings show reasonable diversity")
            return True
            
    except Exception as e:
        logger.error(f"Error testing embeddings: {e}")
        return False

def generate_both_embeddings_fixed(test_dataset='your_dataset', embedder_name='FinQuest', device='cuda'):
    """Generate both query and candidate embeddings with the fixed version"""
    logger.info(" Starting FIXED FinQuest embedding generation (both query and candidate)")
    
    # Test embedding quality first
    logger.info("Step 0: Testing embedding quality...")
    if not test_embeddings_quality():
        logger.error(" Embedding quality test failed - your model has collapsed embeddings!")
        logger.error("   You need to retrain your FinQuest model before proceeding.")
        return 1
    
    # Generate query embeddings
    logger.info("Step 1: Generating query embeddings...")
    result = generate_embeddings_fixed(test_dataset, embedder_name, 'query', device)
    
    if result != 0:
        logger.error(" Query embedding generation failed!")
        return 1
    
    # Generate candidate embeddings
    logger.info("Step 2: Generating candidate embeddings...")
    result = generate_embeddings_fixed(test_dataset, embedder_name, 'candidate', device)
    
    if result != 0:
        logger.error(" Candidate embedding generation failed!")
        return 1
    
    logger.info(" All embedding generation completed successfully!")
    
    # Verify output files
    output_dir = f'/root/nfs/AJ FinRag/Embeddings/embeddings/test/{embedder_name}'
    query_file = os.path.join(output_dir, f'q_{test_dataset}_{embedder_name}_embeddings.pkl')
    candidate_file = os.path.join(output_dir, f'c_{test_dataset}_{embedder_name}_embeddings_1.pkl')
    
    if os.path.exists(query_file) and os.path.exists(candidate_file):
        logger.info(" All embedding files created successfully!")
        logger.info(f" Query embeddings: {query_file}")
        logger.info(f" Candidate embeddings: {output_dir}")
        
        # Quick validation of generated files
        try:
            with open(query_file, 'rb') as f:
                query_data = pickle.load(f)
            with open(candidate_file, 'rb') as f:
                candidate_data = pickle.load(f)
                
            logger.info(f" Query file contains {len(query_data)} date groups")
            logger.info(f" Candidate file contains {len(candidate_data)} date groups")
            
        except Exception as e:
            logger.error(f" Error validating generated files: {e}")
        
    else:
        logger.error(" Some embedding files missing!")
        if not os.path.exists(query_file):
            logger.error(f"   Missing: {query_file}")
        if not os.path.exists(candidate_file):
            logger.error(f"   Missing: {candidate_file}")
        return 1
    
    return 0

if __name__ == "__main__":
    # Set multiprocessing method
    multiprocessing.set_start_method('spawn', force=True)
    
    # Configuration
    test_dataset = 'your_dataset'
    embedder_name = 'FinQuest'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    logger.info(" FIXED FINQUEST EMBEDDING GENERATION")
    logger.info("="*70)
    logger.info(f"Dataset: {test_dataset}")
    logger.info(f"Embedder: {embedder_name}")
    logger.info(f"Device: {device}")
    logger.info("="*70)
    
    # Generate both embeddings
    result = generate_both_embeddings_fixed(test_dataset, embedder_name, device)
    
    if result == 0:
        logger.info(" SUCCESS: All embeddings generated successfully!")
        logger.info("   You can now run your similarity search and experiments.")
    else:
        logger.error(" FAILED: Embedding generation failed!")
        logger.error("   Please check the error messages above and fix the issues.")
        
        # Provide troubleshooting guidance
        logger.info("\n TROUBLESHOOTING GUIDE:")
        logger.info("1. If embedding quality test failed:")
        logger.info("   → Your FinQuest model has collapsed embeddings")
        logger.info("   → You need to retrain the model with better hyperparameters")
        logger.info("2. If file loading failed:")
        logger.info("   → Check if your data files exist at the specified paths")
        logger.info("   → Verify JSON format in your data files")
        logger.info("3. If model loading failed:")
        logger.info("   → Check if your model checkpoint exists")
        logger.info("   → Verify the model was trained successfully")

INFO:__main__: FIXED FINQUEST EMBEDDING GENERATION
INFO:__main__:Dataset: your_dataset
INFO:__main__:Embedder: FinQuest
INFO:__main__:Device: cuda
INFO:__main__: Starting FIXED FinQuest embedding generation (both query and candidate)
INFO:__main__:Step 0: Testing embedding quality...
INFO:__main__:🧪 Testing embedding quality...
INFO:__main__:Loading trained FinQuest model...
INFO:__main__: FinQuest model loaded and validated successfully
INFO:__main__:Embedding test results:
INFO:__main__:  Min similarity: 0.999875
INFO:__main__:  Max similarity: 0.999902
INFO:__main__:  Mean similarity: 0.999885
ERROR:__main__: EMBEDDING COLLAPSE DETECTED in generated embeddings!
ERROR:__main__: Embedding quality test failed - your model has collapsed embeddings!
ERROR:__main__:   You need to retrain your FinQuest model before proceeding.
ERROR:__main__: FAILED: Embedding generation failed!
ERROR:__main__:   Please check the error messages above and fix the issues.
INFO:__main__:
 TROUBLESHOOTING GUID