In [25]:
import requests
import numpy as np
import pandas as pd
import json

In [33]:
# 1. 数据集管理
# 查看可用数据集
datasets = requests.get('http://localhost:5000/datasets').json()
print("可用数据集:", json.dumps(datasets, indent=4))


可用数据集: {
    "datasets": [
        {
            "dataset_id": "built_in_iris",
            "name": "iris",
            "shape": {
                "samples": 150,
                "features": 4
            },
            "type": "built_in",
            "creation_time": "2024-12-09T17:22:21.402144"
        },
        {
            "dataset_id": "built_in_breast_cancer",
            "name": "breast_cancer",
            "shape": {
                "samples": 569,
                "features": 30
            },
            "type": "built_in",
            "creation_time": "2024-12-09T17:22:21.404138"
        },
        {
            "dataset_id": "b2ff525f",
            "name": "my_dataset",
            "description": "\u81ea\u5b9a\u4e49\u6570\u636e\u96c6",
            "shape": {
                "samples": 3,
                "features": 2
            },
            "type": "user_uploaded",
            "creation_time": "2024-12-09T17:18:31.860881"
        }
    ]
}


In [34]:
# 获取内置数据集
iris_data = requests.get('http://localhost:5000/datasets/built_in_iris').json()
# print("内置数据集:\n", json.dumps(iris_data, indent=4))


In [35]:
# 上传新数据集
new_dataset = {
    'name': 'my_dataset',
    'description': '自定义数据集',
    'X': [[1, 2], [3, 4], [5, 6]],
    'y': [0, 1, 0],
    'feature_names': ['feature1', 'feature2'],
    'target_names': ['class0', 'class1']
}
response = requests.post('http://localhost:5000/datasets', json=new_dataset)
dataset_id = response.json()['dataset_id']

In [47]:
data = requests.get(f'http://localhost:5000/datasets/{dataset_id}').json()
print(data)
X = data['X']
y = data['y']
meta_data = data['meta_data']


{'X': [[1, 2], [3, 4], [5, 6]], 'y': [0, 1, 0], 'meta_data': {'name': 'my_dataset', 'description': '自定义数据集', 'feature_names': ['feature1', 'feature2'], 'target_names': ['class0', 'class1'], 'shape': {'samples': 3, 'features': 2}, 'type': 'user_uploaded', 'creation_time': '2024-12-09T17:24:07.208705'}}


In [37]:
# 2. 使用数据集训练模型
response = requests.post('http://localhost:5000/train', json={
    'dataset_id': dataset_id,  # 使用已上传的数据集
    'model_type': 'logistic',
    'params': {'max_iter': 1000}
})
model_id = response.json()['model_id']

In [48]:
# 3. 预测
predictions = requests.post('http://localhost:5000/predict', json={
    'X': X,
    'model_id': model_id
}).json()
print(predictions)

{'predictions': [0, 0, 0]}


In [51]:
# 4. 评估模型
evaluation = requests.post('http://localhost:5000/evaluate', json={
    'X': X,
    'y': y,
    'model_id': model_id,
    'metrics': ['accuracy', 'precision', 'recall', 'f1', 'confusion_matrix']
}).json()

{'accuracy': 0.6666666666666666, 'precision': 0.4444444444444444, 'recall': 0.6666666666666666, 'f1': 0.5333333333333333, 'confusion_matrix': [[2, 0], [1, 0]]}


In [53]:
# 5. 可视化
visualization = requests.post('http://localhost:5000/visualize', json={
    'X': X,
    'y': y,
    'model_id': model_id,
    'plot_type': 'scatter'
}).json()

# base64转图片
# import base64
# from PIL import Image
# from io import BytesIO

# image = Image.open(BytesIO(base64.b64decode(visualization['image'])))
# image.show()


{'image': 'iVBORw0KGgoAAAANSUhEUgAAA+gAAAJYCAYAAADxHswlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/GU6VOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5uElEQVR4nO3dfXRV9Z3o/88JSKJIArZAQCKiKKCAorYarMVaWqrWK947XsvQghadq8UZ0D4oczsKPkycKq2dURHrVWorcsVbdaQ+FLVgFZwrih3ASlUQsBLsaiEhIBGT/fujP9MbIZAgyfkGXq+1zhrPPt+d8zl7trPm7T4PuSzLsgAAAADyqiDfAwAAAAACHQAAAJIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABAh0AAAASINABAAAgAQIdAAAAEiDQAQAAIAECHQAAABIg0AEAACABHfM9AADsT1asWBHDhg2LTp067fTxDz74IJYuXbrbNb/73e9i27Zt+8S6I488cqePA8D+RqADQBvKsiw++9nPxvPPP7/Tx0855ZRmr9lX1gEAf+Et7gAAAJAAgQ4AAAAJEOgAAACQAIEOAAAACRDoAAAAkACBDgAAAAkQ6AAAAJAAgQ4AAAAJEOgAAACQAIEOAAAAC

In [57]:
# 6. 生成综合报告
report = requests.post('http://localhost:5000/generate_report', json={
    'model_id': model_id,
    'dataset_info': {
        'name': 'iris',
        'description': meta_data['description'],
        'feature_names': meta_data['feature_names'],
        'target_names': meta_data['target_names'],
        'shape': list(meta_data['shape'].values())
    },
    'training_info': {
        'model_type': 'logistic',
        'parameters': {},
        'training_time': '2023-XX-XX...',
        'convergence_info': '模型收敛信息'
    },
    'prediction_results': {
        'sample_predictions': predictions['predictions'][:5],
        'prediction_distribution': {'0': 30, '1': 40, '2': 30}
    },
    'evaluation_results': evaluation,
    'visualization_results': {
        'plots': [visualization['image']]
    }
}).json()
print(json.dumps(report, indent=4))

{
    "report_time": "2024-12-09T17:35:21.230534",
    "report_sections": {
        "dataset_analysis": {
            "summary": "\u6570\u636e\u96c6'iris'\u5305\u542b3\u4e2a\u6837\u672c\uff0c2\u4e2a\u7279\u5f81\u3002",
            "feature_importance": [
                "\u7279\u5f81'feature1'\u7684\u91cd\u8981\u6027\u5206\u6790",
                "\u7279\u5f81'feature2'\u7684\u91cd\u8981\u6027\u5206\u6790"
            ],
            "data_distribution": "\u6570\u636e\u5206\u5e03\u5206\u6790\u7ed3\u679c"
        },
        "model_analysis": {
            "model_type": "logistic",
            "parameters": "\u6a21\u578b\u4f7f\u7528\u7684\u53c2\u6570\u914d\u7f6e\u5206\u6790",
            "structure_summary": "logistic\u6a21\u578b\u7ed3\u6784\u6982\u8ff0"
        },
        "performance_analysis": {
            "metrics_summary": "\u51c6\u786e\u7387: 66.67% \u7cbe\u786e\u7387: 44.44% \u53ec\u56de\u7387: 66.67% F1\u5206\u6570: 53.33%",
            "prediction_analysis": "\u9884\u6d4b\u7ed3\

In [60]:
# 清理（可选）
# 删除数据集
requests.delete(f'http://localhost:5000/datasets/{dataset_id}')
# 删除模型
requests.delete(f'http://localhost:5000/models/{model_id}')

<Response [404]>