In [3]:
import pandas as pd
import json
import re
from datetime import datetime
import logging
import os
from dataclasses import dataclass
import concurrent.futures
import pickle  # For checkpointing
import openai  # Assuming DeepSeek uses OpenAI-compatible client
import time
import requests
from typing import Dict, List, Optional

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class APIConfig:
    """Configuration for different API providers"""
    provider: str  # 'deepseek'
    api_key: str
    base_url: Optional[str] = None
    model: str = "DeepSeek-R1-0528"  # Updated to your specified model
    max_tokens: int = 1500
    temperature: float = 0.2  # Low to reduce hallucinations
    timeout: int = 30
    rate_limit_delay: float = 0.5  # seconds between requests

'''
class APIClient:
    """Generic API client for DeepSeek"""
    
    def __init__(self, config: APIConfig):
        self.config = config
        self.client = openai.OpenAI(
            api_key=config.api_key,
            base_url=config.base_url or "https://api.deepseek.com"
        )
    
    def generate_completion(self, prompt: str, max_tokens: Optional[int] = None) -> str:
        """Generate completion using the configured API"""
        max_tokens = max_tokens or self.config.max_tokens
        
        try:
            response = self.client.chat.completions.create(
                model=self.config.model,
                messages=[
                    {"role": "system", "content": "You are a helpful AI assistant specialized in therapeutic conversation analysis. Stick strictly to the user's instructions without adding external knowledge or inferring unstated information."},  # Added to reduce hallucinations
                    {"role": "user", "content": prompt}
                ],
                max_tokens=max_tokens,
                temperature=self.config.temperature,
                timeout=self.config.timeout
            )
            return response.choices[0].message.content.strip()
        
        except Exception as e:
            logger.error(f"API request failed: {e}")
            time.sleep(self.config.rate_limit_delay * 2)
            raise
        
        finally:
            time.sleep(self.config.rate_limit_delay)
'''

class APIClient:
    """Generic API client for DeepSeek"""
    
    def __init__(self, config: APIConfig):
        self.config = config
        self.client = openai.OpenAI(
            api_key=config.api_key,
            base_url=config.base_url or "https://api.deepseek.com"
        )
    
    def generate_completion(self, prompt: str, max_tokens: Optional[int] = None) -> str:
        """Generate completion using the configured API"""
        max_tokens = max_tokens or self.config.max_tokens
        
        try:
            # For deepseek-reasoner, remove temperature parameter
            params = {
                "model": self.config.model,
                "messages": [
                    {"role": "system", "content": "You are a helpful AI assistant specialized in therapeutic conversation analysis. Stick strictly to the user's instructions without adding external knowledge or inferring unstated information."},
                    {"role": "user", "content": prompt}
                ],
                "max_tokens": max_tokens,
                "timeout": self.config.timeout
            }
            
            # Only add temperature for non-reasoner models
            if "reasoner" not in self.config.model.lower():
                params["temperature"] = self.config.temperature
                
            response = self.client.chat.completions.create(**params)
            
            # Handle different response structures
            message = response.choices[0].message
            
            # For deepseek-reasoner, it has both reasoning_content and content
            if hasattr(message, 'reasoning_content') and message.reasoning_content:
                # You can choose to use reasoning_content, content, or both
                # For debugging, you might want to log the reasoning_content
                logger.info(f"Reasoning: {message.reasoning_content[:200]}...")
                return message.content.strip()  # Return the final answer
            else:
                return message.content.strip()
        
        except Exception as e:
            logger.error(f"API request failed: {e}")
            time.sleep(self.config.rate_limit_delay * 2)
            raise
        
        finally:
            time.sleep(self.config.rate_limit_delay)

    def test_api_connection(self):
        """Test API connection and basic functionality"""
        test_prompt = "Say 'test successful' briefly."
        
        try:
            response = self.generate_completion(test_prompt, max_tokens=20)
            logger.info(f"API Test Response: {response}")
            # More lenient check
            if response and len(response.strip()) > 0:
                logger.info("API connection validated successfully")
                return True
            else:
                logger.warning("API responded with empty content")
                return False
        except Exception as e:
            logger.error(f"API Test Failed: {e}")
            return False
            
class ImprovedTrainingDataGenerator:
    def __init__(self, api_config: APIConfig):
        self.api_client = APIClient(api_config)
        self.phq8_questions = {
            "PHQ1": "Little interest or pleasure in doing things",
            "PHQ2": "Feeling down, depressed, or hopeless",
            "PHQ3": "Trouble falling or staying asleep, or sleeping too much",
            "PHQ4": "Feeling tired or having little energy",
            "PHQ5": "Poor appetite or overeating",
            "PHQ6": "Feeling bad about yourself or that you are a failure",
            "PHQ7": "Trouble concentrating on things",
            "PHQ8": "Moving or speaking slowly or being fidgety/restless"
        }
        self.phq8_severity_levels = [
            "Not explored", "Not at all", "Several days", "More than half the days", "Nearly every day"
        ]
        self.depression_classifications = ["Depressed", "Not depressed"]
        self.emotion_tags = [
            "Empathy", "Neutral/Supportive", "Probing", "Encouraging", "Reassuring", "Clarifying"]
        self.response_strategies = [
            "Continue PHQ Assessment", "Explore Depression Indicators",
            "Provide Emotional Support", "Gather Context", "Redirect Conversation",
            "Build Rapport", "Validate Feelings", "Ask Follow-up Questions","Ending conversation"
        ]

    def parse_single_column_csv(self, file_path: str) -> pd.DataFrame:
        """Parse single column CSV format with combined timestamp_speaker_text"""
        logger.info(f"Loading single-column CSV: {file_path}")
        
        df = pd.read_csv(file_path, header=None)
        
        if df.iloc[0, 0] == 'start_timestop_timespeakervalue' or 'start_timestop' in str(df.iloc[0, 0]).lower():
            df = df.iloc[1:].reset_index(drop=True)
        
        parsed_data = []
        
        for idx, row in df.iterrows():
            raw_text = str(row[0]).strip()
            
            if not raw_text or raw_text == 'nan':
                continue
                
            pattern = r'^(\d+\.\d+)(\d+\.\d+)([A-Za-z]+)(.*)'
            match = re.match(pattern, raw_text)
            
            if match:
                start_time = float(match.group(1))
                end_time = float(match.group(2))  
                speaker = match.group(3)
                text = match.group(4).strip()
                timestamp = start_time
                
            else:
                pattern2 = r'^(\d+\.\d+)([A-Za-z]+)(.*)'
                match2 = re.match(pattern2, raw_text)
                
                if match2:
                    timestamp = float(match2.group(1))
                    speaker = match2.group(2)
                    text = match2.group(3).strip()
                else:
                    logger.warning(f"Could not parse line {idx}: {raw_text[:50]}...")
                    timestamp = idx
                    speaker = "Unknown"
                    text = raw_text
            
            speaker = speaker.strip()
            if speaker.lower() in ['ellie', 'elle', 'eli']:
                speaker = 'Ellie'
            elif speaker.lower() in ['participant', 'p', 'user']:
                speaker = 'Participant'
            
            parsed_data.append({
                'timestamp': timestamp,
                'speaker': speaker,
                'text': text
            })
        
        parsed_df = pd.DataFrame(parsed_data)
        parsed_df = parsed_df[parsed_df['text'].str.len() > 0]
        parsed_df = parsed_df.sort_values('timestamp').reset_index(drop=True)
        
        logger.info(f"Parsed {len(parsed_df)} individual utterances from single-column CSV")
        logger.info(f"Speakers found: {parsed_df['speaker'].unique()}")
        
        return parsed_df

    def load_transcript_data(self, file_path: str) -> pd.DataFrame:
        """Load and preprocess transcript data with proper column handling for multiple formats"""
        logger.info(f"Detecting CSV format for: {file_path}")
        
        separators = [',', '\t', ';', '|']
        df_test = None
        separator_used = None
        
        for sep in separators:
            try:
                df_test = pd.read_csv(file_path, sep=sep, nrows=5)
                if df_test.shape[1] >= 3:
                    separator_used = sep
                    logger.info(f"Detected {df_test.shape[1]} columns with separator: '{sep}'")
                    break
            except:
                continue
        
        if df_test is None or df_test.shape[1] < 3:
            logger.info("Falling back to single column parsing")
            df = self.parse_single_column_csv(file_path)
        else:
            logger.info(f"Processing as multi-column format with separator '{separator_used}'")
            df = pd.read_csv(file_path, sep=separator_used)
            
            if df.shape[1] == 4:
                df.columns = ['start_time', 'stop_time', 'speaker', 'text']
                df['timestamp'] = df['start_time']
            elif df.shape[1] >= 3:
                df.columns = ['timestamp', 'speaker', 'text'] + [f'extra_{i}' for i in range(df.shape[1] - 3)]
            
            df['timestamp'] = pd.to_numeric(df['timestamp'], errors='coerce')
            df['speaker'] = df['speaker'].astype(str).str.strip()
            df['text'] = df['text'].astype(str).str.strip()
            
            df = df.dropna(subset=['timestamp', 'speaker', 'text'])
            df = df[df['text'] != '']
            
            if df.iloc[0]['speaker'].lower() in ['speaker', 'ellie', 'participant']:
                if any(word in str(df.iloc[0]['text']).lower() for word in ['value', 'text', 'transcript']):
                    df = df.iloc[1:].reset_index(drop=True)
            
            df = df.sort_values('timestamp').reset_index(drop=True)
            df = df[['timestamp', 'speaker', 'text']].copy()
        
        logger.info(f"Successfully loaded {len(df)} individual utterances")
        return df

    def group_by_speaker_turns(self, df: pd.DataFrame) -> pd.DataFrame:
        """Group consecutive utterances by the same speaker into single turns"""
        logger.info("Grouping consecutive utterances by speaker into turns...")
        
        if len(df) == 0:
            return df
        
        turns = []
        current_speaker = None
        current_texts = []
        current_timestamps = []
        turn_id = 0
        
        for _, row in df.iterrows():
            speaker = row['speaker']
            text = row['text']
            timestamp = row['timestamp']
            
            if speaker != current_speaker:
                if current_speaker is not None and current_texts:
                    combined_text = ' '.join(current_texts)
                    start_timestamp = min(current_timestamps)
                    end_timestamp = max(current_timestamps)
                    
                    turns.append({
                        'turn_id': turn_id,
                        'timestamp': start_timestamp,
                        'end_timestamp': end_timestamp,
                        'speaker': current_speaker,
                        'text': combined_text,
                        'utterance_count': len(current_texts),
                        'duration': end_timestamp - start_timestamp
                    })
                    turn_id += 1
                
                current_speaker = speaker
                current_texts = [text]
                current_timestamps = [timestamp]
            else:
                current_texts.append(text)
                current_timestamps.append(timestamp)
        
        if current_speaker is not None and current_texts:
            combined_text = ' '.join(current_texts)
            start_timestamp = min(current_timestamps)
            end_timestamp = max(current_timestamps)
            
            turns.append({
                'turn_id': turn_id,
                'timestamp': start_timestamp,
                'end_timestamp': end_timestamp,
                'speaker': current_speaker,
                'text': combined_text,
                'utterance_count': len(current_texts),
                'duration': end_timestamp - start_timestamp
            })
        
        turns_df = pd.DataFrame(turns)
        logger.info(f"Grouped {len(df)} utterances into {len(turns_df)} speaker turns")
        
        if not turns_df.empty:
            speaker_counts = turns_df['speaker'].value_counts()
            logger.info("Turn distribution by speaker:")
            for speaker, count in speaker_counts.items():
                logger.info(f"  {speaker}: {count} turns")
        
        return turns_df

    def extract_conversation_turns(self, df: pd.DataFrame) -> List[Dict]:
        """Extract conversation turns with proper speaker identification and grouping"""
        grouped_df = self.group_by_speaker_turns(df)
        
        turns = []
        
        for idx, row in grouped_df.iterrows():
            speaker = row['speaker'].lower()
            if 'ellie' in speaker:
                speaker_role = 'therapist'
            elif 'participant' in speaker:
                speaker_role = 'participant'
            else:
                speaker_role = 'unknown'
            
            turn = {
                'turn_id': row['turn_id'],
                'timestamp': row['timestamp'],
                'end_timestamp': row.get('end_timestamp', row['timestamp']),
                'speaker': row['speaker'],
                'speaker_role': speaker_role,
                'text': row['text'],
                'text_length': len(row['text'].split()),
                'utterance_count': row.get('utterance_count', 1),
                'duration': row.get('duration', 0.0)
            }
            turns.append(turn)
        
        logger.info(f"Extracted {len(turns)} conversation turns with speaker-based switching")
        return turns

    def identify_transition_points(self, turns: List[Dict]) -> List[int]:
        """Identify transition points only from participant to therapist"""
        transition_points = []
        
        for i in range(len(turns) - 1):
            current_speaker = turns[i]['speaker_role']
            next_speaker = turns[i + 1]['speaker_role']
            
            if current_speaker == 'participant' and next_speaker == 'therapist':
                transition_points.append(i + 1)  # Point to the therapist's turn
        
        logger.info(f"Identified {len(transition_points)} participant-to-therapist transition points for training data generation")
        return transition_points

    def generate_single_summary(self, turns: List[Dict], transition_idx: int) -> Dict:
        """Helper to generate a single summary at a transition point"""
        context_turns = turns[:transition_idx]
        
        context_text = "\n".join([
            f"{t['speaker_role']}: {t['text']}" for t in context_turns
        ])
        #print('summary')
        #print(context_text)
        
        prompt = f"""
Task: Generate a cumulative summary of the therapeutic conversation up to this transition point.

Step-by-Step Instructions:
1. Read the conversation carefully.
2. Identify only explicit emotional expressions, mood indicators, and depression-related symptoms mentioned.
3. Note all PHQ-8 related content if directly discussed (e.g., interest, mood, sleep, energy, appetite, self-worth, concentration, psychomotor changes). List evidence for each explored symptom explicitly, even if minor. Do not infer or omit any.
4. Focus solely on clinically relevant information; ignore casual talk.
5. Build on previous information without repetition. Base ONLY on the provided conversation—do not add external knowledge or assumptions.

Conversation so far:
{context_text}

Output a concise yet comprehensive summary no more than 2 short paragraphs or 6–8 sentences , capturing:
- Key emotional indicators (cite direct quotes).
- All depression symptoms discussed, with explicit mapping to PHQ-8 questions (if any).
- Participant's current state/mood (based on evidence).
- Therapist's assessment approach (if evident).
- Important quotes or expressions.

Summary:"""  # Optimized: Added step-by-step, grounding, citation requirement

        try:
            response = self.api_client.generate_completion(prompt, max_tokens=1000)
            summary_text = response.strip()
            
            return {
                'transition_turn_id': transition_idx,
                'cumulative_summary': summary_text,
                'conversation_length': len(context_turns),
                'context_turns_count': len(context_turns)
            }
            
        except Exception as e:
            logger.error(f"Error generating summary for transition at turn {transition_idx}: {e}")
            return {
                'transition_turn_id': transition_idx,
                'cumulative_summary': "Summary generation failed",
                'conversation_length': len(context_turns),
                'context_turns_count': len(context_turns)
            }

    def generate_summaries_at_transitions(self, turns: List[Dict], transition_points: List[int]) -> List[Dict]:
        """Generate summaries only at speaker transition points, parallelized"""
        summaries = []
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:  # Capped to 3 for rate limits
            future_to_idx = {executor.submit(self.generate_single_summary, turns, idx): idx for idx in transition_points}
            
            for future in concurrent.futures.as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    summary = future.result()
                    summaries.append(summary)
                    logger.info(f"Generated summary for transition at turn {idx}")
                except Exception as e:
                    logger.error(f"Error in parallel summary generation for turn {idx}: {e}")
                    summaries.append({
                        'transition_turn_id': idx,
                        'cumulative_summary': "Summary generation failed",
                        'conversation_length': 0,
                        'context_turns_count': 0
                    })
        
        summaries.sort(key=lambda x: x['transition_turn_id'])
        
        return summaries

    def generate_single_strategy(self, turns: List[Dict], transition_idx: int, classifications: List[Dict], i: int, summaries: List[Dict]) -> Dict:
        """Helper to generate a single response strategy at a transition point"""
        before_turn = turns[transition_idx - 1]  # Previous speaker's turn
        after_turn = turns[transition_idx] if transition_idx < len(turns) else None
        
        if after_turn is None:
            return None
            
        if (before_turn['speaker_role'] == 'participant' and 
            after_turn['speaker_role'] == 'therapist'):
            
            current_classification = classifications[i] if i < len(classifications) else {}
            current_summary = summaries[i]['cumulative_summary'] if i < len(summaries) else ""
            
            # Use full conversation history up to the point (excluding current therapist response) for accuracy
            full_history = "\n".join([
                f"{t['speaker_role']}: {t['text']}" for t in turns[:transition_idx]
            ])

            
            prompt = f"""
    Task: Analyze the therapeutic response strategy used at this speaker transition.
    
    Step-by-Step Instructions:
    1. Read the full conversation history up to the participant's statement, therapist's response, and cumulative summary (derived from the full conversation history).
    2. Identify the strategy based ONLY on the response's content (e.g., if asking questions, it's "Ask Follow-up Questions").
    3. Choose the best emotion tag that matches the tone, using the suggestions as a starting point but generating a new one if the context requires a more appropriate fit.
    4. Describe the intent concisely, based on evidence in the response and broader context from the summary and history. Do not infer unstated goals.
    
    Cumulative summary (derived from full conversation): {current_summary}
    
    Full conversation history up to participant: 
    {full_history}
    
    Participant said: "{before_turn['text']}"
    
    Therapist responded: "{after_turn['text']}"
    
    Current Assessment State: {json.dumps(current_classification.get('phq8_scores', {}), indent=2)}
    
    Reference Strategies (use as inspiration; generate a new concise 1-2 word description if none fit the context perfectly): 
    {self.response_strategies}
    
    Suggested Emotion Tags (choose one from the list or generate a new one based on context if more appropriate): {self.emotion_tags}
    
    Output ONLY the JSON object in this EXACT format. No additional text:
    {{
      "strategy_used": "concise 1-2 word strategy description",
      "emotion_tag": "emotion tag (from suggestions or context-based)", 
      "response_intent": "1-sentence description of what the therapist was trying to achieve, based on evidence"
    }}"""
            
            try:
                response = self.api_client.generate_completion(prompt, max_tokens=300)
                
                json_match = re.search(r'\{.*\}', response, re.DOTALL)
                if json_match:
                    strategy_data = json.loads(json_match.group())
                    strategy_data.update({
                        'transition_turn_id': transition_idx,
                        'participant_turn_id': transition_idx - 1,
                        'therapist_turn_id': transition_idx,
                        'participant_text': before_turn['text'],
                        'therapist_response': after_turn['text']
                    })
                    return strategy_data
                else:
                    return self.create_default_strategy_at_transition(transition_idx, before_turn, after_turn)
                
            except Exception as e:
                logger.error(f"Error analyzing strategy for transition at turn {transition_idx}: {e}")
                return self.create_default_strategy_at_transition(transition_idx, before_turn, after_turn)
        return None

    def generate_classifications_at_transitions(self, turns: List[Dict], transition_points: List[int], final_phq_scores: Dict, final_depression_label: str, summaries: List[Dict]) -> List[Dict]:
        """Generate classifications only at speaker transition points with final label validation, parallelized"""
        classifications = []
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
            future_to_idx = {executor.submit(self.generate_single_classification, turns, transition_points[i], final_phq_scores, final_depression_label, summaries[i]['cumulative_summary']): transition_points[i] for i in range(len(transition_points))}
            
            for future in concurrent.futures.as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    classification = future.result()
                    classifications.append(classification)
                    logger.info(f"Generated classification for transition at turn {idx}")
                except Exception as e:
                    logger.error(f"Error in parallel classification generation for turn {idx}: {e}")
                    classifications.append(self.create_default_classification_at_transition({'transition_turn_id': idx}))
        
        classifications.sort(key=lambda x: x['transition_turn_id'])
        
        if final_phq_scores and final_depression_label and classifications:
            classifications = self.align_with_ground_truth_flexible(
                classifications, final_phq_scores, final_depression_label
            )
        
        return classifications

    def generate_single_classification(self, turns: List[Dict], transition_idx: int, final_phq_scores: Dict, final_depression_label: str, current_summary: str) -> Dict:
        """Helper to generate a single classification at a transition point using summary + full history up to point"""
        full_turns = turns[:transition_idx]  # Full up to transition, excluding response
        full_text = "\n".join([f"{t['speaker_role']}: {t['text']}" for t in full_turns])
        #print('classify')
        #print(full_text)
        
        prompt = f"""
    Task: Generate PHQ-8 classifications based on the cumulative summary and full conversation up to this transition point.
    
    Step-by-Step Instructions:
    1. Read the cumulative summary and full conversation carefully and focus ONLY on the participant's sentiments, emotional expressions, and direct statements. Ignore therapist's parts unless they explicitly quote or reflect the participant's sentiment.
    2. Identify explicit evidence for each PHQ-8 question from the participant's sentiment-related content.
    3. Map only direct evidence to severity levels. Be conservative: If no explicit evidence in the participant's sentiments, use "Not explored". Do not infer, hallucinate, or use therapist's interpretations.
    4. For depression classification, assess based solely on mapped severities from participant sentiments (e.g., if multiple high severities in sentiments, classify as "Depressed").
    5. Provide evidence as direct quotes/phrases from the participant's sentiments or "no evidence".
    
    Example:
    If conversation says "Participant: I feel down every day", map PHQ2 to "Nearly every day" with evidence "I feel down every day".
    
    PHQ-8 Questions:
    {json.dumps(self.phq8_questions, indent=2)}
    
    Severity Levels: {self.phq8_severity_levels}
    
    Cumulative summary (derived from full conversation history): {current_summary}
    
    Full conversation up to transition: {full_text}
    
    Output ONLY the JSON object in this EXACT format. No additional text:
    {{
      "phq8_scores": {{
        "PHQ1": "severity_level",
        "PHQ2": "severity_level",
        "PHQ3": "severity_level",
        "PHQ4": "severity_level",
        "PHQ5": "severity_level",
        "PHQ6": "severity_level",
        "PHQ7": "severity_level",
        "PHQ8": "severity_level"
      }},
      "depression_classification": "Depressed" or "Not depressed",
      "evidence_mapping": {{
        "PHQ1": "evidence from participant's sentiments or 'no evidence'",
        "PHQ2": "evidence from participant's sentiments or 'no evidence'",
        "PHQ3": "evidence from participant's sentiments or 'no evidence'",
        "PHQ4": "evidence from participant's sentiments or 'no evidence'",
        "PHQ5": "evidence from participant's sentiments or 'no evidence'",
        "PHQ6": "evidence from participant's sentiments or 'no evidence'",
        "PHQ7": "evidence from participant's sentiments or 'no evidence'",
        "PHQ8": "evidence from participant's sentiments or 'no evidence'"
      }}
    }}"""
        
        try:
            response = self.api_client.generate_completion(prompt, max_tokens=700)
            
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                classification_data = json.loads(json_match.group())
                classification_data = self.validate_classification_structure(classification_data)
                classification_data['transition_turn_id'] = transition_idx
                return classification_data
            else:
                return self.create_default_classification_at_transition({'transition_turn_id': transition_idx})
            
        except Exception as e:
            logger.error(f"Error generating classification for transition at turn {transition_idx}: {e}")
            return self.create_default_classification_at_transition({'transition_turn_id': transition_idx})

    def generate_response_strategies_at_transitions(self, turns: List[Dict], transition_points: List[int],
                                                    classifications: List[Dict], summaries: List[Dict]) -> List[Dict]:
        """Generate response strategies only at speaker transition points, parallelized"""
        strategies = []
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
            future_to_idx = {executor.submit(self.generate_single_strategy, turns, transition_points[i], classifications, i, summaries): i for i in range(len(transition_points))}
            
            for future in concurrent.futures.as_completed(future_to_idx):
                i = future_to_idx[future]
                try:
                    strategy = future.result()
                    if strategy:
                        strategies.append(strategy)
                        logger.info(f"Analyzed strategy for transition at turn {transition_points[i]}")
                except Exception as e:
                    logger.error(f"Error in parallel strategy generation for turn {transition_points[i]}: {e}")
                    strategies.append(self.create_default_strategy_at_transition(transition_points[i], turns[transition_points[i]-1], turns[transition_points[i]]))
        
        strategies.sort(key=lambda x: x['transition_turn_id'])
        
        return strategies

    def validate_classification_structure(self, classification_data: Dict) -> Dict:
        """Validate and fix classification structure"""
        if 'phq8_scores' not in classification_data:
            classification_data['phq8_scores'] = {}
        
        if 'evidence_mapping' not in classification_data:
            classification_data['evidence_mapping'] = {key: "no evidence" for key in self.phq8_questions.keys()}
        
        for phq_key in self.phq8_questions.keys():
            if phq_key not in classification_data['phq8_scores'] or classification_data['phq8_scores'][phq_key] not in self.phq8_severity_levels:
                classification_data['phq8_scores'][phq_key] = "Not explored"
            
            # Post-generation rule: If evidence is "no evidence" but severity isn't "Not explored", reset to "Not explored"
            evidence = classification_data['evidence_mapping'].get(phq_key, "no evidence").strip().lower()
            if evidence == "no evidence" and classification_data['phq8_scores'][phq_key] != "Not explored":
                classification_data['phq8_scores'][phq_key] = "Not explored"
                classification_data['evidence_mapping'][phq_key] = "no evidence (reset due to lack of evidence)"
        
        if 'depression_classification' not in classification_data or classification_data['depression_classification'] not in self.depression_classifications:
            classification_data['depression_classification'] = "Not depressed"
        
        return classification_data

    def create_default_classification_at_transition(self, summary_data: Dict) -> Dict:
        """Create default classification when generation fails at transition"""
        return {
            'transition_turn_id': summary_data['transition_turn_id'],
            'phq8_scores': {key: "Not explored" for key in self.phq8_questions.keys()},
            'depression_classification': "Not depressed",
            'evidence_mapping': {key: "no evidence" for key in self.phq8_questions.keys()},
            'generation_method': 'default_fallback'
        }

    def create_default_strategy_at_transition(self, transition_turn_id: int, participant_turn: Dict, 
                              therapist_turn: Dict) -> Dict:
        """Create default strategy when analysis fails at transition"""
        return {
            'transition_turn_id': transition_turn_id,
            'participant_turn_id': transition_turn_id - 1,
            'therapist_turn_id': transition_turn_id,
            'strategy_used': 'Gather Context',
            'emotion_tag': 'Neutral/Supportive',
            'response_intent': 'Continue conversation',
            'participant_text': participant_turn['text'],
            'therapist_response': therapist_turn['text']
        }

    def align_with_ground_truth_flexible(self, classifications: List[Dict], final_phq_scores: Dict, 
                                       final_depression_label: str) -> List[Dict]:
        """Align classifications with ground truth allowing some flexibility/offset"""
        score_to_severity = {0: "Not at all", 1: "Several days", 2: "More than half the days", 3: "Nearly every day"}
        
        for i, classification in enumerate(classifications):
            is_final_transition = (i == len(classifications) - 1)
            
            if is_final_transition:
                for phq_key in self.phq8_questions.keys():
                    if phq_key in final_phq_scores:
                        numeric_score = final_phq_scores[phq_key]
                        if isinstance(numeric_score, (int, float)) and 0 <= numeric_score <= 3:
                            classification['phq8_scores'][phq_key] = score_to_severity[int(numeric_score)]
                
                classification['depression_classification'] = final_depression_label
                classification['ground_truth_aligned'] = True
                classification['alignment_type'] = 'final_exact'
            else:
                classification['ground_truth_aligned'] = False
                classification['alignment_type'] = 'intermediate_flexible'
        
        return classifications

    def create_aligned_training_samples(self, turns: List[Dict], summaries: List[Dict],
                                      classifications: List[Dict], strategies: List[Dict], 
                                      transition_points: List[int]) -> Dict:
        """Create properly aligned training samples following the pipeline flow"""
        training_samples = {
            'summary_module_samples': [],
            'classification_module_samples': [],
            'response_generation_samples': []
        }
        
        for summary_data in summaries:
            transition_turn_id = summary_data['transition_turn_id']
            conversation_history = []
            
            max_turn_for_summary = transition_turn_id - 1  # Exclude target response
            
            for j in range(max_turn_for_summary + 1):
                if j < len(turns):
                    conversation_history.append({
                        'speaker_role': turns[j]['speaker_role'],
                        'text': turns[j]['text']
                    })
            
            sample = {
                'task': 'Summary',
                'instruction': 'You are a specialized AI assistant for generating cumulative summaries in therapeutic conversations. Based on the full conversation history up to this point, create a concise summary (2-4 sentences) focusing on explicit emotional indicators, all depression symptoms discussed (with direct mapping to PHQ-8 questions if applicable), participant\'s current state/mood based on evidence, therapist\'s assessment approach if evident, and key quotes. Cite direct evidence without inference or external knowledge.',
                'input': {
                    'conversation_history': conversation_history
                },
                'expected_output': {
                    'cumulative_summary': summary_data['cumulative_summary']
                },
                'metadata': {
                    'transition_turn_id': transition_turn_id,
                    'conversation_length': len(conversation_history),
                    'context_turns_count': len(conversation_history)
                }
            }
            training_samples['summary_module_samples'].append(sample)
        
        for i, classification_data in enumerate(classifications):
            transition_turn_id = classification_data['transition_turn_id']
            current_summary = next((s['cumulative_summary'] for s in summaries if s['transition_turn_id'] == transition_turn_id), "")
            recent_history = []
            max_turn_for_context = transition_turn_id - 1
            start_idx = max(0, max_turn_for_context - 9)  # Last 10 turns
            for j in range(start_idx, max_turn_for_context + 1):
                if j < len(turns):
                    recent_history.append({
                        'speaker_role': turns[j]['speaker_role'],
                        'text': turns[j]['text']
                    })
            sample = {
                'task': 'Classification',
                'instruction': 'You are a specialized AI assistant for PHQ-8 classification in therapeutic conversations. Using the recent conversation turns (last 10) combined with the cumulative summary (derived from full history), generate severity levels for each PHQ-8 question based only on explicit participant sentiments and evidence. Be conservative: Use "Not explored" if no direct evidence. Provide evidence mappings from the inputs. Output only valid JSON in this exact format: {"phq8_scores": {"PHQ1": "value", "PHQ2": "value", "PHQ3": "value", "PHQ4": "value", "PHQ5": "value", "PHQ6": "value", "PHQ7": "value", "PHQ8": "value"}, "depression_classification": "Depressed" or "Not depressed", "evidence_mapping": {"PHQ1": "evidence or \'no evidence\'", ...}}. Do not include any other text, explanations, tags, or tool calls.',
                'input': {
                    'recent_history': recent_history,
                    'cumulative_summary': current_summary
                },
                'expected_output': {
                    'phq8_scores': classification_data['phq8_scores'],
                    'depression_classification': classification_data['depression_classification'],
                    'evidence_mapping': classification_data.get('evidence_mapping', {})
                },
                'metadata': {
                    'transition_turn_id': classification_data['transition_turn_id'],
                    'ground_truth_aligned': classification_data.get('ground_truth_aligned', False),
                    'alignment_type': classification_data.get('alignment_type', 'none')
                }
            }
            training_samples['classification_module_samples'].append(sample)
        
        for strategy_data in strategies:
            transition_turn_id = strategy_data['transition_turn_id']
            
            current_summary = next((s for s in summaries if s['transition_turn_id'] == transition_turn_id), None)
            current_classification = next((c for c in classifications if c['transition_turn_id'] == transition_turn_id), None)
            
            if current_summary and current_classification:
                recent_context = []
                max_turn_for_context = transition_turn_id - 1
                start_idx = max(0, max_turn_for_context - 9)  # Last 10 turns before response
                
                for j in range(start_idx, max_turn_for_context + 1):
                    if j < len(turns):
                        recent_context.append({
                            'speaker_role': turns[j]['speaker_role'],
                            'text': turns[j]['text']
                        })
                
                sample = {
                    'task': 'Response',
                    'instruction': 'You are a specialized AI assistant for generating therapeutic responses. Based on the recent conversation context (last 10 turns), cumulative summary (from full history), and PHQ-8 classification results, craft a natural therapist response. Incorporate an appropriate emotion tag, strategy, and intent based on context. Output in JSON: {"therapist_response": "response text", "emotion_tag": "tag", "strategy_used": "strategy", "response_intent": "1-sentence intent"}. Ensure response is empathetic, probing if needed, and advances the assessment.',
                    'input': {
                        'recent_context': recent_context,
                        'cumulative_summary': current_summary['cumulative_summary'],
                        'classification_results': {
                            'phq8_scores': current_classification['phq8_scores'],
                            'depression_classification': current_classification['depression_classification']
                        }
                    },
                    'expected_output': {
                        'therapist_response': strategy_data['therapist_response'],
                        'emotion_tag': strategy_data['emotion_tag'],
                        'strategy_used': strategy_data['strategy_used'],
                        'response_intent': strategy_data['response_intent']
                    },
                    'metadata': {
                        'transition_turn_id': transition_turn_id,
                        'participant_turn_id': strategy_data['participant_turn_id'],
                        'therapist_turn_id': strategy_data['therapist_turn_id'],
                        'response_length': len(strategy_data['therapist_response'].split())
                    }
                }
                training_samples['response_generation_samples'].append(sample)
                
        return training_samples

    def process_transcript_with_proper_alignment(self, file_path: str, final_phq_scores: Dict,
                                               final_depression_label: str, output_file: str = None) -> Dict:
        """Process transcript with proper pipeline alignment using API - only at speaker transitions"""
        logger.info(f"Processing transcript with speaker transition alignment: {file_path}")
        
        df = self.load_transcript_data(file_path)
        turns = self.extract_conversation_turns(df)
        logger.info(f"Extracted {len(turns)} conversation turns")
        
        transition_points = self.identify_transition_points(turns)
        logger.info(f"Identified {len(transition_points)} speaker transition points")
        
        logger.info("Generating summaries at speaker transitions using API...")
        summaries = self.generate_summaries_at_transitions(turns, transition_points)
        
        logger.info("Generating classifications at speaker transitions using API...")
        classifications = self.generate_classifications_at_transitions(turns, transition_points, final_phq_scores, final_depression_label, summaries)
        
        logger.info("Analyzing response strategies at speaker transitions using API...")
        strategies = self.generate_response_strategies_at_transitions(turns, transition_points, classifications, summaries)
        
        logger.info("Creating aligned training samples...")
        training_samples = self.create_aligned_training_samples(turns, summaries, classifications, strategies, transition_points)
        
        training_data = {
            'metadata': {
                'source_file': file_path,
                'generation_timestamp': datetime.now().isoformat(),
                'total_turns': len(turns),
                'transition_points': len(transition_points),
                'final_phq_scores': final_phq_scores,
                'final_depression_label': final_depression_label,
                'api_provider': self.api_client.config.provider,
                'api_model': self.api_client.config.model
            },
            'raw_conversation_turns': turns,
            'speaker_transition_points': transition_points,
            'transition_summaries': summaries,
            'transition_classifications': classifications,
            'transition_response_strategies': strategies,
            'aligned_training_samples': training_samples
        }
        
        if output_file:
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(training_data, f, indent=2, ensure_ascii=False)
            logger.info(f"Training data saved to: {output_file}")
        
        return training_data

    def export_training_data(self, training_data: Dict, output_dir: str = "training_data"):
        """Export training data for fine-tuning"""
        os.makedirs(output_dir, exist_ok=True)
        
        training_samples = training_data.get('aligned_training_samples', {})
        
        for module_name, samples in training_samples.items():
            if samples:
                jsonl_file = os.path.join(output_dir, f"{module_name}.jsonl")
                with open(jsonl_file, 'w', encoding='utf-8') as f:
                    for sample in samples:
                        conversation = {
                            "messages": [
                                {
                                    "role": "system",
                                    "content": f"You are a specialized AI assistant for {sample['task']}. {sample['instruction']}"
                                },
                                {
                                    "role": "user", 
                                    "content": json.dumps(sample['input'], ensure_ascii=False)
                                },
                                {
                                    "role": "assistant",
                                    "content": json.dumps(sample['expected_output'], ensure_ascii=False)
                                }
                            ],
                            "metadata": sample.get('metadata', {})
                        }
                        f.write(json.dumps(conversation, ensure_ascii=False) + "\n")
                logger.info(f"Exported {len(samples)} samples to {jsonl_file}")

    def test_csv_format_detection(self, file_path: str):
        """Test CSV format detection and parsing with speaker-based turn grouping"""
        logger.info("Testing CSV format detection, parsing, and speaker transition identification...")
        
        df = self.load_transcript_data(file_path)
        
        print("\n" + "="*80)
        print("CSV FORMAT DETECTION AND PARSING TEST")
        print("="*80)
        print(f"Total utterances parsed: {len(df)}")
        print(f"Speakers found: {df['speaker'].unique()}")
        print(f"Timestamp range: {df['timestamp'].min():.3f} - {df['timestamp'].max():.3f}")
        
        print("\nFirst 10 parsed utterances:")
        for idx, row in df.head(10).iterrows():
            print(f"Utterance {idx}: [{row['timestamp']:.3f}] {row['speaker']}: {row['text'][:60]}...")
        
        grouped_df = self.group_by_speaker_turns(df)
        
        if not grouped_df.empty:
            turns = self.extract_conversation_turns(df)
            transition_points = self.identify_transition_points(turns)
            
            print(f"\nGrouped {len(df)} utterances into {len(grouped_df)} speaker turns")
            print(f"Identified {len(transition_points)} speaker transition points")
            print(f"Transition points: {transition_points}")
            
            print(f"\nValidation: Ready for training - {'✓ YES' if len(transition_points) > 0 else '✗ NO'}")
        
        return df

    def test_api_connection(self):
        """Test API connection and basic functionality"""
        test_prompt = "Say 'test successful' and nothing else."
        
        try:
            response = self.api_client.generate_completion(test_prompt, max_tokens=20)
            logger.info(f"API Test Response: {response}")
            # More lenient check - just verify we got a non-empty response
            if response and len(response.strip()) > 0:
                logger.info("API connection validated successfully")
                return True
            else:
                logger.warning("API responded with empty content")
                return False
        except Exception as e:
            logger.error(f"API Test Failed: {e}")
            return False


def process_single_transcript(generator, file_name, transcripts_folder, phq_df, checkpoint_dir):
    """Helper function to process a single transcript with checkpointing"""
    if not file_name.lower().endswith(".csv"):
        return None
    
    try:
        participant_id = int(file_name.split("_")[0])
    except ValueError:
        print(f"⚠ Skipping file with unexpected name format: {file_name}")
        return None
    
    phq_row = phq_df[phq_df["Participant_ID"] == participant_id]
    if phq_row.empty:
        print(f"⚠ No PHQ data found for participant {participant_id}, skipping...")
        return None
    
    final_phq_scores = phq_row.iloc[0].drop(
        labels=["Participant_ID", "PHQ8_Binary", "PHQ8_Score", "Gender"]
    ).to_dict()
    
    binary_label = phq_row.iloc[0]["PHQ8_Binary"]
    final_depression_label = "Depressed" if binary_label == 1 else "Not depressed"
    
    file_path = os.path.join(transcripts_folder, file_name)
    output_file = f'simplified_training_data_{participant_id}.json'
    output_dir = f"simplified_training_data_{participant_id}"
    checkpoint_file = os.path.join(checkpoint_dir, f"checkpoint_{participant_id}.pkl")
    
    print("\n" + "="*80)
    print(f"PROCESSING TRANSCRIPT: {file_name}")
    print("="*80)
    
    try:
        if os.path.exists(checkpoint_file):
            with open(checkpoint_file, 'rb') as f:
                checkpoint_data = pickle.load(f)
            df = checkpoint_data.get('df')
            training_data = checkpoint_data.get('training_data')
            if df is not None and training_data is not None:
                print(f"✅ Loaded checkpoint for participant {participant_id}")
            else:
                df = None
                training_data = None
        else:
            df = None
            training_data = None
        
        if df is None:
            df = generator.test_csv_format_detection(file_path)
        
        if training_data is None:
            training_data = generator.process_transcript_with_proper_alignment(
                file_path=file_path,
                final_phq_scores=final_phq_scores,
                final_depression_label=final_depression_label,
                output_file=output_file
            )
        
        with open(checkpoint_file, 'wb') as f:
            pickle.dump({'df': df, 'training_data': training_data}, f)
        
        generator.export_training_data(
            training_data=training_data,
            output_dir=output_dir
        )
        
        print(f"✅ Processed participant {participant_id}")
        print(f"Final depression label: {final_depression_label}")
        print(f"Total transitions: {len(training_data['speaker_transition_points'])}")
        
        return training_data
    
    except Exception as e:
        logger.error(f"Error processing {file_name}: {e}")
        print(f"❌ Error processing {file_name}: {e}")
        return None

def main_api():
    """Batch process transcripts from a folder, matching with PHQ scores, parallelized with checkpoints."""
    
    deepseek_config = APIConfig(
        provider="deepseek",
        api_key="sk-",
        base_url="https://api.deepseek.com",
        model="deepseek-chat",  # Your specified model
        max_tokens=1000,
        temperature=0.2,
        rate_limit_delay=0.05
    )
    
    generator = ImprovedTrainingDataGenerator(deepseek_config)
    
    print("Testing API connection...")
    if not generator.test_api_connection():
        print("❌ API connection failed. Please check your configuration.")
        return None
    
    print("✅ API connection successful!")

    transcripts_folder = "Extracted_Text_Transcript_DAIC"
    phq_csv_path = "train_split_Depression_AVEC2017.csv"
    checkpoint_dir = "checkpoints-v2"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    phq_df = pd.read_csv(phq_csv_path)
    
    file_names = [f for f in os.listdir(transcripts_folder) if f.lower().endswith(".csv")]
    
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=150) as executor:  # Capped for rate limits
        future_to_file = {executor.submit(process_single_transcript, generator, file_name, transcripts_folder, phq_df, checkpoint_dir): file_name for file_name in file_names}
        
        for future in concurrent.futures.as_completed(future_to_file):
            file_name = future_to_file[future]
            try:
                result = future.result()
                if result:
                    results.append(result)
            except Exception as e:
                logger.error(f"Error in parallel processing for {file_name}: {e}")
    
    print("\n🎯 All transcripts processed.")
    return results

if __name__ == "__main__":
    training_datas = main_api()
    
    if training_datas:
        for training_data in training_datas:
            print("\nSuccessfully generated simplified training data!")
            print(f"Total turns processed: {training_data['metadata']['total_turns']}")
            print(f"Total transition points: {len(training_data['speaker_transition_points'])}")
            print(f"Final depression classification: {training_data['metadata']['final_depression_label']}")
    else:
        print("\nProcessing failed. Check logs for details.")

INFO:__main__:Analyzed strategy for transition at turn 175
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Analyzed strategy for transition at turn 177
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Analyzed strategy for transition at turn 179
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Analyzed strategy for transition at turn 181
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Analyzed strategy for transition at turn 183
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Analyzed strategy for transition at turn 185
INFO:httpx:HTTP Request: POST https://api.deepseek.com/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:Analyzed strategy for transition at turn 187
INFO:httpx:HTTP Request: POST https://api.deeps

✅ Processed participant 320
Final depression label: Depressed
Total transitions: 99

🎯 All transcripts processed.

Successfully generated simplified training data!
Total turns processed: 114
Total transition points: 56
Final depression classification: Not depressed

Successfully generated simplified training data!
Total turns processed: 67
Total transition points: 33
Final depression classification: Not depressed

Successfully generated simplified training data!
Total turns processed: 74
Total transition points: 36
Final depression classification: Depressed

Successfully generated simplified training data!
Total turns processed: 70
Total transition points: 34
Final depression classification: Not depressed

Successfully generated simplified training data!
Total turns processed: 73
Total transition points: 36
Final depression classification: Not depressed

Successfully generated simplified training data!
Total turns processed: 75
Total transition points: 37
Final depression classificatio