In [4]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import time
import json
import requests
import logging  
import pickle
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from collections import defaultdict

app = Flask(__name__)
# CORS(app, resources={
#     r"/generate": {
#         "origins": ["http://localhost:3000"],
#         "methods": ["POST", "OPTIONS"],
#         "allow_headers": ["Content-Type"]
#     }
# })

CORS(app, resources={r"/generate": {"origins": "*"}})


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class TherapeuticResponse:
    """Enhanced response structure for therapeutic context"""
    text: str
    timestamp: float
    error: bool = False
    processing_time: float = 0.0
    error_details: str = ""
    timeout: bool = False
    empathy_score: float = 0.0
    safety_checks: List[str] = None
    ethical_considerations: List[str] = None
    refinement_suggestions: List[str] = None
    crisis_flag: bool = False

class OllamaClient:
    """Robust Ollama client with configurable timeouts"""
    def __init__(self, model_name: str = "hf.co/TheDrummer/Gemmasutra-Mini-2B-v1-GGUF:Q3_K_L", base_url: str = "http://localhost:11434"):
        self.model_name = model_name
        self.base_url = base_url
        self.max_retries = 5
        self.request_timeout = 300
        self._verify_model()

    def _parse_json_safe(self, text: str):
        """Enhanced JSON parsing with fallback"""
        clean_text = text.strip()
        if not clean_text:
            return {"error": "Empty response"}

        try:
            return json.loads(clean_text)
        except json.JSONDecodeError:
            try:
                start = clean_text.find('{')
                end = clean_text.rfind('}') + 1
                return json.loads(clean_text[start:end])
            except:
                return {"error": f"Invalid JSON format: {clean_text[:200]}..."}
        except Exception as e:
            return {"error": str(e)}

    def _verify_model(self):
        """Model verification with status checks"""
        for attempt in range(self.max_retries):
            try:
                resp = requests.get(f"{self.base_url}/api/tags", timeout=10)
                if resp.status_code == 200:
                    data = self._parse_json_safe(resp.text)
                    models = [m['name'] for m in data.get('models', [])]
                    if any(self.model_name in m for m in models):
                        return
                    self._pull_model()
                    return
                logger.warning(f"Model check failed (status {resp.status_code})")
            except Exception as e:
                logger.warning(f"Model check attempt {attempt+1} failed: {e}")
                time.sleep(2 ** attempt)
        raise ConnectionError(f"Couldn't connect to Ollama after {self.max_retries} attempts")

    def _pull_model(self):
        """Model pulling with progress tracking"""
        try:
            resp = requests.post(
                f"{self.base_url}/api/pull",
                json={"name": self.model_name},
                stream=True,
                timeout=600
            )
            for line in resp.iter_lines():
                if line:
                    try:
                        status = self._parse_json_safe(line).get('status', '')
                        logger.info(f"Pull progress: {status}")
                    except:
                        continue
        except Exception as e:
            logger.error(f"Model pull failed: {e}")
            raise

    def generate(self, prompt: str) -> Tuple[str, bool]:
        """Generation with configurable timeout and retries"""
        for attempt in range(self.max_retries):
            try:
                with ThreadPoolExecutor() as executor:
                    future = executor.submit(
                        requests.post,
                        f"{self.base_url}/api/generate",
                        json={
                            "model": self.model_name,
                            "prompt": prompt[:4000],
                            "stream": False,
                            "options": {"temperature": 0.5}
                        },
                        timeout=self.request_timeout
                    )
                    resp = future.result(timeout=self.request_timeout)
                    data = self._parse_json_safe(resp.text)
                    return data.get("response", ""), False
            except FutureTimeoutError:
                logger.warning(f"Generation timed out (attempt {attempt+1})")
                return f"Error: Timeout after {self.request_timeout}s", True
            except Exception as e:
                logger.warning(f"Attempt {attempt+1} failed: {e}")
                time.sleep(1)
        return f"Error: Failed after {self.max_retries} attempts", True

class BaseAgent:
    """Timeout-aware base agent"""
    def __init__(self, client: OllamaClient):
        self.client = client
        self.retry_count = 3
        self.max_wait = 300

    def safe_generate(self, prompt: str) -> TherapeuticResponse:
        """Generation with time budget tracking"""
        start_time = time.time()
        timeout_occurred = False

        if not isinstance(prompt, str) or len(prompt.strip()) == 0:
            return TherapeuticResponse(
                text="Error: Invalid input prompt",
                timestamp=start_time,
                error=True,
                error_details="Empty or non-string prompt",
                processing_time=0.0
            )

        for attempt in range(self.retry_count):
            try:
                with ThreadPoolExecutor() as executor:
                    future = executor.submit(self.client.generate, prompt)
                    text, error = future.result(timeout=self.max_wait)

                    return TherapeuticResponse(
                        text=text,
                        timestamp=start_time,
                        error=error,
                        processing_time=time.time() - start_time,
                        error_details=text if error else "",
                        timeout=timeout_occurred
                    )
            except FutureTimeoutError:
                logger.error(f"Generation timed out after {self.max_wait}s")
                timeout_occurred = True
            except Exception as e:
                error_msg = str(e)
                logger.error(f"Generation error: {e}")

        return TherapeuticResponse(
            text=f"Final error: {error_msg}" if 'error_msg' in locals() else "Unknown error",
            timestamp=start_time,
            error=True,
            error_details=error_msg if 'error_msg' in locals() else "",
            processing_time=time.time() - start_time,
            timeout=timeout_occurred
        )

# Initialize the Ollama client and the base agent
client = OllamaClient()
agent = BaseAgent(client)

def _build_cors_preflight_response():
    response = jsonify({"status": "preflight"})
    response.headers.add("Access-Control-Allow-Origin", "*")
    response.headers.add("Access-Control-Allow-Headers", "Content-Type")
    response.headers.add("Access-Control-Allow-Methods", "POST, OPTIONS")
    return response

def _corsify_actual_response(response):
    response.headers.add("Access-Control-Allow-Origin", "*")
    return response

@app.route('/generate', methods=['POST', 'OPTIONS'])
def generate():
    if request.method == 'OPTIONS':
        return _build_cors_preflight_response()
    elif request.method == 'POST':
        data = request.json
        prompt = data.get('prompt', '')
        response = agent.safe_generate(prompt)
        return _corsify_actual_response(jsonify(response.__dict__))

if __name__ == '__main__':
    app.run(port=5000, debug=True)

 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
2025-03-01 17:49:46,968 - INFO - [33mPress CTRL+C to quit[0m
2025-03-01 17:49:46,968 - INFO -  * Restarting with stat
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/homebrew/Caskroom/miniforge/base/envs/nunu24/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/opt/homebrew/Caskroom/miniforge/base/envs/nunu24/lib/python3.12/site-packages/traitlets/config/application.py", line 1074, in launch_instance
    app.initialize(argv)
  File "/opt/homebrew/Caskroom/miniforge/base/envs/nunu24/lib/python3.12/site-packages/traitlets/config/application.py", line 118, in inner
    return method(app, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Caskroom/miniforge/base/envs/nunu24/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 692, in initialize
    self.i

SystemExit: 1

In [6]:
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import time
import json
import requests
import logging  
from dataclasses import dataclass
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError

app = Flask(__name__)
CORS(app, resources={
    r"/generate": {
        "origins": ["http://localhost:3000"],
        "methods": ["POST", "OPTIONS"],
        "allow_headers": ["Content-Type"],
        "supports_credentials": True
    }
})

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

@dataclass
class TherapeuticResponse:
    """Enhanced response structure for therapeutic context"""
    text: str
    timestamp: float
    error: bool = False
    processing_time: float = 0.0
    error_details: str = ""
    timeout: bool = False
    empathy_score: float = 0.0
    safety_checks: List[str] = None
    ethical_considerations: List[str] = None
    refinement_suggestions: List[str] = None
    crisis_flag: bool = False

class OllamaClient:
    """Production-ready Ollama client with enhanced error handling"""
    def __init__(self, model_name: str = "hf.co/TheDrummer/Gemmasutra-Mini-2B-v1-GGUF:Q3_K_L", 
                 base_url: str = "http://localhost:11434"):
        self.model_name = model_name
        self.base_url = base_url
        self.max_retries = 3
        self.request_timeout = 120
        self.session = requests.Session()
        self._verify_model()

    def _parse_json_safe(self, text: str):
        """Robust JSON parsing with multiple fallback strategies"""
        clean_text = text.strip()
        try:
            return json.loads(clean_text)
        except json.JSONDecodeError:
            for wrapper in ['{}', '[]']:
                try:
                    return json.loads(wrapper[0] + clean_text + wrapper[1])
                except:
                    continue
            return {"error": f"Invalid JSON format: {clean_text[:200]}..."}
        except Exception as e:
            return {"error": str(e)}

    def _verify_model(self):
        """Model verification with exponential backoff"""
        for attempt in range(self.max_retries):
            try:
                resp = self.session.get(f"{self.base_url}/api/tags", timeout=10)
                if resp.ok:
                    models = [m['name'] for m in self._parse_json_safe(resp.text).get('models', [])]
                    if any(self.model_name in m for m in models):
                        return
                    self._pull_model()
                    return
                logger.warning(f"Model check failed (status {resp.status_code})")
            except Exception as e:
                logger.warning(f"Model check attempt {attempt+1} failed: {e}")
                time.sleep(2 ** attempt)
        raise ConnectionError(f"Couldn't connect to Ollama after {self.max_retries} attempts")

    def _pull_model(self):
        """Model pulling with chunked streaming"""
        try:
            with self.session.post(
                f"{self.base_url}/api/pull",
                json={"name": self.model_name},
                stream=True,
                timeout=600
            ) as resp:
                for chunk in resp.iter_content(chunk_size=None):
                    if chunk:
                        status = self._parse_json_safe(chunk.decode()).get('status', '')
                        logger.info(f"Pull progress: {status}")
        except Exception as e:
            logger.error(f"Model pull failed: {e}")
            raise

    def generate(self, prompt: str) -> Tuple[str, bool]:
        """Production-grade generation with circuit breaker pattern"""
        sanitized_prompt = prompt[:4000].strip()
        for attempt in range(self.max_retries):
            try:
                resp = self.session.post(
                    f"{self.base_url}/api/generate",
                    json={
                        "model": self.model_name,
                        "prompt": sanitized_prompt,
                        "stream": False,
                        "options": {"temperature": 0.5}
                    },
                    timeout=self.request_timeout
                )
                if resp.status_code == 200:
                    data = self._parse_json_safe(resp.text)
                    return data.get("response", ""), False
                logger.warning(f"Generation failed (status {resp.status_code})")
            except Exception as e:
                logger.error(f"Attempt {attempt+1} failed: {e}")
                time.sleep(1)
        return f"Error: Service unavailable after {self.max_retries} attempts", True

class BaseAgent:
    """Production-ready agent with timeout management"""
    def __init__(self, client: OllamaClient):
        self.client = client
        self.max_wait = 90

    def safe_generate(self, prompt: str) -> TherapeuticResponse:
        """Generation with circuit breaker and time budget tracking"""
        start_time = time.time()
        response = TherapeuticResponse(
            text="",
            timestamp=start_time,
            error=True,
            processing_time=0.0
        )

        if not isinstance(prompt, str) or not prompt.strip():
            response.error_details = "Invalid input: Empty or non-string prompt"
            return response

        try:
            with ThreadPoolExecutor() as executor:
                future = executor.submit(self.client.generate, prompt)
                text, error = future.result(timeout=self.max_wait)
                
                response.text = text
                response.error = error
                response.processing_time = time.time() - start_time
                response.error_details = text if error else ""
                return response
        except FutureTimeoutError:
            response.error_details = f"Timeout after {self.max_wait}s"
            response.timeout = True
        except Exception as e:
            response.error_details = str(e)
        
        response.processing_time = time.time() - start_time
        return response

client = OllamaClient()
agent = BaseAgent(client)

@app.route('/generate', methods=['POST', 'OPTIONS'])
def generate():
    if request.method == 'OPTIONS':
        return _build_cors_preflight_response()
    
    data = request.get_json(silent=True) or {}
    prompt = data.get('prompt', '')
    
    if not prompt:
        return jsonify({
            "error": True,
            "text": "Empty prompt",
            "error_details": "No prompt provided in request"
        }), 400
    
    response = agent.safe_generate(prompt)
    return _corsify_actual_response(jsonify(response.__dict__))

def _build_cors_preflight_response():
    response = jsonify({"status": "preflight"})
    response.headers.add("Access-Control-Allow-Origin", "http://localhost:3000")
    response.headers.add("Access-Control-Allow-Headers", "Content-Type")
    response.headers.add("Access-Control-Allow-Methods", "POST, OPTIONS")
    return response

def _corsify_actual_response(response):
    response.headers.add("Access-Control-Allow-Origin", "http://localhost:3000")
    response.headers.add("Access-Control-Expose-Headers", "Content-Type")
    return response

if __name__ == '__main__':
    from waitress import serve
    logger.info("Starting production server on port 5000")
    serve(app, host="0.0.0.0", port=5000, threads=4)

2025-03-01 17:52:41,562 - INFO - Starting production server on port 5000


OSError: [Errno 48] Address already in use