<a href="https://colab.research.google.com/github/Shuo-Zh/Shuo-Zh.github.io/blob/main/define%20tools%20and%20communicate%20with%20MCP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Define tools

In [None]:
from abc import ABC, abstractmethod
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
import torch
import torch.nn as nn

class SecurityTool(ABC):
    """所有安全工具的基类"""
    def __init__(self, name, description):
        self.name = name
        self.description = description
        self.requires_auth = True

    @abstractmethod
    def execute(self, input_data, auth_token=None):
        pass

class ContentSecurityTool(SecurityTool):
    """内容安全工具 - 模拟假数据生成"""
    def __init__(self):
        super().__init__(
            name="ContentGuard",
            description="生成模拟假数据以保护真实内容安全"
        )
        self.datasets = {
            "financial": self._generate_financial_data,
            "medical": self._generate_medical_data,
            "research": self._generate_research_data
        }

    def execute(self, input_data, auth_token=None):
        data_type = input_data.get("data_type", "financial")
        size = input_data.get("size", 100)

        if data_type not in self.datasets:
            return {"error": f"Unsupported data type: {data_type}"}

        return {
            "status": "success",
            "data": self.datasets[data_type](size),
            "metadata": {
                "data_type": data_type,
                "size": size,
                "is_mock": True
            }
        }

    def _generate_financial_data(self, size):
        """生成模拟金融数据"""
        return {
            "transactions": np.random.normal(1000, 500, size).tolist(),
            "accounts": [f"ACC{np.random.randint(10000, 99999)}" for _ in range(size)]
        }

    def _generate_medical_data(self, size):
        """生成模拟医疗数据"""
        conditions = ["Hypertension", "Diabetes", "Asthma", "Arthritis"]
        return {
            "patients": [f"PT{np.random.randint(1000, 9999)}" for _ in range(size)],
            "diagnoses": [np.random.choice(conditions) for _ in range(size)]
        }

    def _generate_research_data(self, size):
        """生成模拟研究数据"""
        topics = ["AI Ethics", "Quantum Computing", "Climate Modeling", "Genomic Sequencing"]
        return {
            "experiment_ids": [f"EXP{np.random.randint(10000, 99999)}" for _ in range(size)],
            "research_topics": [np.random.choice(topics) for _ in range(size)],
            "results_simulated": np.random.rand(size).tolist()
        }

class HallucinationDetectionTool(SecurityTool):
    """幻觉检测工具 - 模拟QA验证"""
    def __init__(self):
        super().__init__(
            name="TruthValidator",
            description="检测和防止模型幻觉的问答验证系统"
        )
        self.qa_pairs = {
            "finance": {
                "What is the prime rate?": "The prime rate is currently 8.5%",
                "How to calculate ROI?": "ROI = (Net Profit / Cost of Investment) × 100%"
            },
            "healthcare": {
                "Symptoms of diabetes?": "Increased thirst, frequent urination, unexplained weight loss",
                "Normal blood pressure?": "Less than 120/80 mm Hg"
            }
        }

    def execute(self, input_data, auth_token=None):
        domain = input_data.get("domain", "finance")
        question = input_data.get("question", "")

        if domain not in self.qa_pairs:
            return {"error": f"Unsupported domain: {domain}"}

        # 简单的相似度匹配（实际应使用嵌入模型）
        best_match = None
        best_score = 0

        for known_question, answer in self.qa_pairs[domain].items():
            similarity = self._jaccard_similarity(question, known_question)
            if similarity > best_score:
                best_score = similarity
                best_match = (known_question, answer)

        return {
            "status": "success",
            "question": question,
            "matched_question": best_match[0] if best_match else None,
            "answer": best_match[1] if best_match else "No reliable information found",
            "confidence": float(best_score)
        }

    def _jaccard_similarity(self, a, b):
        """计算Jaccard相似度"""
        a_set = set(a.lower().split())
        b_set = set(b.lower().split())
        intersection = len(a_set & b_set)
        union = len(a_set | b_set)
        return intersection / union if union > 0 else 0

class FedPrivacyTool(SecurityTool):
    """联邦隐私安全工具 - 模拟异构客户端"""
    def __init__(self):
        super().__init__(
            name="FedShield",
            description="联邦学习环境下的隐私保护工具"
        )
        self.model = SimpleModel()
        self.clients = self._initialize_clients()

    def execute(self, input_data, auth_token=None):
        client_id = input_data.get("client_id", 0)
        batch_size = input_data.get("batch_size", 32)
        learning_rate = input_data.get("learning_rate", 0.01)

        if client_id not in self.clients:
            return {"error": f"Invalid client ID: {client_id}"}

        client = self.clients[client_id]
        client_data = self._generate_client_data(client_id, batch_size)

        # 模拟联邦学习更新
        client.update(client_data, learning_rate)

        return {
            "status": "success",
            "client_id": client_id,
            "data_summary": f"{len(client_data)} samples",
            "model_update": client.get_model_update(),
            "privacy_level": client.privacy_level
        }

    def _initialize_clients(self):
        """初始化异构客户端"""
        return {
            i: FedClient(
                client_id=i,
                data_distribution="normal" if i % 2 == 0 else "uniform",
                privacy_level=np.random.choice(["low", "medium", "high"])
            )
            for i in range(10)  # 10个模拟客户端
        }

    def _generate_client_data(self, client_id, size):
        """生成客户端特定数据"""
        if client_id % 2 == 0:
            return np.random.normal(0, 1, (size, 10))
        else:
            return np.random.uniform(-1, 1, (size, 10))

class RiskAssessmentTool(SecurityTool):
    """GPR风险分析工具 - 基于论文要求"""
    def __init__(self):
        super().__init__(
            name="RiskOracle",
            description="使用高斯过程回归进行智能风险评估"
        )
        self.kernel = RBF(length_scale=1.0)
        self.model = GaussianProcessRegressor(kernel=self.kernel)

    def execute(self, input_data, auth_token=None):
        X = np.array(input_data.get("features", []))
        y = np.array(input_data.get("targets", []))

        if len(X) == 0 or len(y) == 0:
            return {"error": "Missing features or targets"}

        # 训练或更新模型
        if len(X) > 10:  # 有足够数据时重新训练
            self.model.fit(X, y)

        # 预测风险
        predictions = self.model.predict(X, return_std=True)

        return {
            "status": "success",
            "predictions": predictions[0].tolist(),
            "uncertainty": predictions[1].tolist(),
            "kernel": str(self.model.kernel_)
        }

# 辅助类和模型
class SimpleModel(nn.Module):
    """简单的神经网络模型"""
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

class FedClient:
    """模拟联邦学习客户端"""
    def __init__(self, client_id, data_distribution, privacy_level):
        self.client_id = client_id
        self.data_distribution = data_distribution
        self.privacy_level = privacy_level
        self.model = SimpleModel()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)

    def update(self, data, learning_rate):
        """使用客户端数据更新模型"""
        self.optimizer.param_groups[0]['lr'] = learning_rate
        inputs = torch.tensor(data, dtype=torch.float32)
        outputs = self.model(inputs)

        # 模拟训练过程
        loss = outputs.mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def get_model_update(self):
        """获取模型更新（差分隐私处理）"""
        update = [param.grad.numpy().tolist() for param in self.model.parameters()]

        # 根据隐私级别添加噪声
        if self.privacy_level == "high":
            noise = np.random.laplace(0, 0.01, np.array(update).shape)
            update = (np.array(update) + noise).tolist()

        return update

Define Agent

In [None]:
class Agent:
    """Agent基类"""
    def __init__(self, name, role, tools=None):
        self.name = name
        self.role = role
        self.tools = tools or []
        self.auth_token = None

    def authenticate(self, auth_server):
        """通过OAuth服务器认证"""
        # 实际实现中应使用OAuth2客户端凭证流程
        self.auth_token = auth_server.get_token(self.name)
        return self.auth_token is not None

    def use_tool(self, tool_name, input_data):
        """通过MCP使用工具"""
        for tool in self.tools:
            if tool.name == tool_name:
                return tool.execute(input_data, self.auth_token)
        return {"error": f"Tool not available: {tool_name}"}

class SecurityAgent(Agent):
    """安全Agent - 执行安全任务"""
    def __init__(self):
        tools = [
            ContentSecurityTool(),
            HallucinationDetectionTool(),
            RiskAssessmentTool()
        ]
        super().__init__(
            name="SecBot",
            role="执行安全审查和风险评估",
            tools=tools
        )

    def perform_audit(self, data_context):
        """执行安全审计"""
        # 1. 生成假数据用于测试
        mock_data = self.use_tool("ContentGuard", {
            "data_type": data_context["data_type"],
            "size": 100
        })

        # 2. 验证模型输出
        validation = self.use_tool("TruthValidator", {
            "domain": data_context["domain"],
            "question": data_context["primary_question"]
        })

        # 3. 风险评估
        risk = self.use_tool("RiskOracle", {
            "features": data_context["features"],
            "targets": data_context["targets"]
        })

        return {
            "audit_status": "completed",
            "mock_data": mock_data,
            "validation": validation,
            "risk_assessment": risk
        }

class SecurityReviewAgent(Agent):
    """安全检查Agent - 审查任务"""
    def __init__(self):
        tools = [
            HallucinationDetectionTool(),
            RiskAssessmentTool()
        ]
        super().__init__(
            name="ReviewBot",
            role="审查安全Agent的工作",
            tools=tools
        )

    def review_work(self, audit_report):
        """审查安全审计报告"""
        # 1. 验证关键声明
        primary_claim = audit_report["validation"]["answer"]
        validation = self.use_tool("TruthValidator", {
            "domain": "security",
            "question": f"Is this statement accurate: {primary_claim}"
        })

        # 2. 评估风险评分
        risk_data = {
            "features": audit_report["risk_assessment"]["predictions"],
            "targets": [0.5] * len(audit_report["risk_assessment"]["predictions"])
        }
        risk_review = self.use_tool("RiskOracle", risk_data)

        return {
            "review_status": "completed",
            "claim_validation": validation,
            "risk_review": risk_review,
            "original_report": audit_report
        }

class BenchmarkingAgent(Agent):
    """基准测试Agent - 设置安全基准"""
    def __init__(self):
        tools = [
            FedPrivacyTool(),
            RiskAssessmentTool()
        ]
        super().__init__(
            name="BenchmarkBot",
            role="建立和评估安全基准",
            tools=tools
        )

    def set_benchmarks(self, clients):
        """为客户端设置安全基准"""
        benchmarks = {}
        for client_id in clients:
            client_data = self.use_tool("FedShield", {
                "client_id": client_id,
                "batch_size": 64,
                "learning_rate": 0.02
            })

            risk = self.use_tool("RiskOracle", {
                "features": client_data["model_update"],
                "targets": [0.3]  # 目标风险水平
            })

            benchmarks[client_id] = {
                "privacy_level": client_data["privacy_level"],
                "risk_score": risk["predictions"][0],
                "compliance": risk["predictions"][0] <= 0.35
            }

        return benchmarks

class RedTeamingAgent(Agent):
    """红队测试Agent - 模拟攻击"""
    def __init__(self):
        tools = [
            ContentSecurityTool(),
            FedPrivacyTool()
        ]
        super().__init__(
            name="RedTeamBot",
            role="模拟攻击以测试系统安全性",
            tools=tools
        )

    def perform_attack(self, target_system):
        """执行红队攻击"""
        # 1. 生成恶意数据
        malicious_data = self.use_tool("ContentGuard", {
            "data_type": target_system["data_type"],
            "size": 50
        })

        # 2. 模拟联邦攻击
        attack_results = []
        for client_id in target_system["clients"]:
            result = self.use_tool("FedShield", {
                "client_id": client_id,
                "batch_size": 32,
                "learning_rate": 0.1  # 高学习率可能导致模型不稳定
            })
            attack_results.append(result)

        # 3. 评估攻击效果
        success_rate = sum(1 for r in attack_results if "error" in r) / len(attack_results)

        return {
            "attack_type": "data_poisoning",
            "malicious_data": malicious_data,
            "attack_results": attack_results,
            "success_rate": success_rate
        }


Define MCP using OAUTH

In [None]:
from concurrent import futures
import grpc
from grpc import ServicerContext
import auth_pb2
import auth_pb2_grpc
import mcp_pb2
import mcp_pb2_grpc
import json # Added explicit import for json

# OAuth认证服务
class AuthService(auth_pb2_grpc.AuthServiceServicer):
    def GetToken(self, request: auth_pb2.TokenRequest, context: ServicerContext):
        # 简化实现 - 实际应使用OAuth2客户端凭证流程
        if request.client_id == "sec_bot" and request.client_secret == "secure_password":
            return auth_pb2.TokenResponse(
                access_token="secbot_token_xyz",
                token_type="Bearer",
                expires_in=3600
            )
        context.set_code(grpc.StatusCode.UNAUTHENTICATED)
        context.set_details("Invalid client credentials")
        return auth_pb2.TokenResponse()

# MCP核心服务
class MCPService(mcp_pb2_grpc.MCPServiceServicer):
    def __init__(self):
        self.tools = {
            "ContentGuard": ContentSecurityTool(),
            "TruthValidator": HallucinationDetectionTool(),
            "FedShield": FedPrivacyTool(),
            "RiskOracle": RiskAssessmentTool()
        }

    def ExecuteTool(self, request: mcp_pb2.ToolRequest, context: ServicerContext):
        # 验证OAuth令牌
        auth_header = dict(context.invocation_metadata()).get("authorization")
        if not auth_header or not auth_header.startswith("Bearer "):
            context.set_code(grpc.StatusCode.UNAUTHENTICATED)
            context.set_details("Missing or invalid authorization header")
            return mcp_pb2.ToolResponse()

        token = auth_header[7:]
        if token != "secbot_token_xyz":  # 简化验证
            context.set_code(grpc.StatusCode.PERMISSION_DENIED)
            context.set_details("Invalid access token")
            return mcp_pb2.ToolResponse()

        # 执行工具
        if request.tool_name not in self.tools:
            context.set_code(grpc.StatusCode.NOT_FOUND)
            context.set_details(f"Tool not found: {request.tool_name}")
            return mcp_pb2.ToolResponse()

        try:
            input_data = json.loads(request.input_data)
            result = self.tools[request.tool_name].execute(input_data, token)
            return mcp_pb2.ToolResponse(
                status="success",
                output_data=json.dumps(result)
            )
        except Exception as e:
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(f"Tool execution error: {str(e)}")
            return mcp_pb2.ToolResponse()

def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

    # 注册服务
    auth_pb2_grpc.add_AuthServiceServicer_to_server(AuthService(), server)
    mcp_pb2_grpc.add_MCPServiceServicer_to_server(MCPService(), server)

    # 启用TLS (暂时注释掉，因为缺少证书文件)
    # server_creds = grpc.ssl_server_credentials([
    #     (open('server-key.pem', 'rb').read(),
    #      open('server-cert.pem', 'rb').read())
    # ])

    # server.add_secure_port('[::]:50051', server_creds)
    server.add_insecure_port('[::]:50051') # 改用非安全端口以便调试
    server.start()
    print("MCP Server running on port 50051 (insecure, TLS disabled)")
    server.wait_for_termination()

if __name__ == '__main__':
    serve()

MCP Server running on port 50051 (insecure, TLS disabled)


KeyboardInterrupt: 

In [None]:
# First, install grpcio-tools to compile the .proto files
!pip install grpcio-tools

# Create dummy .proto files for demonstration (you would use your actual .proto files)
# In a real scenario, you would have auth.proto and mcp.proto in your working directory.
# For this example, let's assume their content is minimal for compilation purposes.
# If you have the actual content of auth.proto and mcp.proto, replace these.

# Example content for auth.proto
auth_proto_content = """
syntax = "proto3";

package auth;

message TokenRequest {
  string client_id = 1;
  string client_secret = 2;
}

message TokenResponse {
  string access_token = 1;
  string token_type = 2;
  int32 expires_in = 3;
}

service AuthService {
  rpc GetToken (TokenRequest) returns (TokenResponse);
}
"""

# Example content for mcp.proto
mcp_proto_content = """
syntax = "proto3";

package mcp;

message ToolRequest {
  string tool_name = 1;
  string input_data = 2;
}

message ToolResponse {
  string status = 1;
  string output_data = 2;
}

service MCPService {
  rpc ExecuteTool (ToolRequest) returns (ToolResponse);
}
"""

# Write the .proto files to the current directory
with open('auth.proto', 'w') as f:
    f.write(auth_proto_content)

with open('mcp.proto', 'w') as f:
    f.write(mcp_proto_content)

# Compile the .proto files into Python modules
!python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. auth.proto mcp.proto

print("Generated auth_pb2.py, auth_pb2_grpc.py, mcp_pb2.py, and mcp_pb2_grpc.py")

Collecting grpcio-tools
  Downloading grpcio_tools-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.3 kB)
Collecting protobuf<7.0.0,>=6.31.1 (from grpcio-tools)
  Downloading protobuf-6.33.1-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Downloading grpcio_tools-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading protobuf-6.33.1-cp39-abi3-manylinux2014_x86_64.whl (323 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m323.2/323.2 kB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: protobuf, grpcio-tools
  Attempting uninstall: protobuf
    Found existing installation: protobuf 5.29.5
    Uninstalling protobuf-5.29.5:
      Successfully uninstalled protobuf-5.29.5
[31mERROR: pip's dependency resolver does not currently take into account all the pa

Generated auth_pb2.py, auth_pb2_grpc.py, mcp_pb2.py, and mcp_pb2_grpc.py


In [None]:
import unittest
from unittest.mock import MagicMock, patch
import grpc
import mcp_pb2

class TestSecurityTools(unittest.TestCase):
    def test_content_security_tool(self):
        tool = ContentSecurityTool()
        result = tool.execute({"data_type": "financial", "size": 50})
        self.assertEqual(result["status"], "success")
        self.assertEqual(len(result["data"]["transactions"]), 50)

    def test_hallucination_detection(self):
        tool = HallucinationDetectionTool()
        result = tool.execute({
            "domain": "finance",
            "question": "How to calculate return on investment?"
        })
        self.assertEqual(result["status"], "success")
        self.assertIn("ROI", result["answer"])

class TestAgents(unittest.TestCase):
    @patch('mcp_pb2_grpc.MCPServiceStub')
    def test_security_agent(self, mock_stub):
        # 设置模拟响应
        mock_stub.ExecuteTool.return_value = mcp_pb2.ToolResponse(
            status="success",
            output_data=json.dumps({"mock_data": "test"})
        )

        agent = SecurityAgent()
        agent.auth_token = "valid_token"
        result = agent.perform_audit({
            "data_type": "medical",
            "domain": "healthcare",
            "primary_question": "What is diabetes?",
            "features": [[1,2],[3,4]],
            "targets": [0,1]
        })

        self.assertEqual(result["audit_status"], "completed")

class TestMCPIntegration(unittest.TestCase):
    @patch('grpc.insecure_channel')
    def test_mcp_communication(self, mock_channel):
        # 创建模拟通道和存根
        mock_stub = MagicMock()
        mock_channel.return_value = MagicMock()

        # 测试工具执行
        request = mcp_pb2.ToolRequest(
            tool_name="ContentGuard",
            input_data=json.dumps({"data_type": "research", "size": 30})
        )

        # 设置模拟响应
        mock_stub.ExecuteTool.return_value = mcp_pb2.ToolResponse(
            status="success",
            output_data=json.dumps({"data": "mock_research_data"})
        )

        # 执行调用
        response = mock_stub.ExecuteTool(request)
        self.assertEqual(response.status, "success")
        self.assertIn("mock_research_data", response.output_data)

class OAuthTest(unittest.TestCase):
    def test_token_validation(self):
        # 创建测试服务器
        server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
        auth_pb2_grpc.add_AuthServiceServicer_to_server(AuthService(), server)
        server.add_insecure_port('[::]:50052')
        server.start()

        # 创建客户端
        with grpc.insecure_channel('localhost:50052') as channel:
            stub = auth_pb2_grpc.AuthServiceStub(channel)
            response = stub.GetToken(auth_pb2.TokenRequest(
                client_id="sec_bot",
                client_secret="secure_password"
            ))
            self.assertEqual(response.access_token, "secbot_token_xyz")

        server.stop(0)

if __name__ == '__main__':
    unittest.main()


Deploy and run

In [None]:
python mcp_server.py

SyntaxError: invalid syntax (ipython-input-479180457.py, line 1)