In [None]:
import numpy as np
import re
import json
import os
import pandas as pd
import requests
import warnings
warnings.filterwarnings('ignore')

class PHIDetector:
    def __init__(self, api_key=None):
        self.api_key = api_key or os.getenv('GROQ_API_KEY')
        self.endpoint = "https://api.groq.com/openai/v1/chat/completions"
        self.use_api = self._test_connection() if self.api_key else False
        
    def _test_connection(self):
        try:
            response = requests.post(
                self.endpoint,
                headers={'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json'},
                json={"model": "llama3-8b-8192", "messages": [{"role": "user", "content": "test"}], "max_tokens": 5},
                timeout=30
            )
            return response.status_code == 200
        except:
            return False
    
    def detect_phi(self, text):
        print(f"API Available: {self.use_api}")
        if self.use_api:
            print("ATTEMPTING GROQ LLM DETECTION...")
            entities = self._api_detection(text)
            if entities and any(e.get('source') == 'groq_api' for e in entities):
                print("SUCCESS: Using Groq LLM results")
                return entities
            else:
                print("FALLBACK: Groq failed, using pattern matching")
                return self._pattern_detection(text)
        else:
            print("NO API: Using pattern matching only")
            return self._pattern_detection(text)
    
    def _api_detection(self, text):
        system_prompt = "Find all PHI in medical text. Return JSON: [{'text': 'match', 'category': 'PERSON|DATE|PHONE|EMAIL|ADDRESS|ID|LOCATION'}]"
        
        try:
            print("Making API call to Groq...")
            response = requests.post(
                self.endpoint,
                headers={'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json'},
                json={
                    "model": "llama3-8b-8192",
                    "messages": [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": text[:1000]}
                    ],
                    "max_tokens": 1024,
                    "temperature": 0
                },
                timeout=60
            )
            
            print(f"API Response Status: {response.status_code}")
            
            if response.status_code == 200:
                result = response.json()['choices'][0]['message']['content']
                print(f"API Response Length: {len(result)} characters")
                print(f"API Response Preview: {result[:200]}...")
                entities = self._parse_api_response(result, text)
                print(f"Entities from API: {len(entities)}")
                return entities
            else:
                print(f"API Error: {response.text}")
                
        except Exception as e:
            print(f"API Exception: {e}")
        
        print("API failed - returning empty list")
        return []
    
    def _parse_api_response(self, response, text):
        entities = []
        print("Parsing Groq API response...")
        
        try:
            # Clean response - remove any explanatory text before JSON
            cleaned = response.strip()
            
            # Remove markdown formatting
            if '```json' in cleaned:
                start = cleaned.find('```json') + 7
                end = cleaned.find('```', start)
                cleaned = cleaned[start:end].strip()
            elif '```' in cleaned:
                start = cleaned.find('```') + 3
                end = cleaned.find('```', start)
                cleaned = cleaned[start:end].strip()
            
            # Find JSON array boundaries more robustly
            json_start = cleaned.find('[')
            json_end = cleaned.rfind(']') + 1
            
            if json_start != -1 and json_end != -1:
                json_str = cleaned[json_start:json_end]
                print(f"Extracted JSON: {json_str[:100]}...")
                
                data = json.loads(json_str)
                print(f"Successfully parsed JSON with {len(data)} items")
                
                for item in data:
                    if isinstance(item, dict) and 'text' in item and item['text'] in text:
                        pos = text.find(item['text'])
                        entities.append({
                            'text': item['text'],
                            'start': pos,
                            'end': pos + len(item['text']),
                            'label': item.get('category', 'PHI'),
                            'source': 'groq_api'
                        })
                        
                print(f"Valid entities from Groq: {len(entities)}")
            else:
                print("No JSON array found in response")
                
        except json.JSONDecodeError as e:
            print(f"JSON parsing failed: {e}")
            print("Raw response text:")
            print(repr(response[:200]))
            print("Attempting fallback parsing...")
            
            # Improved fallback parsing
            lines = response.split('\n')
            for line in lines:
                # Look for text in quotes within JSON-like structure
                if '"text":' in line:
                    match = re.search(r'"text":\s*"([^"]+)"', line)
                    if match:
                        found_text = match.group(1)
                        if found_text in text and len(found_text) > 2:
                            pos = text.find(found_text)
                            entities.append({
                                'text': found_text,
                                'start': pos,
                                'end': pos + len(found_text),
                                'label': 'PHI',
                                'source': 'groq_fallback'
                            })
                            
            print(f"Fallback entities: {len(entities)}")
            
        return entities
    
    def _pattern_detection(self, text):
        print("USING PATTERN MATCHING (No LLM)")
        entities = []
        patterns = {
            'PERSON_NAME': [r'\b[A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+O\'[A-Z][a-z]+)?\b'],
            'DATE': [r'\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b'],
            'PHONE': [r'\+\d{3}-\d{2}-\d{3}-\d{4}', r'\b\d{3}-\d{3}-\d{4}\b'],
            'EMAIL': [r'\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b'],
            'ADDRESS': [r'\bApartment\s+\d+[A-Z]?,?\s+\d+\s+[A-Za-z\s]+,\s*[A-Za-z\s]+,\s*[A-Za-z]+\b'],
            'MEDICAL_ID': [r'\bMRN-\d+\b', r'\bVHI-\d+\b', r'\b\d{8}[A-Z]\b']
        }
        
        for category, pattern_list in patterns.items():
            for pattern in pattern_list:
                for match in re.finditer(pattern, text):
                    if not self._is_medical_preserve(match.group()):
                        entities.append({
                            'text': match.group(),
                            'start': match.start(),
                            'end': match.end(),
                            'label': category,
                            'source': 'pattern_matching'  # Mark as pattern-detected
                        })
                        
        print(f"Pattern matching found: {len(entities)} entities")
        return entities
    
    def _is_medical_preserve(self, text):
        text_lower = text.lower()
        
        if re.match(r'^\d{1,3}-year', text_lower):
            age_match = re.search(r'(\d+)', text_lower)
            if age_match and int(age_match.group(1)) < 90:
                return True
        
        medical_patterns = [
            r'^\d+\s*mg\b', r'^\d+\s*bpm\b', r'^\d+/\d+\s*mmhg\b', 
            r'^\d+\.?\d*%\b', r'^bmi\s+\d+\b', r'^\d+\s*ml/min\b'
        ]
        
        if any(re.match(p, text_lower) for p in medical_patterns):
            return True
            
        medical_terms = {
            'metformin', 'atorvastatin', 'vitamin', 'diabetes', 'sarcoidosis',
            'cardiology', 'pulmonology', 'ecg', 'ct scan'
        }
        
        return any(term in text_lower for term in medical_terms)
    
    def deidentify(self, text):
        entities = self.detect_phi(text)
        entities = self._validate_entities(entities, text)
        replacements = self._generate_tags(entities)
        result_text = self._apply_replacements(text, entities, replacements)
        
        return {
            'original': text,
            'deidentified': result_text,
            'entities_found': len(entities),
            'entities': entities
        }
    
    def _validate_entities(self, entities, text):
        valid = []
        seen = set()
        
        for entity in entities:
            key = f"{entity['text']}_{entity['start']}"
            if key not in seen and len(entity['text'].strip()) > 1:
                valid.append(entity)
                seen.add(key)
        
        return sorted(valid, key=lambda x: x['start'])
    
    def _generate_tags(self, entities):
        replacements = {}
        counters = {}
        
        for entity in entities:
            label = entity['label']
            text = entity['text']
            
            if '@' in text:
                category = 'EMAIL'
            elif re.match(r'^\+?\d{3}-\d{2}-\d{3}-\d{4}$', text):
                category = 'PHONE'
            elif re.match(r'^\d{3}-\d{3}-\d{4}$', text):
                category = 'PHONE'
            elif 'mrn' in text.lower() or 'vhi' in text.lower():
                category = 'MEDICAL_ID'
            elif re.match(r'^[A-Z][a-z]+\s+[A-Z]', text):
                category = 'PERSON'
            elif any(month in text.lower() for month in ['january', 'april', 'november']):
                category = 'DATE'
            elif 'apartment' in text.lower() or 'court' in text.lower():
                category = 'ADDRESS'
            else:
                category = label if label != 'PHI' else 'ID'
            
            counters[category] = counters.get(category, 0) + 1
            replacements[text] = f"[{category}_{counters[category]}]"
        
        return replacements
    
    def _apply_replacements(self, text, entities, replacements):
        result = text
        for entity in sorted(entities, key=lambda x: x['start'], reverse=True):
            start, end = entity['start'], entity['end']
            if start < len(result) and end <= len(result):
                replacement = replacements[entity['text']]
                result = result[:start] + replacement + result[end:]
        return result


def run_demo(groq_api_key=None):
    text = """Patient Sarah Johnson, a 45-year-old female, was admitted on March 15, 2024. She resides at 123 Oak Street, Boston, MA 02101. Contact: sarah.j@email.com or call 617-555-1234. Medical record MRN-789123 shows history of diabetes diagnosed in 2019. Her physician Dr. Michael Chen prescribed Metformin 500mg twice daily. Vital signs: BP 140/90 mmHg, pulse 78 bpm, BMI 28.5. Lab results indicate HbA1c of 7.2% and cholesterol 195 mg/dL. Insurance ID: BC-456789. Emergency contact is her husband James Johnson at work number 617-555-9876. Patient reported chest discomfort and shortness of breath during exercise. ECG showed normal sinus rhythm. Recommended cardiology follow-up in 2 weeks and dietary consultation."""
    
    print("="*80)
    print("PHI DETECTION METHOD ANALYSIS")
    print("="*80)
    
    # Use the provided API key
    detector = PHIDetector(groq_api_key)  
    result = detector.deidentify(text)
    
    # Analyze detection methods
    llm_entities = [e for e in result['entities'] if e.get('source', '').startswith('groq')]
    pattern_entities = [e for e in result['entities'] if e.get('source') == 'pattern_matching']
    
    print(f"\nDETECTION METHOD BREAKDOWN:")
    print(f"LLM (Groq) Entities: {len(llm_entities)}")
    print(f"Pattern Entities: {len(pattern_entities)}")
    print(f"Total Entities: {len(result['entities'])}")
    
    if llm_entities:
        print(f"\nLLM CONTRIBUTION: {(len(llm_entities)/len(result['entities']))*100:.1f}%")
        print("LLM-detected entities:")
        for e in llm_entities[:5]:
            print(f"  - '{e['text']}' -> {e['label']}")
    else:
        print(f"\nLLM CONTRIBUTION: 0% (API failed or not used)")
        
    if pattern_entities:
        print(f"\nPATTERN CONTRIBUTION: {(len(pattern_entities)/len(result['entities']))*100:.1f}%")
        print("Pattern-detected entities:")
        for e in pattern_entities[:5]:
            print(f"  - '{e['text']}' -> {e['label']}")
    
    print("="*80)
    print("ORIGINAL TEXT:")
    print(result['original'])
    print("\nDE-IDENTIFIED TEXT:")
    print(result['deidentified'])
    print("="*80)

if __name__ == "__main__":
    # Define your Groq API key here
    GROQ_API_KEY = "xxxx"  # Replace with your actual key
    
    # Pass the API key to the demo function
    run_demo(GROQ_API_KEY)