In [None]:
# Cell 1: Core Imports and Base Classes
import os
import time
import json
import requests
import logging
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from collections import defaultdict

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class TherapeuticResponse:
    """Enhanced response structure for therapeutic context"""
    text: str
    timestamp: float
    error: bool = False
    processing_time: float = 0.0
    error_details: str = ""
    timeout: bool = False
    empathy_score: float = 0.0
    safety_checks: List[str] = None
    ethical_considerations: List[str] = None
    refinement_suggestions: List[str] = None
    crisis_flag: bool = False

class OllamaClient:
    """Robust Ollama client with configurable timeouts"""
    def __init__(self, model_name: str = "hf.co/TheDrummer/Gemmasutra-Mini-2B-v1-GGUF:Q3_K_L", base_url: str = "http://localhost:11434"):
        self.model_name = model_name
        self.base_url = base_url
        self.max_retries = 5
        self.request_timeout = 300
        self._verify_model()

    def _parse_json_safe(self, text: str):
        """Enhanced JSON parsing with fallback"""
        clean_text = text.strip()
        if not clean_text:
            return {"error": "Empty response"}
            
        try:
            return json.loads(clean_text)
        except json.JSONDecodeError:
            try:
                start = clean_text.find('{')
                end = clean_text.rfind('}') + 1
                return json.loads(clean_text[start:end])
            except:
                return {"error": f"Invalid JSON format: {clean_text[:200]}..."}
        except Exception as e:
            return {"error": str(e)}

    def _verify_model(self):
        """Model verification with status checks"""
        for attempt in range(self.max_retries):
            try:
                resp = requests.get(f"{self.base_url}/api/tags", timeout=10)
                if resp.status_code == 200:
                    data = self._parse_json_safe(resp.text)
                    models = [m['name'] for m in data.get('models', [])]
                    if any(self.model_name in m for m in models):
                        return
                    self._pull_model()
                    return
                logger.warning(f"Model check failed (status {resp.status_code})")
            except Exception as e:
                logger.warning(f"Model check attempt {attempt+1} failed: {e}")
                time.sleep(2 ** attempt)
        raise ConnectionError(f"Couldn't connect to Ollama after {self.max_retries} attempts")

    def _pull_model(self):
        """Model pulling with progress tracking"""
        try:
            resp = requests.post(
                f"{self.base_url}/api/pull",
                json={"name": self.model_name},
                stream=True,
                timeout=600
            )
            for line in resp.iter_lines():
                if line:
                    try:
                        status = self._parse_json_safe(line).get('status', '')
                        logger.info(f"Pull progress: {status}")
                    except:
                        continue
        except Exception as e:
            logger.error(f"Model pull failed: {e}")
            raise

    def generate(self, prompt: str) -> Tuple[str, bool]:
        """Generation with configurable timeout and retries"""
        for attempt in range(self.max_retries):
            try:
                with ThreadPoolExecutor() as executor:
                    future = executor.submit(
                        requests.post,
                        f"{self.base_url}/api/generate",
                        json={
                            "model": self.model_name,
                            "prompt": prompt[:4000],
                            "stream": False,
                            "options": {"temperature": 0.5}
                        },
                        timeout=self.request_timeout
                    )
                    resp = future.result(timeout=self.request_timeout)
                    data = self._parse_json_safe(resp.text)
                    return data.get("response", ""), False
            except FutureTimeoutError:
                logger.warning(f"Generation timed out (attempt {attempt+1})")
                return f"Error: Timeout after {self.request_timeout}s", True
            except Exception as e:
                logger.warning(f"Attempt {attempt+1} failed: {e}")
                time.sleep(1)
        return f"Error: Failed after {self.max_retries} attempts", True

# Cell 2: Base Agent Framework
class BaseAgent:
    """Timeout-aware base agent"""
    def __init__(self, client: OllamaClient):
        self.client = client
        self.retry_count = 3
        self.max_wait = 300
        
    def safe_generate(self, prompt: str) -> TherapeuticResponse:
        """Generation with time budget tracking"""
        start_time = time.time()
        error_state = False
        timeout_occurred = False
        
        if not isinstance(prompt, str) or len(prompt.strip()) == 0:
            return TherapeuticResponse(
                text="Error: Invalid input prompt",
                timestamp=start_time,
                error=True,
                error_details="Empty or non-string prompt",
                processing_time=0.0
            )
            
        for attempt in range(self.retry_count):
            try:
                with ThreadPoolExecutor() as executor:
                    future = executor.submit(self.client.generate, prompt)
                    text, error = future.result(timeout=self.max_wait)
                    
                    return TherapeuticResponse(
                        text=text,
                        timestamp=start_time,
                        error=error,
                        processing_time=time.time() - start_time,
                        error_details=text if error else "",
                        timeout=timeout_occurred
                    )
            except FutureTimeoutError:
                logger.error(f"Generation timed out after {self.max_wait}s")
                timeout_occurred = True
            except Exception as e:
                error_msg = str(e)
                logger.error(f"Generation error: {e}")
                
        return TherapeuticResponse(
            text=f"Final error: {error_msg}" if 'error_msg' in locals() else "Unknown error",
            timestamp=start_time,
            error=True,
            error_details=error_msg if 'error_msg' in locals() else "",
            processing_time=time.time() - start_time,
            timeout=timeout_occurred
        )

# Cell 3: Specialized Therapy Agents
class PatientContextAnalyzer(BaseAgent):
    """Analyzes patient input and conversation history"""
    def analyze_context(self, patient_input: str, history: List[str]) -> Tuple[Dict, float]:
        safe_input = patient_input[:2000] if patient_input else "Patient remains silent"
        history_context = "\n".join(history[-3:])[:3000] if history else "No history available"
        
        prompt = f"""Analyze therapeutic context:
        Patient Statement: {safe_input}
        Conversation History: {history_context}
        Identify:
        - Primary emotions expressed
        - Potential underlying issues
        - Immediate needs
        - Therapeutic approach suggestions
        - Crisis indicators
        
        Output JSON with:
        - emotions: List[str]
        - key_themes: List[str]
        - suggested_approaches: List[str]
        - crisis_warnings: List[str]
        - context_summary: str"""
        
        response = self.safe_generate(prompt)
        return self.client._parse_json_safe(response.text), response.processing_time

class TherapeuticResponseGenerator(BaseAgent):
    """Generates and evaluates therapist responses"""
    def generate_response(self, context: Dict, history: List[str]) -> TherapeuticResponse:
        prompt = f"""Generate therapeutic response:
        Context: {json.dumps(context)[:3000]}
        History: {" | ".join(history[-3:])[:2000] if history else "First session"}
        
        Guidelines:
        - Show empathy and validation
        - Use open-ended questions
        - Avoid medical advice
        - Maintain professional boundaries
        - Focus on patient's feelings
        - Use CBT techniques when appropriate
        
        Response:"""
        
        response = self.safe_generate(prompt)
        return self._enhance_response(response)

    def _enhance_response(self, raw_response: TherapeuticResponse) -> TherapeuticResponse:
        """Add therapeutic metrics to response"""
        empathy_score = min(1.0, raw_response.text.count("understand") * 0.1)
        return TherapeuticResponse(
            **vars(raw_response),
            empathy_score=empathy_score,
            safety_checks=["Basic validation passed"],
            refinement_suggestions=[]
        )

class ClinicalSafetyChecker(BaseAgent):
    """Ensures responses meet safety and ethical standards"""
    def evaluate_response(self, response: str, context: Dict) -> Dict:
        prompt = f"""Evaluate therapeutic response:
        Response: {response[:3000]}
        Context: {json.dumps(context)[:2000]}
        
        Check for:
        - Boundary violations
        - Medical advice
        - Crisis handling
        - Cultural sensitivity
        - Ethical compliance
        
        Output JSON with:
        - safety_score: float (0-1)
        - concerns: List[str]
        - improvement_suggestions: List[str]
        - crisis_handling: str
        - ethical_violations: List[str]"""
        
        response = self.safe_generate(prompt)
        return self.client._parse_json_safe(response.text)

# Cell 4: Therapy Response System
class TherapeuticResponseSystem:
    """End-to-end therapeutic response system"""
    def __init__(self):
        self.client = OllamaClient(model_name="llama2:13b")
        self.agents = {
            'context': PatientContextAnalyzer(self.client),
            'generator': TherapeuticResponseGenerator(self.client),
            'safety': ClinicalSafetyChecker(self.client)
        }
        self.conversation_history = []
        self.metrics = defaultdict(lambda: {'count': 0, 'errors': 0, 'timeouts': 0})
        
    def process_session(self, patient_input: str) -> Dict:
        """Main processing pipeline with safety checks"""
        result = {
            'response': '',
            'context_analysis': {},
            'safety_check': {},
            'errors': [],
            'warnings': [],
            'timings': {},
            'crisis_alert': False
        }
        
        try:
            processed_input = str(patient_input)[:5000] if patient_input else ""
            
            # Context analysis
            ctx_start = time.time()
            context, ctx_time = self.agents['context'].analyze_context(processed_input, self.conversation_history)
            result['context_analysis'] = context
            result['timings']['context'] = ctx_time
            
            # Crisis handling
            crisis_warnings = context.get('crisis_warnings', [])
            if crisis_warnings:
                result['crisis_alert'] = True
                result['response'] = self._handle_crisis_situation(crisis_warnings)
                return self._compile_result(result)
                
            # Generate response
            gen_start = time.time()
            response = self.agents['generator'].generate_response(context, self.conversation_history)
            result['timings']['generation'] = time.time() - gen_start
            
            if response.error:
                result['errors'].append("Response generation failed")
            else:
                result['response'] = response.text
                
                # Safety check
                safety_start = time.time()
                safety_check = self.agents['safety'].evaluate_response(response.text, context)
                result['safety_check'] = safety_check
                result['timings']['safety'] = time.time() - safety_start
                
                # Refinement
                if safety_check.get('safety_score', 0) < 0.7:
                    refined = self._refine_response(response.text, safety_check)
                    result['response'] = refined
                    result['timings']['refinement'] = time.time() - safety_start
                
            # Update history
            self._update_history(processed_input, result['response'])
            
        except Exception as e:
            result['errors'].append(f"Processing failed: {str(e)}")
        
        return self._compile_result(result)
        
    def _handle_crisis_situation(self, crisis_warnings: List[str]) -> str:
        """Standard crisis response protocol"""
        crisis_template = """I hear that you're experiencing {issues}. That sounds incredibly difficult. \
        Please know you're not alone. I strongly recommend contacting {resource} immediately at {phone}. \
        Would you like me to help you connect with support?"""
        
        resources = {
            'suicidal': ('the National Suicide Prevention Lifeline', '988'),
            'abuse': ('the National Domestic Violence Hotline', '1-800-799-7233'),
            'general': ('Crisis Text Line', 'Text HOME to 741741')
        }
        
        crisis_type = next((ct for ct in resources if ct in str(crisis_warnings)), 'general')
        resource, phone = resources[crisis_type]
        
        return crisis_template.format(
            issues=", ".join(crisis_warnings),
            resource=resource,
            phone=phone
        )
        
    def _refine_response(self, response: str, safety_check: Dict) -> str:
        """Improve response based on safety check"""
        prompt = f"""Refine therapist response:
        Original Response: {response[:2000]}
        Safety Concerns: {json.dumps(safety_check['concerns'])[:1000]}
        
        Create improved response that:
        - Maintains therapeutic boundaries
        - Removes any medical advice
        - Enhances empathy
        - Uses open-ended questions
        
        Revised Response:"""
        
        refined = self.agents['generator'].safe_generate(prompt)
        return refined.text if not refined.error else response
        
    def _update_history(self, patient_input: str, therapist_response: str):
        """Maintain conversation history"""
        self.conversation_history.extend([
            f"Patient: {patient_input[:500]}",
            f"Therapist: {therapist_response[:500]}"
        ])
        self.conversation_history = self.conversation_history[-12:]
        
    def _compile_result(self, raw_data: dict) -> dict:
        """Structure final output"""
        return {
            'response': raw_data.get('response', ''),
            'context_analysis': raw_data.get('context_analysis', {}),
            'safety_check': raw_data.get('safety_check', {}),
            'crisis_alert': raw_data.get('crisis_alert', False),
            'errors': raw_data.get('errors', []),
            'warnings': raw_data.get('warnings', []),
            'timings': raw_data.get('timings', {}),
            'conversation_history': self.conversation_history[-4:]
        }

# Cell 5: Test Scenarios
def test_empty_input():
    """Test handling of empty patient input"""
    system = TherapeuticResponseSystem()
    result = system.process_session("")
    print("\nTest 1 - Empty Input:")
    print(f"Response: {result['response'][:150]}")
    print(f"Errors: {result['errors']}")
    print(f"Crisis Alert: {result['crisis_alert']}")

def test_crisis_detection():
    """Test suicide risk detection"""
    system = TherapeuticResponseSystem()
    result = system.process_session("I can't take it anymore, I have a bottle of pills ready")
    print("\nTest 2 - Crisis Detection:")
    print(f"Response: {result['response'][:250]}")
    print(f"Crisis Alert: {result['crisis_alert']}")
    print(f"Safety Concerns: {result['safety_check'].get('concerns', [])[:2]}")

def test_invalid_input_type():
    """Test non-string input handling"""
    system = TherapeuticResponseSystem()
    result = system.process_session(12345)
    print("\nTest 3 - Invalid Input Type:")
    print(f"Response: {result['response'][:150]}")
    print(f"Errors: {result['errors']}")

if __name__ == "__main__":
    test_empty_input()
    test_crisis_detection()
    test_invalid_input_type()

2025-02-19 16:44:43,102 - INFO - Pull progress: pulling manifest
2025-02-19 16:44:43,563 - INFO - Pull progress: pulling 2609048d349e
2025-02-19 16:44:43,794 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:43,855 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:43,914 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:43,976 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,035 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,095 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,154 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,215 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,274 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,334 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,394 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,455 - INFO - Pull progress: pulling 8c17c2ebb0ea
2025-02-19 16:44:44,515 - INFO - Pull 

In [5]:
import logging
# from your_module import OllamaClient  # Replace your_module

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def load_and_verify_model(model_name: str, base_url: str):
    try:
        client = OllamaClient(model_name=model_name, base_url=base_url)  # This will load/verify
        logger.info(f"Model {model_name} loaded successfully.")
        return True # Indicate success
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return False # Indicate failure


if __name__ == "__main__":
    model_name = "hf.co/TheDrummer/Gemmasutra-Mini-2B-v1-GGUF:Q3_K_L"
    base_url = "http://localhost:11434"
    success = load_and_verify_model(model_name, base_url)

    if success:
        print("Model loaded. You can now run your main script.")
    else:
        print("Model loading failed.")

2025-02-19 16:42:56,931 - INFO - Model hf.co/TheDrummer/Gemmasutra-Mini-2B-v1-GGUF:Q3_K_L loaded successfully.


Model loaded. You can now run your main script.
