In [None]:
pip install torch



In [None]:
import torch
import torch.nn as nn
import networkx as nx
import numpy as np
from datetime import datetime
from typing import Dict, List, Tuple, Optional
import json

class MedicalKnowledgeCell(nn.Module):
    """Specialized neural memory cell for medical knowledge"""
    def __init__(self, input_size: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size

        # Knowledge processing layers
        self.knowledge_encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )

        # Evidence assessment network
        self.evidence_network = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 3)  # Evidence levels: High, Medium, Low
        )

        # Source reliability assessment
        self.source_network = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1),
            nn.Sigmoid()
        )

        # Context integration
        self.context_attention = nn.MultiheadAttention(hidden_size, num_heads=4)

    def forward(self, input_data: torch.Tensor, context: Optional[torch.Tensor] = None) -> Dict:
        # Encode input knowledge
        encoded = self.knowledge_encoder(input_data)

        # Process context if available
        if context is not None:
            encoded, _ = self.context_attention(
                encoded.unsqueeze(0),
                context.unsqueeze(0),
                context.unsqueeze(0)
            )
            encoded = encoded.squeeze(0)

        # Assess evidence level
        evidence_scores = torch.softmax(self.evidence_network(encoded), dim=-1)

        # Calculate source reliability
        reliability = self.source_network(encoded)

        return {
            'encoded_knowledge': encoded,
            'evidence_scores': evidence_scores,
            'reliability': reliability,
        }

class HealthcareKnowledgeAssistant:
    def __init__(self, input_size: int = 256, hidden_size: int = 512):
        self.knowledge_graph = nx.DiGraph()
        self.memory_cell = MedicalKnowledgeCell(input_size, hidden_size)
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Initialize medical knowledge categories
        self.categories = {
            'diagnosis': {},
            'treatment': {},
            'medication': {},
            'contraindication': {},
            'research': {}
        }

        # Evidence level thresholds
        self.evidence_thresholds = {
            'high': 0.8,
            'medium': 0.6,
            'low': 0.4
        }

    def encode_medical_text(self, text: str) -> torch.Tensor:
        """Convert medical text to tensor (simplified for POC)"""
        # In real implementation, use medical-specific text encoder
        return torch.randn(self.input_size)  # Placeholder

    def add_medical_knowledge(self,
                            content: str,
                            category: str,
                            source: str,
                            related_ids: List[str] = None) -> Dict:
        """Add new medical knowledge to the system"""
        # Encode content
        encoded_content = self.encode_medical_text(content)

        # Process through memory cell
        result = self.memory_cell(encoded_content)

        # Generate knowledge ID
        knowledge_id = f"{category}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

        # Add to knowledge graph
        self.knowledge_graph.add_node(
            knowledge_id,
            category=category,
            content=content,
            source=source,
            evidence_scores=result['evidence_scores'].detach().numpy(),
            reliability=result['reliability'].item(),
            timestamp=datetime.now()
        )

        # Add relationships
        if related_ids:
            for related_id in related_ids:
                if related_id in self.knowledge_graph:
                    self.knowledge_graph.add_edge(
                        knowledge_id,
                        related_id,
                        relationship_type='related'
                    )

        # Update category index
        if category in self.categories:
            self.categories[category][knowledge_id] = {
                'content': content,
                'reliability': result['reliability'].item(),
                'evidence_scores': result['evidence_scores'].detach().numpy()
            }

        return {
            'knowledge_id': knowledge_id,
            'reliability': result['reliability'].item(),
            'evidence_levels': {
                'high': result['evidence_scores'][0].item(),
                'medium': result['evidence_scores'][1].item(),
                'low': result['evidence_scores'][2].item()
            }
        }

    def query_medical_knowledge(self, query: str, category: Optional[str] = None) -> List[Dict]:
        """Query medical knowledge with confidence scoring"""
        # Encode query
        encoded_query = self.encode_medical_text(query)

        results = []
        categories_to_search = [category] if category else self.categories.keys()

        for cat in categories_to_search:
            for knowledge_id, data in self.categories.get(cat, {}).items():
                node_data = self.knowledge_graph.nodes[knowledge_id]

                # Calculating relevance (simplified for POC)
                relevance = torch.rand(1).item()  # In real implementation, use semantic similarity

                # Combine relevance with reliability and evidence
                confidence = (
                    relevance * 0.4 +
                    node_data['reliability'] * 0.3 +
                    node_data['evidence_scores'][0] * 0.3
                )

                results.append({
                    'knowledge_id': knowledge_id,
                    'content': node_data['content'],
                    'category': node_data['category'],
                    'confidence': confidence,
                    'reliability': node_data['reliability'],
                    'evidence_levels': {
                        'high': node_data['evidence_scores'][0],
                        'medium': node_data['evidence_scores'][1],
                        'low': node_data['evidence_scores'][2]
                    },
                    'source': node_data['source'],
                    'timestamp': node_data['timestamp']
                })

        # Sort by confidence
        results.sort(key=lambda x: x['confidence'], reverse=True)
        return results

    def get_related_knowledge(self, knowledge_id: str) -> List[Dict]:
        """Get related medical knowledge"""
        if knowledge_id not in self.knowledge_graph:
            return []

        related_knowledge = []
        for _, related_id in self.knowledge_graph.edges(knowledge_id):
            node_data = self.knowledge_graph.nodes[related_id]
            related_knowledge.append({
                'knowledge_id': related_id,
                'content': node_data['content'],
                'category': node_data['category'],
                'reliability': node_data['reliability'],
                'evidence_levels': {
                    'high': node_data['evidence_scores'][0],
                    'medium': node_data['evidence_scores'][1],
                    'low': node_data['evidence_scores'][2]
                }
            })

        return related_knowledge

def demo_healthcare_assistant():
    """Demonstrate the Healthcare Knowledge Assistant"""
    # Initialize assistant
    assistant = HealthcareKnowledgeAssistant()

    # sample medical knowledge(Demonstration Purposes)
    diagnosis_result = assistant.add_medical_knowledge(
        content="Persistent headache with aura may indicate migraine disorder",
        category="diagnosis",
        source="Medical Journal of Neurology",
    )

    treatment_result = assistant.add_medical_knowledge(
        content="Beta-blockers shown effective in migraine prevention",
        category="treatment",
        source="Clinical Neurology Research",
        related_ids=[diagnosis_result['knowledge_id']]
    )

    # Query knowledge
    query_results = assistant.query_medical_knowledge(
        query="migraine treatment options",
        category="treatment"
    )

    # Get related knowledge
    related_info = assistant.get_related_knowledge(treatment_result['knowledge_id'])

    return {
        'diagnosis': diagnosis_result,
        'treatment': treatment_result,
        'query_results': query_results,
        'related_information': related_info
    }

if __name__ == "__main__":
    results = demo_healthcare_assistant()
    print(json.dumps(results, indent=2, default=str))

{
  "diagnosis": {
    "knowledge_id": "diagnosis_20241026_103844",
    "reliability": 0.48270535469055176,
    "evidence_levels": {
      "high": 0.32098549604415894,
      "medium": 0.34192436933517456,
      "low": 0.3370901942253113
    }
  },
  "treatment": {
    "knowledge_id": "treatment_20241026_103844",
    "reliability": 0.502500057220459,
    "evidence_levels": {
      "high": 0.3310445249080658,
      "medium": 0.3301960527896881,
      "low": 0.3387594223022461
    }
  },
  "query_results": [
    {
      "knowledge_id": "treatment_20241026_103844",
      "content": "Beta-blockers shown effective in migraine prevention",
      "category": "treatment",
      "confidence": 0.2594859391450882,
      "reliability": 0.502500057220459,
      "evidence_levels": {
        "high": "0.33104452",
        "medium": "0.33019605",
        "low": "0.33875942"
      },
      "source": "Clinical Neurology Research",
      "timestamp": "2024-10-26 10:38:44.883262"
    }
  ],
  "related_infor