In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role
import subprocess
import sys
import time

# 获取SageMaker角色和会话
role = get_execution_role()
session = sagemaker.Session()

# 指定你的S3存储桶和路径
bucket = "sagemaker-earthquake-prediction-shuhao"
data_prefix = "earthquake-data/processed/2"
model_prefix = "earthquake-data/models/2"

print("🌍 地震预测Random Forest训练系统")
print("=" * 60)
print(f"SageMaker角色: {role}")
print(f"S3存储桶: {bucket}")
print(f"数据路径: s3://{bucket}/{data_prefix}/")
print(f"模型保存路径: s3://{bucket}/{model_prefix}/")
print("=" * 60)

In [None]:
# Cell 2: 验证S3存储桶和数据
import boto3

s3_client = boto3.client('s3')

print("🔍 验证S3存储桶和数据...")

# 检查存储桶是否存在
try:
    s3_client.head_bucket(Bucket=bucket)
    print(f"✅ 存储桶 {bucket} 访问正常")
except Exception as e:
    print(f"❌ 存储桶访问失败: {e}")
    print("请检查存储桶名称和IAM权限设置")

# 检查数据文件
try:
    response = s3_client.list_objects_v2(
        Bucket=bucket,
        Prefix=f"{data_prefix}/",
        MaxKeys=50
    )
    
    if 'Contents' in response:
        feature_files = [obj['Key'] for obj in response['Contents'] 
                        if obj['Key'].endswith('_特征データ.csv')]
        
        print(f"\n📁 找到 {len(feature_files)} 个特征数据文件:")
        total_size = 0
        
        for i, file_key in enumerate(feature_files[:10]):  # 显示前10个
            try:
                obj_info = s3_client.head_object(Bucket=bucket, Key=file_key)
                size = obj_info['ContentLength']
                total_size += size
                filename = file_key.split('/')[-1]
                print(f"  {i+1:2d}. {filename} ({size:,} bytes)")
            except:
                filename = file_key.split('/')[-1]
                print(f"  {i+1:2d}. {filename}")
        
        if len(feature_files) > 10:
            print(f"  ... 还有 {len(feature_files) - 10} 个文件")
        
        print(f"\n📊 数据文件总计: {len(feature_files)} 个")
        if total_size > 0:
            print(f"📊 已检查文件总大小: {total_size:,} bytes ({total_size/1024/1024:.2f} MB)")
            
    else:
        print(f"⚠️  在 s3://{bucket}/{data_prefix}/ 下未找到数据文件")
        print("请确保特征数据文件已上传到正确位置")
        
except Exception as e:
    print(f"❌ 检查数据文件失败: {e}")

In [None]:
# Cell 3: 执行训练（带实时进度监控）
import subprocess
import sys
import time
from datetime import datetime

# 训练参数
bucket_name = bucket
data_folder = data_prefix
model_folder = model_prefix
min_samples = 80  # 最小样本数

print("🚀 准备开始训练...")
print(f"时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"参数配置:")
print(f"  存储桶: {bucket_name}")
print(f"  数据文件夹: {data_folder}")
print(f"  模型保存文件夹: {model_folder}")
print(f"  最小样本数: {min_samples}")

# 构建训练命令
train_command = [
    sys.executable, "/home/ec2-user/SageMaker/earthquake-training/scripts/sagemaker_train.py",
    "--bucket-name", bucket_name,
    "--data-folder", data_folder,
    "--model-folder", model_folder,
    "--min-samples", str(min_samples)
]

print(f"\n执行命令:")
print(f"{' '.join(train_command)}")

print(f"\n{'='*60}")
print("开始训练...")
print(f"{'='*60}")

# 执行训练（实时显示输出）
start_time = time.time()
process = subprocess.Popen(
    train_command, 
    stdout=subprocess.PIPE, 
    stderr=subprocess.STDOUT,
    universal_newlines=True,
    bufsize=1
)

# 实时显示输出
while True:
    output = process.stdout.readline()
    if output == '' and process.poll() is not None:
        break
    if output:
        print(output.strip())

# 等待进程完成
return_code = process.poll()
end_time = time.time()
total_time = end_time - start_time

print(f"\n{'='*60}")
print("训练完成!")
print(f"{'='*60}")
print(f"返回码: {return_code}")
print(f"总耗时: {total_time:.1f}秒 ({total_time/60:.1f}分钟)")
print(f"结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

if return_code == 0:
    print("🎉 训练成功!")
else:
    print("❌ 训练失败!")

In [None]:
# Cell 4: 验证训练结果
import boto3
import json
from datetime import datetime

s3_client = boto3.client('s3')

print("🔍 验证训练结果...")

# 检查模型文件
model_files = [
    f"{model_prefix}/random_forest_model.pkl",
    f"{model_prefix}/label_encoder.pkl", 
    f"{model_prefix}/model_metadata.json"
]

print(f"\n📁 检查S3中的模型文件:")
total_size = 0
files_found = 0

for file_key in model_files:
    try:
        response = s3_client.head_object(Bucket=bucket, Key=file_key)
        size = response['ContentLength']
        last_modified = response['LastModified']
        total_size += size
        files_found += 1
        
        filename = file_key.split('/')[-1]
        print(f"  ✅ {filename}")
        print(f"     大小: {size:,} bytes ({size/1024/1024:.2f} MB)")
        print(f"     修改时间: {last_modified}")
        
    except Exception as e:
        filename = file_key.split('/')[-1]
        print(f"  ❌ {filename} (不存在)")

print(f"\n📊 模型文件统计:")
print(f"  找到文件: {files_found}/{len(model_files)}")
print(f"  总大小: {total_size:,} bytes ({total_size/1024/1024:.2f} MB)")

# 读取并显示模型元数据
if files_found >= 3:
    try:
        print(f"\n📋 读取模型元数据...")
        metadata_obj = s3_client.get_object(Bucket=bucket, Key=f"{model_prefix}/model_metadata.json")
        metadata = json.loads(metadata_obj['Body'].read().decode('utf-8'))
        
        print(f"\n🎯 训练结果摘要:")
        print(f"  模型类型: {metadata['model_type']}")
        print(f"  模型版本: {metadata.get('model_version', 'N/A')}")
        print(f"  训练时间: {metadata.get('training_time', 'N/A')}")
        print(f"  训练耗时: {metadata['metrics'].get('training_time_seconds', 0):.1f}秒")
        
        print(f"\n📊 模型性能:")
        metrics = metadata['metrics']
        print(f"  测试集R²得分: {metrics['test_r2']:.4f}")
        print(f"  测试集MSE: {metrics['test_mse']:.4f}")
        print(f"  测试集MAE: {metrics['test_mae']:.4f}")
        if 'oob_score' in metrics and metrics['oob_score']:
            print(f"  OOB得分: {metrics['oob_score']:.4f}")
        
        print(f"\n🗾 训练数据:")
        print(f"  参与训练的都道府県: {metadata.get('num_prefectures', 0)}")
        print(f"  总样本数: {metadata.get('total_samples', 0):,}")
        print(f"  特征数量: {len(metadata.get('feature_names', []))}")
        
        print(f"\n🏆 参与训练的都道府県 (前10个):")
        prefecture_info = metadata.get('prefecture_info', [])
        for i, (prefecture, count) in enumerate(prefecture_info[:10]):
            print(f"  {i+1:2d}. {prefecture}: {count:,} 条记录")
        
        if len(prefecture_info) > 10:
            print(f"  ... 还有 {len(prefecture_info) - 10} 个都道府県")
        
        print(f"\n🎯 重要特征 (Top 8):")
        feature_importance = metadata.get('feature_importance', {})
        sorted_features = sorted(feature_importance.items(), 
                               key=lambda x: x[1], reverse=True)
        for i, (feature, importance) in enumerate(sorted_features[:8]):
            print(f"  {i+1}. {feature}: {importance:.4f}")
        
        # 性能评估
        r2_score = metrics['test_r2']
        if r2_score >= 0.7:
            print(f"\n🎉 模型性能: 优秀 (R² = {r2_score:.4f})")
        elif r2_score >= 0.5:
            print(f"\n👍 模型性能: 良好 (R² = {r2_score:.4f})")
        elif r2_score >= 0.3:
            print(f"\n⚠️  模型性能: 一般 (R² = {r2_score:.4f})")
        else:
            print(f"\n❌ 模型性能: 需要改进 (R² = {r2_score:.4f})")
        
    except Exception as e:
        print(f"❌ 读取模型元数据失败: {e}")
       
else:
   print(f"⚠️  模型文件不完整，无法读取元数据")

print(f"\n{'='*60}")
print("训练验证完成!")
print(f"{'='*60}")