In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import json
import numpy as np
from tqdm import tqdm
import os
import glob
from pathlib import Path
import logging
from torch.cuda.amp import autocast, GradScaler

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class FinQuestDataset(Dataset):
    """Dataset for training FinQuest - Financial Pattern Retriever"""
    
    def __init__(self, pos_neg_data_paths, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self.load_multiple_pos_neg_files(pos_neg_data_paths)
        logger.info(f"Loaded {len(self.data)} training examples for FinQuest")

        # Analyze score distribution
        self.analyze_teacher_scores()

    def analyze_teacher_scores(self):
        """Analyze the distribution of teacher probability scores"""
        all_pos_scores = []
        all_neg_scores = []
        separations = []
        zero_score_count = 0
        valid_examples = 0

        for item in self.data:
            scores = item.get('teacher_scores', [])
            if len(scores) == 0:
                continue
                
            pos_score = scores[0] if len(scores) > 0 else 0
            neg_scores = scores[1:] if len(scores) > 1 else []
            
            if pos_score > 0:
                all_pos_scores.append(pos_score)
                valid_examples += 1
                
                if neg_scores:
                    all_neg_scores.extend(neg_scores)
                    avg_neg = np.mean(neg_scores)
                    separations.append(pos_score - avg_neg)
            else:
                zero_score_count += 1

        logger.info(f"FinQuest training data analysis:")
        logger.info(f"  Total examples: {len(self.data)}")
        logger.info(f"  Valid examples (positive score > 0): {valid_examples}")
        logger.info(f"  Examples with zero positive scores: {zero_score_count}")
        
        if all_pos_scores:
            logger.info(f"  Positive score range: {min(all_pos_scores):.4f} - {max(all_pos_scores):.4f}")
            logger.info(f"  Average positive score: {np.mean(all_pos_scores):.4f}")
            
        if all_neg_scores:
            logger.info(f"  Negative score range: {min(all_neg_scores):.4f} - {max(all_neg_scores):.4f}")
            logger.info(f"  Average negative score: {np.mean(all_neg_scores):.4f}")
            
        if separations:
            logger.info(f"  Average pos-neg separation: {np.mean(separations):.4f}")
            logger.info(f"  Separation std: {np.std(separations):.4f}")

    def load_multiple_pos_neg_files(self, paths):
        """Load and combine pos/neg allocated data from multiple files"""
        all_data = []

        # Handle different input types
        if isinstance(paths, str):
            if paths.endswith('.json'):
                paths = [paths]
            else:
                paths = glob.glob(paths)
        elif isinstance(paths, (list, tuple)):
            expanded_paths = []
            for path in paths:
                if '*' in path:
                    expanded_paths.extend(glob.glob(path))
                else:
                    expanded_paths.append(path)
            paths = expanded_paths

        logger.info(f"Processing {len(paths)} pos/neg allocated data files...")

        for path in paths:
            if not os.path.exists(path):
                logger.warning(f"File not found: {path}")
                continue

            logger.info(f"Loading {path}...")
            file_data = self.load_pos_neg_data(path)

            # Add file identifier for tracking
            file_name = Path(path).stem
            for item in file_data:
                item['source_file'] = file_name

            all_data.extend(file_data)
            logger.info(f"Loaded {len(file_data)} examples from {path}")

        return all_data

    def load_pos_neg_data(self, path):
        """Load pos/neg allocated data with better error handling"""
        data = []
        with open(path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    line = line.strip()
                    if line:
                        item = json.loads(line)
                        if self.validate_item(item):
                            data.append(item)
                        else:
                            logger.warning(f"Invalid item structure at {path}:{line_num}")
                except json.JSONDecodeError as e:
                    logger.warning(f"Skipping malformed JSON at {path}:{line_num}: {e}")
                except Exception as e:
                    logger.error(f"Unexpected error at {path}:{line_num}: {e}")
        return data

    def validate_item(self, item):
        """Validate that item has required fields for pos/neg training"""
        required_fields = ['query', 'pos', 'neg', 'teacher_scores']
        has_required = all(field in item for field in required_fields)
        
        if has_required:
            # Additional validation
            pos_list = item.get('pos', [])
            neg_list = item.get('neg', [])
            scores = item.get('teacher_scores', [])
            
            # Check if we have at least one positive and one negative
            if len(pos_list) > 0 and len(neg_list) > 0 and len(scores) >= 2:
                return True
            else:
                logger.debug(f"Item failed validation: pos={len(pos_list)}, neg={len(neg_list)}, scores={len(scores)}")
        
        return False

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            item = self.data[idx]

            # Extract data from pos/neg format
            query = item['query']
            pos_candidates = item['pos']  # Should be a list with 1 item
            neg_candidates = item['neg']  # Should be a list with multiple items
            teacher_scores = item['teacher_scores']  # [pos_score, neg_score1, neg_score2, ...]

            # Validate lengths
            expected_scores = len(pos_candidates) + len(neg_candidates)
            if len(teacher_scores) != expected_scores:
                logger.warning(f"Score mismatch at idx {idx}: expected {expected_scores}, got {len(teacher_scores)}")
                return None

            # Skip if no positive or negative candidates
            if len(pos_candidates) == 0 or len(neg_candidates) == 0:
                return None

            # Skip examples with zero positive score
            pos_score = teacher_scores[0]
            if pos_score <= 0.0:
                return None

            # Get the positive candidate (should be just one)
            positive_candidate = pos_candidates[0]
            
            # Get negative candidates and their scores
            neg_scores = teacher_scores[1:len(neg_candidates)+1]

            return {
                'query': query,
                'positive': positive_candidate,
                'negatives': neg_candidates,
                'teacher_scores_selected': teacher_scores,  # Keep all scores
                'pos_score': float(pos_score),
                'neg_scores': [float(s) for s in neg_scores],
                'source_file': item.get('source_file', 'unknown'),
                'query_stock': item.get('query_stock', 'unknown'),
                'query_date': item.get('query_date', 'unknown')
            }

        except Exception as e:
            logger.error(f"Error processing item at index {idx}: {e}")
            return None

    def collate_fn(self, batch):
        """Custom collate function for FinQuest training"""
        batch = [item for item in batch if item is not None]
        if not batch:
            return None

        if len(batch) == 1:
            logger.warning("Batch size of 1 may cause training instability")

        queries = [item['query'] for item in batch]
        positives = [item['positive'] for item in batch]
        source_files = [item['source_file'] for item in batch]
        query_stocks = [item['query_stock'] for item in batch]
        query_dates = [item['query_date'] for item in batch]

        # Handle variable number of negatives
        max_num_negatives = max(len(item['negatives']) for item in batch)

        negatives_padded = []
        negatives_mask = []
        teacher_scores_padded = []

        for item in batch:
            num_negatives = len(item['negatives'])
            
            # Pad negatives
            padded_negs = item['negatives'] + [''] * (max_num_negatives - num_negatives)
            mask = [1] * num_negatives + [0] * (max_num_negatives - num_negatives)

            # Pad teacher scores
            # Format: [pos_score, neg_score1, neg_score2, ..., padding_zeros]
            scores = item['teacher_scores_selected']
            padded_scores = scores + [0.0] * (max_num_negatives - (len(scores) - 1))

            negatives_padded.append(padded_negs)
            negatives_mask.append(mask)
            teacher_scores_padded.append(padded_scores)

        return {
            'query': queries,
            'positive': positives,
            'negatives_padded': negatives_padded,
            'negatives_mask': torch.tensor(negatives_mask, dtype=torch.bool),
            'teacher_scores_padded_tensor': torch.tensor(teacher_scores_padded, dtype=torch.float32),
            'source_files': source_files,
            'query_stocks': query_stocks,
            'query_dates': query_dates
        }

class FinQuestRetriever(nn.Module):
    """FinQuest - Advanced Financial Pattern Retriever"""
    
    def __init__(self, model_name='sentence-transformers/all-MiniLM-L6-v2', hidden_size=384, dropout_rate=0.1):
        super().__init__()
        
        # Base encoder
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Financial domain adaptation layers - SIMPLIFIED TO PREVENT COLLAPSE
        encoder_dim = self.encoder.config.hidden_size
        self.projection = nn.Sequential(
            nn.Linear(encoder_dim, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        self.dropout = nn.Dropout(dropout_rate)
        self.hidden_size = hidden_size
        
        logger.info(f"FinQuest Retriever initialized with {sum(p.numel() for p in self.parameters()):,} parameters")

    def encode_sequence(self, sequences):
        """Encode financial sequences with domain adaptation"""
        if not sequences or all(not seq.strip() for seq in sequences):
            # Handle empty sequences
            return torch.zeros(len(sequences), self.hidden_size).to(next(self.parameters()).device)

        inputs = self.tokenizer(
            sequences,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(next(self.parameters()).device)

        # Use mixed precision for encoding
        with autocast():
            # Get base embeddings
            outputs = self.encoder(**inputs)
            # Use mean pooling instead of just CLS token
            embeddings = self.mean_pooling(outputs, inputs['attention_mask'])

        # Project to financial domain in full precision
        projected = self.projection(embeddings.float())
        
        # Apply dropout only during training
        if self.training:
            projected = self.dropout(projected)
        
        # DON'T NORMALIZE HERE - let loss functions handle it
        return projected
    
    def mean_pooling(self, model_output, attention_mask):
        """Mean pooling with attention mask"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class FinancialContrastiveLoss(nn.Module):
    """Contrastive loss optimized for financial patterns"""
    
    def __init__(self, temperature=0.2):  # INCREASED TEMPERATURE
        super().__init__()
        self.temperature = temperature

    def forward(self, query_emb, pos_emb, neg_embs, neg_mask):
        # NORMALIZE HERE
        query_emb = F.normalize(query_emb, p=2, dim=1)
        pos_emb = F.normalize(pos_emb, p=2, dim=1)
        neg_embs = F.normalize(neg_embs, p=2, dim=2)
        
        # Compute similarities
        pos_sim = torch.sum(query_emb * pos_emb, dim=1) / self.temperature
        neg_sims = torch.bmm(query_emb.unsqueeze(1), neg_embs.transpose(1, 2)).squeeze(1) / self.temperature

        # Mask out padded negatives
        neg_sims = neg_sims.masked_fill(~neg_mask, -1e9)

        # Concatenate positive and negative similarities
        logits = torch.cat([pos_sim.unsqueeze(1), neg_sims], dim=1)

        # Positive examples should have label 0 (first position)
        targets = torch.zeros(logits.size(0), dtype=torch.long).to(logits.device)

        return F.cross_entropy(logits, targets)

class FinancialKnowledgeDistillation(nn.Module):
    """Knowledge distillation loss for financial relevance scores"""
    
    def __init__(self, temperature=2.0):  # REDUCED TEMPERATURE
        super().__init__()
        self.temperature = temperature

    def forward(self, query_emb, candidate_embs, teacher_scores_padded, candidate_mask):
        # NORMALIZE HERE TOO
        query_emb = F.normalize(query_emb, p=2, dim=1)
        candidate_embs = F.normalize(candidate_embs, p=2, dim=2)
        
        # Student similarities
        student_sims = torch.bmm(query_emb.unsqueeze(1), candidate_embs.transpose(1, 2)).squeeze(1) / self.temperature
        student_sims = student_sims.masked_fill(~candidate_mask, -1e9)

        # Better teacher score handling
        teacher_scores_clamped = torch.clamp(teacher_scores_padded, min=0.01, max=0.99)
        teacher_probs = F.softmax(teacher_scores_clamped / self.temperature, dim=1)
        
        # Mask teacher probs and renormalize
        teacher_probs = teacher_probs * candidate_mask.float()
        teacher_probs = teacher_probs / (teacher_probs.sum(dim=1, keepdim=True) + 1e-8)
        
        student_log_probs = F.log_softmax(student_sims, dim=1)

        # KL divergence with better stability
        return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')

def compute_embedding_diversity_loss(embeddings, min_distance=0.1):
    """Encourage embedding diversity to prevent collapse"""
    embeddings_norm = F.normalize(embeddings, p=2, dim=1)
    
    batch_size = embeddings_norm.size(0)
    if batch_size < 2:
        return torch.tensor(0.0, device=embeddings.device)
    
    # Compute cosine similarities
    similarities = torch.mm(embeddings_norm, embeddings_norm.t())
    
    # Remove diagonal (self-similarities)
    mask = ~torch.eye(batch_size).bool().to(similarities.device)
    off_diagonal_sims = similarities[mask]
    
    # Encourage low similarities (diverse directions)
    diversity_loss = torch.mean(torch.relu(off_diagonal_sims - (1.0 - min_distance)))
    
    return diversity_loss

def train_finquest(model, dataloader, optimizer, device, num_epochs=15, save_every=3, save_dir="finquest_models"):
    """Train FinQuest retriever with advanced loss combination"""
    
    os.makedirs(save_dir, exist_ok=True)
    
    model.train()
    contrastive_loss_fn = FinancialContrastiveLoss(temperature=0.2)  # FIXED TEMPERATURE
    kd_loss_fn = FinancialKnowledgeDistillation(temperature=2.0)  # FIXED TEMPERATURE
    scaler = GradScaler()

    best_loss = float('inf')
    training_history = []

    # Advanced gradient clipping
    max_grad_norm = 1.0

    for epoch in range(num_epochs):
        total_loss = total_cl = total_kd = total_div = 0
        num_batches = 0
        skipped_batches = 0

        progress = tqdm(dataloader, desc=f'FinQuest Epoch {epoch+1}/{num_epochs}')

        for batch_idx, batch in enumerate(progress):
            if batch is None:
                continue

            optimizer.zero_grad()

            try:
                # Validate teacher scores
                teacher_scores = batch['teacher_scores_padded_tensor'].to(device)
                if torch.any(torch.isnan(teacher_scores)) or torch.any(torch.isinf(teacher_scores)):
                    logger.warning(f"Skipping batch {batch_idx}: NaN/Inf in teacher scores")
                    skipped_batches += 1
                    continue

                # Robust score clamping for financial data
                teacher_scores = torch.clamp(teacher_scores, min=0.0, max=1.0)

                # Forward pass with mixed precision
                with autocast():
                    # Encode query and positive
                    q_emb = model.encode_sequence(batch['query'])
                    p_emb = model.encode_sequence(batch['positive'])

                    # Flatten and encode all negatives
                    all_neg_seqs = [neg for negs in batch['negatives_padded'] for neg in negs if neg.strip()]
                    if not all_neg_seqs:  # Handle edge case
                        skipped_batches += 1
                        continue
                        
                    n_embs = model.encode_sequence(all_neg_seqs).view(len(batch['query']), -1, q_emb.size(-1))

                # Compute losses in full precision
                with torch.cuda.amp.autocast(enabled=False):
                    # Convert to float32 for stable loss computation
                    q_emb_fp32 = q_emb.float()
                    p_emb_fp32 = p_emb.float()
                    n_embs_fp32 = n_embs.float()

                    # Financial contrastive loss
                    cl_loss = contrastive_loss_fn(q_emb_fp32, p_emb_fp32, n_embs_fp32, batch['negatives_mask'].to(device))

                    # Financial knowledge distillation loss
                    candidate_embs_fp32 = torch.cat([p_emb_fp32.unsqueeze(1), n_embs_fp32], dim=1)
                    candidate_mask = torch.cat([
                        torch.ones(len(batch['query']), 1, dtype=torch.bool).to(device),
                        batch['negatives_mask'].to(device)
                    ], dim=1)

                    kd_loss = kd_loss_fn(
                        q_emb_fp32,
                        candidate_embs_fp32,
                        teacher_scores,
                        candidate_mask
                    )

                    # DIVERSITY LOSS TO PREVENT COLLAPSE
                    all_embeddings = torch.cat([q_emb_fp32, p_emb_fp32], dim=0)
                    diversity_loss = compute_embedding_diversity_loss(all_embeddings, min_distance=0.15)

                    # BALANCED LOSS COMBINATION
                    cl_loss = torch.clamp(cl_loss, max=10.0)
                    kd_loss = torch.clamp(kd_loss, max=10.0)
                    diversity_loss = torch.clamp(diversity_loss, max=2.0)
                    
                    # FIXED WEIGHTS - NO ADAPTIVE
                    loss = 0.5 * cl_loss + 0.4 * kd_loss + 0.1 * diversity_loss

                # Validate loss before backward pass
                if torch.isnan(loss) or torch.isinf(loss) or loss > 100:
                    logger.warning(f"Skipping batch {batch_idx}: Invalid loss {loss.item()}")
                    skipped_batches += 1
                    continue

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                # Advanced gradient clipping
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

                scaler.step(optimizer)
                scaler.update()

                # Update metrics
                total_loss += loss.item()
                total_cl += cl_loss.item()
                total_kd += kd_loss.item()
                total_div += diversity_loss.item()
                num_batches += 1

                # Enhanced progress bar
                progress.set_postfix({
                    'Loss': f'{loss.item():.4f}',
                    'CL': f'{cl_loss.item():.3f}',
                    'KD': f'{kd_loss.item():.3f}',
                    'Div': f'{diversity_loss.item():.3f}',
                    'GN': f'{grad_norm:.2f}',
                    'Skip': skipped_batches
                })

            except Exception as e:
                logger.error(f"Error in batch {batch_idx}: {e}")
                skipped_batches += 1
                continue

        # Epoch summary with EMBEDDING QUALITY CHECK
        if num_batches > 0:
            avg_loss = total_loss / num_batches
            avg_cl = total_cl / num_batches
            avg_kd = total_kd / num_batches
            avg_div = total_div / num_batches

            epoch_stats = {
                'epoch': epoch + 1,
                'avg_loss': avg_loss,
                'avg_contrastive': avg_cl,
                'avg_kd': avg_kd,
                'avg_diversity': avg_div,
                'skipped_batches': skipped_batches,
                'processed_batches': num_batches
            }
            training_history.append(epoch_stats)

            logger.info(f'FinQuest Epoch {epoch+1} Summary:')
            logger.info(f'  Avg Loss: {avg_loss:.4f} (CL: {avg_cl:.3f}, KD: {avg_kd:.3f}, Div: {avg_div:.3f})')
            logger.info(f'  Processed: {num_batches} batches, Skipped: {skipped_batches}')

            # EMBEDDING QUALITY CHECK EVERY 2 EPOCHS
            if epoch % 2 == 0:
                model.eval()
                with torch.no_grad():
                    test_seqs = [
                        "Apple financial data analysis",
                        "Tesla stock performance metrics",
                        "Google revenue growth patterns",
                        "Microsoft quarterly earnings",
                        "Amazon business indicators"
                    ]
                    test_embs = model.encode_sequence(test_seqs)
                    test_embs_norm = F.normalize(test_embs, p=2, dim=1)
                    
                    # Compute similarities
                    sims = torch.mm(test_embs_norm, test_embs_norm.t())
                    mask = ~torch.eye(len(test_seqs)).bool()
                    off_diagonal = sims[mask]
                    
                    min_sim = torch.min(off_diagonal).item()
                    max_sim = torch.max(off_diagonal).item()
                    avg_sim = torch.mean(off_diagonal).item()
                    std_sim = torch.std(off_diagonal).item()
                    
                    logger.info(f"  Embedding Quality Check:")
                    logger.info(f"    Similarity range: {min_sim:.4f} - {max_sim:.4f}")
                    logger.info(f"    Avg similarity: {avg_sim:.4f} (±{std_sim:.4f})")
                    
                    # COLLAPSE DETECTION
                    if avg_sim > 0.9:
                        logger.warning("  ⚠️  HIGH COLLAPSE RISK!")
                    elif avg_sim > 0.8:
                        logger.warning("  ⚠️  Moderate collapse risk")
                    else:
                        logger.info("  ✅ Healthy embedding diversity")
                
                model.train()

            # Save best model
            if avg_loss < best_loss:
                best_loss = avg_loss
                best_model_path = os.path.join(save_dir, 'finquest_retriever_best.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                    'training_history': training_history,
                    'model_config': {
                        'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
                        'hidden_size': model.hidden_size,
                        'retriever_type': 'FinQuest'
                    }
                }, best_model_path)
                logger.info(f"Saved best FinQuest model with loss {avg_loss:.4f}")

            # Regular checkpoint
            if (epoch + 1) % save_every == 0:
                checkpoint_path = os.path.join(save_dir, f'finquest_retriever_epoch_{epoch+1}.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                    'training_history': training_history
                }, checkpoint_path)
                logger.info(f"Saved checkpoint: epoch {epoch+1}")
        else:
            logger.warning(f"Epoch {epoch+1}: No valid batches processed!")

        # Memory cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return training_history

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"Training FinQuest Retriever on device: {device}")

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

    # Your pos/neg allocated data file path
    pos_neg_data_path = "/root/nfs/AJ FinRag/Training Data/llm_data/all_companies_train_pos_neg.json"

    # Check if the file exists
    if not os.path.exists(pos_neg_data_path):
        logger.error(f"Pos/neg data file not found: {pos_neg_data_path}")
        logger.error("Please ensure your pos/neg allocation step completed successfully")
        logger.error("Expected format: {'query': '...', 'pos': [...], 'neg': [...], 'teacher_scores': [...]}")
        return

    logger.info(f"Loading pos/neg allocated data from: {pos_neg_data_path}")

    # Initialize FinQuest dataset
    dataset = FinQuestDataset(pos_neg_data_path, tokenizer)

    # Validate dataset
    if len(dataset) == 0:
        logger.error("FinQuest dataset is empty! Please check your pos/neg data file.")
        return

    logger.info(f"FinQuest dataset loaded successfully with {len(dataset)} examples")

    # Initialize FinQuest model
    model = FinQuestRetriever(hidden_size=384, dropout_rate=0.1).to(device)

    # Dynamic batch size based on GPU memory
    try:
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # GB
            if gpu_memory < 6:
                batch_size = 2
            elif gpu_memory < 12:
                batch_size = 4
            elif gpu_memory < 24:
                batch_size = 8
            else:
                batch_size = 12
        else:
            batch_size = 2
    except:
        batch_size = 4  # Conservative default

    logger.info(f"Using batch size: {batch_size} " + 
               (f"(GPU memory: {gpu_memory:.1f}GB)" if torch.cuda.is_available() else "(CPU)"))

    # Create optimized dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
        num_workers=0,  # Keep at 0 for stability
        pin_memory=True if torch.cuda.is_available() else False,
        drop_last=True,
        persistent_workers=False
    )

    # BETTER OPTIMIZER SETTINGS
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-5,  # Slightly higher LR
        weight_decay=0.01,
        eps=1e-8,
        betas=(0.9, 0.999)
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.7, patience=3, verbose=True, min_lr=1e-7
    )

    logger.info(f"Starting FinQuest training:")
    logger.info(f"  Training examples: {len(dataset):,}")
    logger.info(f"  Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    logger.info(f"  Estimated memory usage: ~{batch_size * 512 * 384 * 4 / 1e9:.2f}GB")

    # Train FinQuest
    training_history = train_finquest(
        model=model, 
        dataloader=dataloader, 
        optimizer=optimizer, 
        device=device, 
        num_epochs=20,  # More epochs for financial domain
        save_every=4,
        save_dir="finquest_models"
    )

    # Final model save
    final_model_path = "finquest_models/finquest_retriever_final.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_config': {
            'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
            'hidden_size': model.hidden_size,
            'retriever_type': 'FinQuest',
            'training_examples': len(dataset),
            'final_epoch': len(training_history)
        },
        'training_history': training_history
    }, final_model_path)

    # Training completion summary
    logger.info("\n" + "="*60)
    logger.info("FinQuest Training Completed Successfully!")
    logger.info("="*60)
    logger.info("Saved models:")
    logger.info("  - finquest_models/finquest_retriever_best.pth (best validation loss)")
    logger.info("  - finquest_models/finquest_retriever_final.pth (final model)")
    logger.info("  - finquest_models/finquest_retriever_epoch_X.pth (checkpoints)")
    if training_history:
        logger.info(f"Final training loss: {training_history[-1]['avg_loss']:.4f}")
        logger.info(f"Best training loss: {min(h['avg_loss'] for h in training_history):.4f}")
    logger.info("\nFinQuest is ready for financial pattern retrieval!")

if __name__ == "__main__":
    main()

INFO:__main__:Training FinQuest Retriever on device: cuda
INFO:__main__:Loading pos/neg allocated data from: /root/nfs/AJ FinRag/Training Data/llm_data/all_companies_train_pos_neg.json
INFO:__main__:Processing 1 pos/neg allocated data files...
INFO:__main__:Loading /root/nfs/AJ FinRag/Training Data/llm_data/all_companies_train_pos_neg.json...
INFO:__main__:Loaded 13210 examples from /root/nfs/AJ FinRag/Training Data/llm_data/all_companies_train_pos_neg.json
INFO:__main__:Loaded 13210 training examples for FinQuest
INFO:__main__:FinQuest training data analysis:
INFO:__main__:  Total examples: 13210
INFO:__main__:  Valid examples (positive score > 0): 13210
INFO:__main__:  Examples with zero positive scores: 0
INFO:__main__:  Positive score range: 0.0100 - 0.9761
INFO:__main__:  Average positive score: 0.4816
INFO:__main__:  Negative score range: 0.0100 - 0.9571
INFO:__main__:  Average negative score: 0.3197
INFO:__main__:  Average pos-neg separation: 0.1619
INFO:__main__:  Separation st