In [5]:
import os
import json
import pandas as pd
from glob import glob

# Paths
input_dir = "01_raw_data/01_species_details"
output_path = "02_wrangled_data/Table02_GeneralPlantDescriptionTable.csv"

In [6]:
# Cycle label mapping
cycle_map = {
    "Perennial": "Every year",
    "Annual": "Once a year",
    "Biennial": "Every 2 years"
}

# Flatten general plant record
def flatten_general_description(data):
    plant_id = data.get("id")
    if plant_id > 3000:
        return None  # Skip threatened plants

    return {
        "general_plant_id": plant_id,
        "if_edible": data.get("edible_fruit", False) or data.get("edible_leaf", False) or data.get("cuisine", False),
        "if_indoors": data.get("indoor", False),
        "if_medicinal": data.get("medicinal", False),
        "if_poisonous": data.get("poisonous_to_humans", False) or data.get("poisonous_to_pets", False),
        "if_fruits": data.get("fruits", False),
        "if_flowers": data.get("flowers", False),
        "plant_type": data.get("type"),
        "plant_cycle": cycle_map.get(data.get("cycle"), data.get("cycle")),
        "attracts": json.dumps(data.get("attracts", []), ensure_ascii=False), # Change a little bit to fit MySQL Workbench
        "propagation": json.dumps(data.get("propagation", []), ensure_ascii=False),  # Change a little bit to fit MySQL Workbench
        "description": data.get("description")
    }

In [7]:
# Load and flatten JSON file
json_files = glob(os.path.join(input_dir, "plant_species_details_*.json"))
flattened_data = []

for file in json_files:
    with open(file, "r", encoding="utf-8") as f:
        data = json.load(f)
        record = flatten_general_description(data)
        if record:
            flattened_data.append(record)

In [8]:
# Create DataFrame and sort
df = pd.DataFrame(flattened_data)
df = df.sort_values(by="general_plant_id").reset_index(drop=True)

df["general_plant_id"] = pd.to_numeric(df["general_plant_id"], errors="coerce").astype("Int64")

ordered_cols = [
    "general_plant_id", "if_edible", "if_indoors", "if_medicinal", "if_poisonous",
    "if_fruits", "if_flowers", "plant_type", "plant_cycle",
    "attracts", "propagation", "description"
]
df = df[ordered_cols]

os.makedirs(os.path.dirname(output_path), exist_ok=True)
df.to_csv(output_path, index=False)

## 执行 LOAD DATA LOCAL INFILE 命令上传到数据库

In [9]:
import mysql.connector
from mysql.connector import Error

# 数据库连接配置
db_config = {
    'host': 'database-plantx.cqz06uycysiz.us-east-1.rds.amazonaws.com',
    'user': 'zihan',
    'password': '2002317Yzh12138.',
    'database': 'FIT5120_PlantX_Database',
    'allow_local_infile': True,
    'use_pure': True  # 使用纯Python实现
}

try:
    # 建立连接
    connection = mysql.connector.connect(**db_config)
    
    if connection.is_connected():
        print("成功连接到 MySQL 服务器")
        
        # 创建游标
        cursor = connection.cursor()
        
        # 构建 LOAD DATA LOCAL INFILE 命令
        # 注意：请将下面的文件路径替换为你实际的CSV文件路径
        load_data_query = """
        LOAD DATA LOCAL INFILE '02_wrangled_data/Table02_GeneralPlantDescriptionTable.csv'
        INTO TABLE Table02_GeneralPlantDescriptionTable
        CHARACTER SET utf8mb4
        FIELDS TERMINATED BY ',' 
        OPTIONALLY ENCLOSED BY '"'
        LINES TERMINATED BY '\\r\\n'
        IGNORE 1 LINES
        (   
            general_plant_id, if_edible, if_indoors, if_medicinal, if_poisonous,
            if_fruits, if_flowers, plant_type, plant_cycle, attracts, propagation,
            description
        );
        """
        
        # 执行命令
        cursor.execute(load_data_query)
        connection.commit()  # 提交事务
        
        print(f"数据导入成功！影响了 {cursor.rowcount} 行。")
        
except Error as e:
    print(f"执行过程中发生错误：{e}")
    
finally:
    # 关闭连接
    if connection.is_connected():
        cursor.close()
        connection.close()
        print("MySQL 连接已关闭。")

成功连接到 MySQL 服务器
数据导入成功！影响了 485 行。
MySQL 连接已关闭。


## 验证导入结果

In [11]:
# 在同一个连接会话中，或者在新的连接中执行
try:
    connection = mysql.connector.connect(**db_config)
    cursor = connection.cursor()
    
    cursor.execute("SELECT COUNT(*) FROM Table02_GeneralPlantDescriptionTable")
    row_count = cursor.fetchone()[0]
    print(f"表中现有 {row_count} 行数据")
    
    # 查看前几行数据
    cursor.execute("SELECT * FROM Table01_PlantMainTable LIMIT 5")
    rows = cursor.fetchall()
    for row in rows:
        print(row)
        
except Error as e:
    print(f"查询过程中发生错误：{e}")
finally:
    if connection.is_connected():
        cursor.close()
        connection.close()

表中现有 494 行数据
(1, 1, 0, 'European Silver Fir', 'Abies alba', 'Common Silver Fir', 'False', 'False', 'False', 'True', 'False', 'False', 'False', '["full sun"]', 'Frequent', 'Every year', 'High')
(2, 2, 0, 'Pyramidalis Silver Fir', "Abies alba 'Pyramidalis'", '', 'False', 'False', 'False', 'False', 'False', 'False', 'False', '["full sun"]', 'Average', 'Every year', 'Low')
(3, 3, 0, 'White Fir', 'Abies concolor', 'Silver Fir', 'False', 'False', 'False', 'True', 'False', 'False', 'True', '["Full sun", "part shade"]', 'Average', 'Every year', 'Low')
(4, 4, 0, 'Candicans White Fir', "Abies concolor 'Candicans'", 'Silver Fir', 'False', 'False', 'False', 'False', 'False', 'False', 'False', '["full sun"]', 'Average', 'Every year', 'Low')
(5, 5, 0, 'Fraser Fir', 'Abies fraseri', 'Southern Fir', 'False', 'False', 'False', 'False', 'False', 'False', 'True', '["full sun", "part shade", "filtered shade"]', 'Frequent', 'Every year', 'Moderate')
