# AWS Glue Notebook 访问 MySQL RDS 数据库示例

本示例展示如何在 AWS Glue Notebook 中安全地访问 MySQL RDS 数据库。我们将使用 Glue connection 连接到 VPC 内的 RDS 数据库并获取数据库凭证，使用 Spark JDBC 连接到数据库。

## 1. 配置 Glue 会话

In [2]:
# %idle_timeout 60
# %glue_version 5.0
# %worker_type G.1X
# %number_of_workers 2
%connections rds-mysql-connection
%additional_python_modules mysql-connector-python

Welcome to the Glue Interactive Sessions Kernel
For more information on available magic commands, please type %help in any new cell.

Please view our Getting Started page to access the most up-to-date information on the Interactive Sessions kernel: https://docs.aws.amazon.com/glue/latest/dg/interactive-sessions.html
Installed kernel version: 1.0.8 
Connections to be included:
rds-mysql-connection


In [None]:
%config

## 2. 初始化 Glue 上下文

In [1]:
import sys
from awsglue.transforms import *
from awsglue.utils import getResolvedOptions
from pyspark.context import SparkContext
from awsglue.context import GlueContext
from awsglue.job import Job
import boto3
import json
from pyspark.sql import DataFrame
import pandas as pd

# 初始化 Glue 上下文
# Glue notebook 交互式会话中 SparkContext 已自动创建
# sc = SparkContext().getOrCreate()
# 直接获取当前的 SparkContext
sc = SparkContext._active_spark_context
glueContext = GlueContext(sc)
spark = glueContext.spark_session
job = Job(glueContext)

Trying to create a Glue session for the kernel.
Session Type: glueetl
Session ID: d7b08ede-910a-4e0e-b517-f01f8103041a
Applying the following default arguments:
--glue_kernel_version 1.0.8
--enable-glue-datacatalog true
Waiting for session d7b08ede-910a-4e0e-b517-f01f8103041a to get into ready status...
Session d7b08ede-910a-4e0e-b517-f01f8103041a has been created.



In [4]:
# 手动结束状态异常的session
%stop_session

Stopping session: 244fb699-d20c-4adb-93e7-320c9cc53d4b
Stopped session.


## 3. 从 Secrets Manager 获取数据库连接信息(可选)

In [4]:
def get_secret(secret_id, region_name):
    """从 Secrets Manager 获取数据库凭证"""
    client = boto3.client('secretsmanager', region_name=region_name)
    response = client.get_secret_value(SecretId=secret_id)
    secret = json.loads(response['SecretString'])
    return secret

# 获取 MySQL 连接信息
secret_id = "test/mysql-local"
region_name = "ap-southeast-1"  # 使用新加坡区域
db_credentials = get_secret(secret_id, region_name)

# 打印连接信息（不包含敏感信息）
print(f"Database Host: {db_credentials['host']}")
print(f"Database Name: {db_credentials['dbname']}")
print(f"Database Port: {db_credentials['port']}")

Database Host: test-mysql-public.cprl6bvujco2.ap-southeast-1.rds.amazonaws.com
Database Name: testdb
Database Port: 3306


## 4. 测试数据库可访问性(可选)

In [11]:
# 获取连接配置
connection_options = glueContext.extract_jdbc_conf("rds-mysql-connection")

def test_database_connection():
    """测试数据库连接是否可用"""
    try:       
        # 创建一个简单的测试查询
        test_df = spark.read.format("jdbc") \
            .options(**connection_options) \
            .option("query", "SELECT 1 as connection_test") \
            .load()
        
        # 显示结果
        test_df.show()
        print("✅ 数据库连接测试成功!")
        
        # 方法1: 使用 INFORMATION_SCHEMA 查询
        print("\n方法1: 使用 INFORMATION_SCHEMA 查询表:")
        try:
            tables_df1 = spark.read.format("jdbc") \
                .options(**connection_options) \
                .option("query", "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'testdb'") \
                .load()
            tables_df1.show()
        except Exception as e:
            print(f"方法1失败: {e}")
        
        # 方法2: 使用 mysql.innodb_table_stats 系统表
        print("\n方法2: 使用 mysql.innodb_table_stats 系统表:")
        try:
            tables_df2 = spark.read.format("jdbc") \
                .options(**connection_options) \
                .option("query", "SELECT table_name FROM mysql.innodb_table_stats WHERE database_name = 'testdb'") \
                .load()
            tables_df2.show()
        except Exception as e:
            print(f"方法2失败: {e}")
        
        return True
    except Exception as e:
        print(f"❌ 数据库连接测试失败: {e}")
        return False

# 执行连接测试
connection_successful = test_database_connection()

+---------------+
|connection_test|
+---------------+
|              1|
+---------------+

✅ 数据库连接测试成功!

方法1: 使用 INFORMATION_SCHEMA 查询表:
+-----------------+
|       TABLE_NAME|
+-----------------+
|      departments|
|employee_projects|
|        employees|
|         projects|
+-----------------+


方法2: 使用 mysql.innodb_table_stats 系统表:
+-----------------+
|       table_name|
+-----------------+
|      departments|
|employee_projects|
|        employees|
|         projects|
+-----------------+


## 5. 使用 Spark JDBC 读取数据（推荐方法）

In [21]:
# 使用 Spark SQL 读取 MySQL 数据

table_name = 'testdb.departments'

# 通过 Glue connection 访问数据库时，无需指定连接信息
   # .option("url", "jdbc:mysql://test-mysql-public.cprl6bvujco2.ap-southeast-1.rds.amazonaws.com:3306/testdb") \
   # .option("user", "admin") \
   # .option("password", "YourPWD") \
df = spark.read.format("jdbc") \
   .options(**connection_options) \
   .option("dbtable", table_name) \
   .load()

# 显示数据
df.show()

+---+---------------+-------------------+----------+
| id|           name|           location|    budget|
+---+---------------+-------------------+----------+
|  1|    Engineering|Building A, Floor 2|1500000.00|
|  2|      Marketing|Building B, Floor 1| 800000.00|
|  3|        Finance|Building A, Floor 3|1200000.00|
|  4|Human Resources|Building B, Floor 2| 500000.00|
|  5|          Sales|Building C, Floor 1|2000000.00|
+---+---------------+-------------------+----------+


In [33]:
def read_with_spark_jdbc(table_name):
    """使用 Spark JDBC 读取数据（推荐用于大数据集）"""
    print(f"Reading table {table_name} using Spark JDBC...")

    # 检查必要的键是否存在
    required_keys = ["url", "user", "password"]
    for key in required_keys:
        if key not in connection_options:
            raise KeyError(f"Required key '{key}' not found in connection_options")
    
    # 从 connection_options 中提取连接信息
    jdbc_url = connection_options["url"]  # 或者使用 fullUrl: connection_options["fullUrl"]
    username = connection_options["user"]
    password = connection_options["password"]
    
    # 构建连接属性
    connection_properties = {
        "user": username,
        "password": password,
        "driver": "com.mysql.jdbc.Driver"
    }
    
    df = spark.read.jdbc(
        url=jdbc_url,
        table=table_name,
        properties=connection_properties
    )
    return df

# 读取员工表
employees_df = read_with_spark_jdbc("testdb.employees")
print("员工表结构:")
employees_df.printSchema()
print("\n员工数据预览:")
employees_df.show(5)

Reading table testdb.employees using Spark JDBC...
员工表结构:
root
 |-- id: integer (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = true)
 |-- email: string (nullable = true)
 |-- department: string (nullable = true)
 |-- salary: decimal(10,2) (nullable = true)
 |-- hire_date: date (nullable = true)


员工数据预览:
+---+----------+---------+--------------------+---------------+--------+----------+
| id|first_name|last_name|               email|     department|  salary| hire_date|
+---+----------+---------+--------------------+---------------+--------+----------+
|  1|      John|    Smith|john.smith@exampl...|    Engineering|85000.00|2018-06-15|
|  2|     Emily|  Johnson|emily.johnson@exa...|      Marketing|72000.00|2019-03-22|
|  3|   Michael| Williams|michael.williams@...|        Finance|95000.00|2017-11-08|
|  4|     Sarah|    Brown|sarah.brown@examp...|Human Resources|68000.00|2020-01-15|
|  5|     David|    Jones|david.jones@examp...|    Enginee

## 6. 使用 Glue DynamicFrame 读取数据

In [40]:
connection_name="rds-mysql-connection"

def read_with_glue_dynamicframe(full_table_name):
    """使用 AWS Glue 动态帧读取数据"""
    print(f"Reading table {full_table_name} using Glue DynamicFrame...")

    # 从 connection_options 中获取必要的连接信息
    jdbc_url = connection_options["url"]
    user = connection_options["user"]
    password = connection_options["password"]
    
    # 构建 DynamicFrame 的连接选项
    df_connection_options = {
    #     "connectionName": connection_name,
    #     "dbtable": full_table_name
        "url": jdbc_url,
        "user": user,
        "password": password,
        "dbtable": full_table_name
    }
    
    try:
        # 使用 Glue 连接创建 DynamicFrame
        
        dynamic_frame = glueContext.create_dynamic_frame.from_options(
            connection_type="mysql",
            connection_options=df_connection_options
        )
        
        print(f"Successfully read {dynamic_frame.count()} records from {full_table_name}")
        return dynamic_frame
    except Exception as e:
        print(f"Error reading table {full_table_name} with DynamicFrame: {e}")
        return None

# 读取部门表
departments_dyf = read_with_glue_dynamicframe("testdb.departments")

print("部门表结构:")
departments_dyf.printSchema()
print("\n部门数据预览:")
departments_dyf.show(5)

Reading table testdb.departments using Glue DynamicFrame...
Successfully read 5 records from testdb.departments
部门表结构:
root
|-- id: int
|-- name: string
|-- location: string
|-- budget: decimal


部门数据预览:
{"id": 2, "name": "Marketing", "location": "Building B, Floor 1", "budget": 800000.00}
{"id": 3, "name": "Finance", "location": "Building A, Floor 3", "budget": 1200000.00}
{"id": 4, "name": "Human Resources", "location": "Building B, Floor 2", "budget": 500000.00}
{"id": 1, "name": "Engineering", "location": "Building A, Floor 2", "budget": 1500000.00}
{"id": 5, "name": "Sales", "location": "Building C, Floor 1", "budget": 2000000.00}


## 7. 执行 SQL 查询

In [45]:
# 读取员工表
employees_df = read_with_spark_jdbc("testdb.employees")

# 将数据帧注册为临时视图
employees_df.createOrReplaceTempView("employees_view")

# 确保 departments_dyf 已经正确读取
departments_dyf = read_with_glue_dynamicframe("testdb.departments")
departments_df = departments_dyf.toDF()
departments_df.createOrReplaceTempView("departments_view")

# 执行 SQL 查询 - 按部门统计员工数量和平均薪资
query = """
SELECT 
    e.department, 
    COUNT(*) as employee_count, 
    AVG(e.salary) as avg_salary
FROM employees_view e
GROUP BY e.department
ORDER BY employee_count DESC
"""

result = spark.sql(query)
print("按部门统计的员工数量和平均薪资:")
result.show()

Reading table testdb.employees using Spark JDBC...
Reading table testdb.departments using Glue DynamicFrame...
Successfully read 5 records from testdb.departments
按部门统计的员工数量和平均薪资:
+---------------+--------------+-------------+
|     department|employee_count|   avg_salary|
+---------------+--------------+-------------+
|    Engineering|             5| 89800.000000|
|        Finance|             3| 92000.000000|
|      Marketing|             3| 75000.000000|
|          Sales|             2|107500.000000|
|Human Resources|             2| 66500.000000|
+---------------+--------------+-------------+


## 8. 执行复杂的联表查询

In [46]:
# 读取项目表和员工项目关联表
projects_df = read_with_spark_jdbc("testdb.projects")
projects_df.createOrReplaceTempView("projects_view")

employee_projects_df = read_with_spark_jdbc("testdb.employee_projects")
employee_projects_df.createOrReplaceTempView("employee_projects_view")

# 执行复杂的联表查询 - 查找每个项目的参与员工和总预算
complex_query = """
SELECT 
    p.name as project_name, 
    p.budget as project_budget,
    d.name as department_name,
    COUNT(DISTINCT ep.employee_id) as employee_count
FROM projects_view p
JOIN departments_view d ON p.department_id = d.id
JOIN employee_projects_view ep ON p.id = ep.project_id
GROUP BY p.name, p.budget, d.name
ORDER BY p.budget DESC
"""

complex_result = spark.sql(complex_query)
print("项目、预算和参与员工数量:")
complex_result.show()

Reading table testdb.projects using Spark JDBC...
Reading table testdb.employee_projects using Spark JDBC...
项目、预算和参与员工数量:
+--------------------+--------------+---------------+--------------+
|        project_name|project_budget|department_name|employee_count|
+--------------------+--------------+---------------+--------------+
|     Cloud Migration|     500000.00|    Engineering|             4|
|Mobile App Develo...|     420000.00|    Engineering|             3|
|Financial System ...|     350000.00|        Finance|             3|
|Sales Analytics P...|     300000.00|          Sales|             2|
|    Website Redesign|     250000.00|      Marketing|             3|
|     Employee Portal|     180000.00|Human Resources|             2|
+--------------------+--------------+---------------+--------------+


## 9. 清空表数据

In [47]:
def truncate_table(table_name):
    """清空指定表的所有数据"""
    try:
        # 从 connection_options 中获取连接信息
        jdbc_url = connection_options["url"]
        user = connection_options["user"]
        password = connection_options["password"]
        
        # 确保 URL 包含数据库名称
        database_name = "testdb"
        if database_name not in jdbc_url:
            if jdbc_url.endswith("/"):
                jdbc_url += database_name
            else:
                jdbc_url += "/" + database_name
        
        # 注意：TRUNCATE TABLE 不能通过 JDBC 的 query 选项直接执行
        # 需要使用 executeUpdate 方法
        
        # 创建一个临时表来执行 SQL 命令
        temp_df = spark.createDataFrame([("dummy",)], ["col"])
        
        # 使用 Spark SQL 执行 TRUNCATE 语句
        # 这里我们使用一个变通方法：先删除所有数据，然后重置自增 ID
        delete_sql = f"DELETE FROM {table_name}"
        
        temp_df.write \
            .format("jdbc") \
            .option("url", jdbc_url) \
            .option("dbtable", table_name) \
            .option("user", user) \
            .option("password", password) \
            .option("truncate", "true") \
            .mode("overwrite") \
            .save()
        
        print(f"✅ 表 {table_name} 已成功清空")
        return True
    except Exception as e:
        print(f"❌ 清空表 {table_name} 失败: {e}")
        return False

# 注意：此步不可逆
truncate_table("testdb.test_data")

✅ 表 testdb.test_data 已成功清空
True


## 10. 将结果写回到 MySQL 数据库

In [48]:
def write_to_mysql(dataframe, target_table, write_mode="overwrite"):
    """将数据写入 MySQL 表"""
    # 从 connection_options 中获取连接信息
    jdbc_url = connection_options["url"]
    user = connection_options["user"]
    password = connection_options["password"]
    
    # 确保 URL 包含数据库名称
    database_name = "testdb"
    if database_name not in jdbc_url:
        if jdbc_url.endswith("/"):
            jdbc_url += database_name
        else:
            jdbc_url += "/" + database_name
    
    # 构建连接属性
    connection_properties = {
        "user": user,
        "password": password,
        "driver": "com.mysql.jdbc.Driver"
    }
    
    # 确保表名包含数据库名称
    if "." not in target_table:
        full_table_name = f"{database_name}.{target_table}"
    else:
        full_table_name = target_table
    
    print(f"Writing data to table {full_table_name}...")
    
    try:
        dataframe.write.jdbc(
            url=jdbc_url,
            table=full_table_name,
            mode=write_mode,
            properties=connection_properties
        )
        print(f"✅ 数据已成功写入表 {full_table_name}")
    except Exception as e:
        print(f"❌ 写入数据失败: {e}")

# 创建一个新的数据帧用于写入测试
test_data = [
    (1, "测试部门1", "测试地点1", 100000.00),
    (2, "测试部门2", "测试地点2", 200000.00),
    (3, "测试部门3", "测试地点3", 300000.00)
]

# 创建 DataFrame
schema = ["id", "name", "location", "budget"]
test_df = spark.createDataFrame(test_data, schema)

# 显示要写入的数据
print("要写入的测试数据:")
test_df.show()

# 将数据写入新表
# write_to_mysql(test_df, "test_departments", "overwrite")

要写入的测试数据:
+---+---------+---------+--------+
| id|     name| location|  budget|
+---+---------+---------+--------+
|  1|测试部门1|测试地点1|100000.0|
|  2|测试部门2|测试地点2|200000.0|
|  3|测试部门3|测试地点3|300000.0|
+---+---------+---------+--------+


## 11. 使用 pandas 进行数据分析（适用于小数据集）

In [49]:
# 将 Spark DataFrame 转换为 pandas DataFrame 进行更多分析
employees_pd = employees_df.toPandas()

# 使用 pandas 进行数据分析
print("使用 pandas 计算部门薪资统计:")
salary_stats = employees_pd.groupby('department')['salary'].agg(['mean', 'min', 'max']).reset_index()
print(salary_stats)

使用 pandas 计算部门薪资统计:
        department      mean        min        max
0      Engineering   89800.0   85000.00   95000.00
1          Finance   92000.0   88000.00   95000.00
2  Human Resources   66500.0   65000.00   68000.00
3        Marketing   75000.0   72000.00   78000.00
4            Sales  107500.0  105000.00  110000.00


## 12. 清理资源

In [None]:
# 删除所有临时视图
spark.catalog.clearCache()  # 清除缓存
temp_views = spark.catalog.listTables()
for view in temp_views:
    if view.isTemporary:
        spark.catalog.dropTempView(view.name)
        print(f"删除临时视图: {view.name}")
                
# 停止 Spark 会话
# spark.stop() # 在 Glue Notebook 中不要停止 Spark 会话，因为它由环境管理

print("资源清理完成!")