<a href="https://colab.research.google.com/github/Shuo-Zh/Shuo-Zh.github.io/blob/main/mcpsystem.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import grpc
from concurrent import futures
import time
import json
import numpy as np
import threading
import argparse
from uuid import uuid4
from google.protobuf import json_format
import logging
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("MCP-System")

# ======================
# 1. Protocol Buffers Definitions
# ======================
class ToolRequest:
    def __init__(self, tool_name="", input_data="", auth_token=""):
        self.tool_name = tool_name
        self.input_data = input_data
        self.auth_token = auth_token

class ToolResponse:
    def __init__(self, status="", output_data="", error_detail=""):
        self.status = status
        self.output_data = output_data
        self.error_detail = error_detail

class TokenRequest:
    def __init__(self, client_id="", client_secret=""):
        self.client_id = client_id
        self.client_secret = client_secret

class TokenResponse:
    def __init__(self, access_token="", token_type="", expires_in=0):
        self.access_token = access_token
        self.token_type = token_type
        self.expires_in = expires_in

# ======================
# 2. Security Tools Implementation
# ======================
class SecurityTool:
    def __init__(self, name, description):
        self.name = name
        self.description = description

    def execute(self, input_data):
        raise NotImplementedError

class ContentSecurityTool(SecurityTool):
    """Content Security Tool - Generates synthetic data"""
    def __init__(self):
        super().__init__("ContentGuard", "Generates synthetic data to protect real content")

    def execute(self, input_data):
        data_type = input_data.get("data_type", "financial")
        size = input_data.get("size", 100)

        logger.info(f"Generating synthetic data: {data_type}, size: {size}")

        if data_type == "financial":
            return {
                "transactions": np.random.normal(1000, 500, size).tolist(),
                "accounts": [f"ACC{np.random.randint(10000, 99999)}" for _ in range(size)]
            }
        elif data_type == "medical":
            conditions = ["Hypertension", "Diabetes", "Asthma", "Arthritis"]
            return {
                "patients": [f"PT{np.random.randint(1000, 9999)}" for _ in range(size)],
                "diagnoses": [np.random.choice(conditions) for _ in range(size)]
            }
        else:
            return {"error": f"Unsupported data type: {data_type}"}

class HallucinationDetectionTool(SecurityTool):
    """Hallucination Detection Tool - Simulated QA validation"""
    def __init__(self):
        super().__init__("TruthValidator", "Question-answer validation system to detect and prevent model hallucinations")
        self.qa_pairs = {
            "finance": {
                "What is the prime rate?": "The current prime rate is 8.5%",
                "How to calculate return on investment?": "ROI = (Net Profit / Cost of Investment) × 100%"
            },
            "healthcare": {
                "Symptoms of diabetes?": "Increased thirst, frequent urination, unexplained weight loss",
                "Normal blood pressure range?": "Below 120/80 mmHg"
            }
        }

    def execute(self, input_data):
        domain = input_data.get("domain", "finance")
        question = input_data.get("question", "")

        logger.info(f"Validating question: {domain} - {question}")

        if domain not in self.qa_pairs:
            return {"error": f"Unsupported domain: {domain}"}

        # Simple similarity matching
        best_match = None
        best_score = 0

        for known_question, answer in self.qa_pairs[domain].items():
            similarity = self._jaccard_similarity(question, known_question)
            if similarity > best_score:
                best_score = similarity
                best_match = (known_question, answer)

        return {
            "question": question,
            "matched_question": best_match[0] if best_match else None,
            "answer": best_match[1] if best_match else "No reliable information found",
            "confidence": float(best_score)
        }

    def _jaccard_similarity(self, a, b):
        a_set = set(a.lower().split())
        b_set = set(b.lower().split())
        intersection = len(a_set & b_set)
        union = len(a_set | b_set)
        return intersection / union if union > 0 else 0

class RiskAssessmentTool(SecurityTool):
    """Risk Assessment Tool - Intelligent risk evaluation"""
    def __init__(self):
        super().__init__("RiskOracle", "Intelligent risk assessment using Gaussian Process Regression")
        self.risk_levels = ["Low", "Medium", "High"]

    def execute(self, input_data):
        risk_factors = input_data.get("risk_factors", {})

        logger.info(f"Risk assessment: {len(risk_factors)} factors")

        # Simple weighted risk score
        total_score = 0
        weights = {
            "sensitivity": 0.4,
            "access_control": 0.3,
            "data_volume": 0.2,
            "external_access": 0.1
        }

        for factor, value in risk_factors.items():
            if factor in weights:
                total_score += weights[factor] * value

        # Determine risk level
        risk_level = self.risk_levels[0]
        if total_score > 0.6:
            risk_level = self.risk_levels[2]
        elif total_score > 0.3:
            risk_level = self.risk_levels[1]

        return {
            "risk_score": total_score,
            "risk_level": risk_level,
            "recommendations": [
                "Strengthen access control" if risk_factors.get("access_control", 0) > 0.5 else "",
                "Implement data anonymization" if risk_factors.get("sensitivity", 0) > 0.6 else ""
            ]
        }

# ======================
# NEW TOOLS IMPLEMENTATION
# ======================
class GPRRiskAssessmentTool(SecurityTool):
    """Gaussian Process Regression Tool - Advanced Risk Assessor"""
    def __init__(self):
        super().__init__("GPRiskAssessor", "Advanced risk assessment using Gaussian Process Regression")
        self.risk_levels = ["Low", "Medium", "High"]
        self.kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))

    def execute(self, input_data):
        risk_factors = input_data.get("risk_factors", {})
        historical_data = input_data.get("historical_data", [])

        logger.info(f"GPR Risk assessment with {len(risk_factors)} factors and {len(historical_data)} historical records")

        # Prepare data for GPR
        if not historical_data:
            return {"error": "Historical data required for GPR assessment"}

        try:
            # Convert historical data to features and labels
            X = []
            y = []
            for record in historical_data:
                features = [
                    record.get("sensitivity", 0),
                    record.get("access_control", 0),
                    record.get("data_volume", 0),
                    record.get("external_access", 0)
                ]
                X.append(features)
                y.append(record.get("outcome", 0))

            X = np.array(X)
            y = np.array(y)

            # Train-test split
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

            # Create and train GPR model
            gpr = GaussianProcessRegressor(kernel=self.kernel, n_restarts_optimizer=10)
            gpr.fit(X_train, y_train)

            # Evaluate model
            y_pred, sigma = gpr.predict(X_test, return_std=True)
            mse = np.mean((y_pred - y_test) ** 2)

            # Predict for current risk factors
            current_features = np.array([
                risk_factors.get("sensitivity", 0),
                risk_factors.get("access_control", 0),
                risk_factors.get("data_volume", 0),
                risk_factors.get("external_access", 0)
            ]).reshape(1, -1)

            risk_score, std_dev = gpr.predict(current_features, return_std=True)
            risk_score = risk_score[0]
            std_dev = std_dev[0]

            # Determine risk level with uncertainty
            risk_level = self.risk_levels[0]
            if risk_score > 0.6:
                risk_level = self.risk_levels[2]
            elif risk_score > 0.3:
                risk_level = self.risk_levels[1]

            return {
                "risk_score": float(risk_score),
                "std_dev": float(std_dev),
                "risk_level": risk_level,
                "model_mse": float(mse),
                "recommendations": [
                    "Enhanced monitoring required" if std_dev > 0.15 else "",
                    "Conduct threat modeling" if risk_score > 0.5 else "",
                    "Review access policies" if risk_factors.get("access_control", 0) > 0.4 else ""
                ]
            }
        except Exception as e:
            logger.error(f"GPR execution error: {str(e)}")
            return {"error": f"GPR processing failed: {str(e)}"}

class GANSyntheticDataTool(SecurityTool):
    """GAN-style Generator - Virtual Data Factory"""
    def __init__(self):
        super().__init__("GANDataFactory", "Advanced synthetic data generation using GAN-inspired techniques")
        self.data_templates = {
            "financial": {
                "fields": ["transaction_id", "amount", "account_id", "timestamp", "merchant_category"],
                "distributions": {
                    "amount": lambda: np.random.lognormal(6, 0.5),
                    "merchant_category": lambda: np.random.choice(["Retail", "Dining", "Travel", "Utilities"])
                }
            },
            "medical": {
                "fields": ["patient_id", "age", "diagnosis", "treatment", "outcome"],
                "distributions": {
                    "age": lambda: np.random.randint(18, 90),
                    "diagnosis": lambda: np.random.choice(["Hypertension", "Diabetes", "Asthma", "Arthritis"]),
                    "outcome": lambda: np.random.choice(["Improved", "Stable", "Worsened"])
                }
            }
        }

    def execute(self, input_data):
        data_type = input_data.get("data_type", "financial")
        size = input_data.get("size", 100)
        complexity = input_data.get("complexity", "medium")  # low, medium, high

        logger.info(f"GAN-style data generation: {data_type}, size: {size}, complexity: {complexity}")

        if data_type not in self.data_templates:
            return {"error": f"Unsupported data type: {data_type}"}

        template = self.data_templates[data_type]
        records = []

        # Generate correlated data based on complexity
        for i in range(size):
            record = {}
            for field in template["fields"]:
                if field in template["distributions"]:
                    # Add complexity-based variations
                    if complexity == "high":
                        # Add noise and correlations
                        if field == "amount" and "merchant_category" in record:
                            if record["merchant_category"] == "Travel":
                                record[field] = template["distributions"][field]() * 1.5
                            elif record["merchant_category"] == "Utilities":
                                record[field] = template["distributions"][field]() * 0.8
                            else:
                                record[field] = template["distributions"][field]()
                        elif field == "treatment" and "diagnosis" in record:
                            if record["diagnosis"] == "Diabetes":
                                record[field] = np.random.choice(["Metformin", "Insulin", "Lifestyle"])
                            elif record["diagnosis"] == "Hypertension":
                                record[field] = np.random.choice(["ACE Inhibitor", "Beta Blocker", "Diuretic"])
                            else:
                                record[field] = "Standard Care"
                        else:
                            record[field] = template["distributions"][field]()
                    else:
                        record[field] = template["distributions"][field]()
                else:
                    # Default field generation
                    if field == "transaction_id":
                        record[field] = f"TX{np.random.randint(1000000, 9999999)}"
                    elif field == "patient_id":
                        record[field] = f"PT{np.random.randint(10000, 99999)}"
                    elif field == "timestamp":
                        record[field] = f"2023-{np.random.randint(1,12):02d}-{np.random.randint(1,28):02d}"
                    else:
                        record[field] = "N/A"
            records.append(record)

        # Add dataset-level statistics
        stats = {}
        if data_type == "financial":
            amounts = [r["amount"] for r in records]
            stats = {
                "mean_amount": np.mean(amounts),
                "max_amount": np.max(amounts),
                "min_amount": np.min(amounts)
            }
        elif data_type == "medical":
            ages = [r["age"] for r in records]
            stats = {
                "mean_age": np.mean(ages),
                "diagnosis_distribution": {d: sum(1 for r in records if r["diagnosis"] == d)
                                          for d in set(r["diagnosis"] for r in records)}
            }

        return {
            "data_type": data_type,
            "complexity": complexity,
            "records": records,
            "statistics": stats
        }

class KnowledgeDistillationTool(SecurityTool):
    """Knowledge Distillation Framework - SOP for Knowledge Impartation"""
    def __init__(self):
        super().__init__("KDImpartation", "Knowledge distillation from complex to simple models")
        self.complex_models = {
            "fraud_detection": self._train_complex_fraud_model,
            "diagnosis_prediction": self._train_complex_diagnosis_model
        }

    def execute(self, input_data):
        task = input_data.get("task", "fraud_detection")
        data = input_data.get("training_data", [])
        distillation_method = input_data.get("method", "soft_labels")

        logger.info(f"Knowledge distillation for: {task}, method: {distillation_method}")

        if task not in self.complex_models:
            return {"error": f"Unsupported task: {task}"}

        if not data:
            return {"error": "Training data required for distillation"}

        try:
            # Train complex teacher model
            teacher_model, X, y = self.complex_models[task](data)

            # Train simple student model
            student_model = self._train_student_model(X, y, distillation_method, teacher_model)

            # Evaluate both models
            teacher_acc = accuracy_score(y, teacher_model.predict(X))
            student_acc = accuracy_score(y, student_model.predict(X))

            # Calculate model size reduction (simulated)
            size_reduction = np.random.uniform(0.6, 0.85)

            return {
                "task": task,
                "distillation_method": distillation_method,
                "teacher_accuracy": float(teacher_acc),
                "student_accuracy": float(student_acc),
                "size_reduction": float(size_reduction),
                "recommendations": [
                    "Use student model for edge deployment",
                    "Retain teacher model for critical decisions"
                ]
            }
        except Exception as e:
            logger.error(f"Knowledge distillation error: {str(e)}")
            return {"error": f"Distillation failed: {str(e)}"}

    def _train_complex_fraud_model(self, data):
        # Simulate training a complex fraud detection model
        X = np.array([[d.get("amount", 0),
                      d.get("frequency", 0),
                      d.get("location_risk", 0)]
                     for d in data])
        y = np.array([d.get("is_fraud", 0) for d in data])

        # Simulate complex model (in reality would be a deep learning model)
        model = RandomForestClassifier(n_estimators=100, max_depth=10)
        model.fit(X, y)
        return model, X, y

    def _train_complex_diagnosis_model(self, data):
        # Simulate training a complex medical diagnosis model
        X = np.array([[d.get("age", 0),
                      d.get("symptom_score", 0),
                      d.get("test_result", 0)]
                     for d in data])
        y = np.array([d.get("diagnosis_code", 0) for d in data])

        # Simulate complex model
        model = RandomForestClassifier(n_estimators=150, max_depth=12)
        model.fit(X, y)
        return model, X, y

    def _train_student_model(self, X, y, method, teacher_model=None):
        # Train a simpler student model
        if method == "soft_labels" and teacher_model:
            # Get soft predictions from teacher
            y_soft = teacher_model.predict_proba(X)
            # Train student on soft labels
            model = RandomForestClassifier(n_estimators=30, max_depth=5)
            model.fit(X, y_soft)
        else:
            # Train directly on hard labels
            model = RandomForestClassifier(n_estimators=30, max_depth=5)
            model.fit(X, y)
        return model

# ======================
# 3. MCP Server Implementation
# ======================
class AuthService:
    """OAuth Authentication Service"""
    def __init__(self):
        self.valid_clients = {
            "sec_agent": "agent_secret_123",
            "review_agent": "review_secret_456",
            "data_scientist": "scientist_secret_789"
        }
        self.tokens = {}

    def GetToken(self, request):
        client_id = request.client_id
        client_secret = request.client_secret

        if client_id in self.valid_clients and self.valid_clients[client_id] == client_secret:
            token = f"token_{uuid4().hex[:16]}"
            self.tokens[token] = {
                "client_id": client_id,
                "expires": time.time() + 3600  # 1-hour validity
            }
            logger.info(f"Issuing token to: {client_id}")
            return TokenResponse(
                access_token=token,
                token_type="Bearer",
                expires_in=3600
            )

        logger.warning(f"Invalid client credentials: {client_id}")
        return TokenResponse()  # Return empty response for error

class MCPService:
    """MCP Core Service"""
    def __init__(self):
        self.tools = {
            "ContentGuard": ContentSecurityTool(),
            "TruthValidator": HallucinationDetectionTool(),
            "RiskOracle": RiskAssessmentTool(),
            "GPRiskAssessor": GPRRiskAssessmentTool(),        # New GPR tool
            "GANDataFactory": GANSyntheticDataTool(),         # New GAN tool
            "KDImpartation": KnowledgeDistillationTool()      # New KD tool
        }
        self.auth_service = AuthService()

    def _validate_token(self, token):
        """Validate token validity"""
        if token in self.auth_service.tokens:
            token_data = self.auth_service.tokens[token]
            if token_data["expires"] > time.time():
                return True
            else:
                logger.warning(f"Token expired: {token}")
                del self.auth_service.tokens[token]
        return False

    def ExecuteTool(self, request):
        """Execute tool request"""
        # Validate token
        if not self._validate_token(request.auth_token):
            return ToolResponse(
                status="error",
                error_detail="Invalid or expired authentication token"
            )

        # Find tool
        tool = self.tools.get(request.tool_name)
        if not tool:
            return ToolResponse(
                status="error",
                error_detail=f"Tool not found: {request.tool_name}"
            )

        try:
            # Parse input data
            input_data = json.loads(request.input_data)

            # Execute tool
            logger.info(f"Executing tool: {tool.name} ({tool.description})")
            result = tool.execute(input_data)

            return ToolResponse(
                status="success",
                output_data=json.dumps(result, ensure_ascii=False)
            )
        except Exception as e:
            logger.error(f"Tool execution error: {str(e)}")
            return ToolResponse(
                status="error",
                error_detail=f"Tool execution error: {str(e)}"
            )

# ======================
# 4. Agent System Implementation
# ======================
class SecurityAgent:
    """Security Agent - Performs security tasks"""
    def __init__(self, name, client_id, client_secret):
        self.name = name
        self.client_id = client_id
        self.client_secret = client_secret
        self.auth_token = None
        self.mcp_host = "localhost"
        self.mcp_port = 50051

    def _call_mcp(self, method, request):
        """Generic method to call MCP service"""
        # In actual implementation should use gRPC channel
        # Simplified to direct service call
        if method == "GetToken":
            return self.mcp_service.auth_service.GetToken(request)
        elif method == "ExecuteTool":
            return self.mcp_service.ExecuteTool(request)
        return None

    def authenticate(self, mcp_service):
        """Obtain OAuth token"""
        self.mcp_service = mcp_service
        token_request = TokenRequest(
            client_id=self.client_id,
            client_secret=self.client_secret
        )
        response = self._call_mcp("GetToken", token_request)

        if response and response.access_token:
            self.auth_token = response.access_token
            logger.info(f"{self.name} authenticated successfully")
            return True

        logger.error(f"{self.name} authentication failed")
        return False

    def use_tool(self, tool_name, input_data):
        """Use a tool"""
        if not self.auth_token:
            logger.error(f"{self.name} not authenticated, cannot use tools")
            return None

        request = ToolRequest(
            tool_name=tool_name,
            input_data=json.dumps(input_data, ensure_ascii=False),
            auth_token=self.auth_token
        )

        response = self._call_mcp("ExecuteTool", request)

        if response and response.status == "success":
            return json.loads(response.output_data)

        logger.error(f"{self.name} tool call failed: {response.error_detail if response else 'Unknown error'}")
        return None

    def perform_advanced_audit(self, data_context):
        """Perform advanced security audit using new tools"""
        logger.info(f"{self.name} starting ADVANCED security audit")

        # 1. Generate high-quality synthetic data
        synthetic_data = self.use_tool("GANDataFactory", {
            "data_type": data_context["data_type"],
            "size": data_context.get("size", 100),
            "complexity": "high"
        })

        # 2. Advanced risk assessment
        risk = self.use_tool("GPRiskAssessor", {
            "risk_factors": data_context.get("risk_factors", {}),
            "historical_data": data_context.get("historical_risk_data", [])
        })

        # 3. Knowledge distillation for security models
        distillation = self.use_tool("KDImpartation", {
            "task": "fraud_detection",
            "training_data": data_context.get("training_data", []),
            "method": "soft_labels"
        })

        return {
            "synthetic_data": synthetic_data,
            "risk_assessment": risk,
            "knowledge_distillation": distillation
        }

# ======================
# 5. Server Run Function
# ======================
def run_server():
    """Run MCP server"""
    logger.info("Starting MCP server...")

    # Create MCP service instance
    mcp_service = MCPService()

    # Simulate gRPC server
    logger.info("MCP server running (port: 50051)")
    logger.info("Available tools:")
    for tool_name, tool in mcp_service.tools.items():
        logger.info(f"  - {tool_name}: {tool.description}")

    logger.info("Press Ctrl+C to stop server")

    # Keep server running
    try:
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        logger.info("Stopping MCP server")

# ======================
# 6. Client Test Function
# ======================
def run_client_test():
    """Run client test"""
    logger.info("Starting client test...")

    # Create MCP service instance (simulate server environment)
    mcp_service = MCPService()

    # Create security agent
    agent = SecurityAgent("Advanced Security Auditor", "sec_agent", "agent_secret_123")

    # Agent authentication
    if not agent.authenticate(mcp_service):
        logger.error("Client test failed: authentication error")
        return

    # Generate historical risk data for GPR
    historical_risk_data = [
        {"sensitivity": 0.6, "access_control": 0.3, "data_volume": 0.7, "external_access": 0.2, "outcome": 0.75},
        {"sensitivity": 0.4, "access_control": 0.5, "data_volume": 0.3, "external_access": 0.1, "outcome": 0.35},
        {"sensitivity": 0.8, "access_control": 0.2, "data_volume": 0.9, "external_access": 0.4, "outcome": 0.85},
        {"sensitivity": 0.3, "access_control": 0.6, "data_volume": 0.4, "external_access": 0.1, "outcome": 0.25},
        {"sensitivity": 0.7, "access_control": 0.4, "data_volume": 0.8, "external_access": 0.3, "outcome": 0.78}
    ]

    # Generate training data for knowledge distillation
    training_data = []
    for i in range(100):
        training_data.append({
            "amount": np.random.uniform(10, 10000),
            "frequency": np.random.randint(1, 30),
            "location_risk": np.random.uniform(0, 1),
            "is_fraud": 1 if np.random.random() > 0.9 else 0
        })

    # Perform advanced security audit
    audit_report = agent.perform_advanced_audit({
        "data_type": "financial",
        "risk_factors": {
            "sensitivity": 0.7,
            "access_control": 0.4,
            "data_volume": 0.6,
            "external_access": 0.3
        },
        "historical_risk_data": historical_risk_data,
        "training_data": training_data
    })

    # Print results
    logger.info("\nADVANCED SECURITY AUDIT REPORT:")

    # Synthetic data summary
    if audit_report.get("synthetic_data"):
        synth = audit_report["synthetic_data"]
        logger.info(f"1. Synthetic Data ({synth.get('data_type', 'N/A')}):")
        logger.info(f"   - Records generated: {len(synth.get('records', []))}")
        logger.info(f"   - Complexity: {synth.get('complexity', 'N/A')}")
        if "statistics" in synth:
            stats = synth["statistics"]
            if "mean_amount" in stats:
                logger.info(f"   - Mean amount: ${stats['mean_amount']:.2f}")

    # Risk assessment
    if audit_report.get("risk_assessment"):
        risk = audit_report["risk_assessment"]
        logger.info(f"2. Risk Assessment:")
        logger.info(f"   - Score: {risk.get('risk_score', 0):.2f} ± {risk.get('std_dev', 0):.2f}")
        logger.info(f"   - Level: {risk.get('risk_level', 'N/A')}")
        logger.info(f"   - Recommendations: {', '.join([r for r in risk.get('recommendations', []) if r])}")

    # Knowledge distillation
    if audit_report.get("knowledge_distillation"):
        kd = audit_report["knowledge_distillation"]
        logger.info(f"3. Knowledge Distillation ({kd.get('task', 'N/A')}):")
        logger.info(f"   - Teacher Accuracy: {kd.get('teacher_accuracy', 0):.2%}")
        logger.info(f"   - Student Accuracy: {kd.get('student_accuracy', 0):.2%}")
        logger.info(f"   - Size Reduction: {kd.get('size_reduction', 0):.0%}")

# ======================
# 7. Main Program Entry
# ======================
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='MCP System')
    parser.add_argument('mode', choices=['server', 'client'], help='Run mode: server or client')

    args = parser.parse_args()

    if args.mode == 'server':
        run_server()
    else:
        run_client_test()


In [None]:
python mcp_system.py server

In [None]:
python mcp_system.py client