In [1]:
import pandas as pd
import pymysql
import numpy as np

# --------------- 配置区，根据自己环境修改 ---------------
MYSQL_CONNECTION_INFO = {
    "host": "localhost",
    "user": "root",
    "password": "",
    "database": "dsci551_project"  # 修改为你的数据库名
}
csv_file = "train.csv"  # CSV文件的路径或名称
# ------------------------------------------------------

In [2]:
def create_tables(cursor):
    """
    在数据库中创建3张表：customers, products, orders（如果不存在）。
    这里自行决定实体及其属性，处理好主键和外键。
    """

    # 1. Customers 表
    create_customers = """
    CREATE TABLE IF NOT EXISTS customers (
        customer_id VARCHAR(50) PRIMARY KEY,
        customer_name VARCHAR(100),
        segment VARCHAR(50),
        city VARCHAR(100),
        state VARCHAR(100),
        region VARCHAR(50),
        country VARCHAR(50),
        postal_code VARCHAR(20)
    )
    """

    # 2. Products 表
    create_products = """
    CREATE TABLE IF NOT EXISTS products (
        product_id VARCHAR(50) PRIMARY KEY,
        product_name VARCHAR(200),
        category VARCHAR(50),
        sub_category VARCHAR(50)
    )
    """

    # 3. Orders 表
    # 注意：Row ID 这里用作主键，通常每行对应一个“订单-产品”明细
    create_orders = """
    CREATE TABLE IF NOT EXISTS orders (
        row_id INT PRIMARY KEY,
        order_id VARCHAR(50),
        order_date DATE,
        ship_date DATE,
        ship_mode VARCHAR(50),
        sales FLOAT,
        customer_id VARCHAR(50),
        product_id VARCHAR(50),
        FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
        FOREIGN KEY (product_id) REFERENCES products(product_id)
    )
    """

    cursor.execute(create_customers)
    cursor.execute(create_products)
    cursor.execute(create_orders)


def insert_customers(cursor, df_customers):
    """
    向 customers 表插入数据。
    df_customers：包含 customers 去重后的 DataFrame
    """
    insert_query = """
    INSERT INTO customers (
        customer_id, customer_name, segment, city, state, region, country, postal_code
    ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
    """

    for _, row in df_customers.iterrows():
        customer_id = row["Customer ID"]
        customer_name = row["Customer Name"]
        segment = row["Segment"]
        city = row["City"]
        state = row["State"]
        region = row["Region"]
        country = row["Country"]
        # postal_code有可能缺失或是float，这里转成字符串处理
        postal_code = str(int(row["Postal Code"])) if not pd.isnull(row["Postal Code"]) else None

        cursor.execute(insert_query, (
            customer_id, 
            customer_name, 
            segment, 
            city, 
            state, 
            region, 
            country, 
            postal_code
        ))


def insert_products(cursor, df_products):
    """
    向 products 表插入数据。
    df_products：包含 products 去重后的 DataFrame
    """
    insert_query = """
    INSERT INTO products (
        product_id, product_name, category, sub_category
    ) VALUES (%s, %s, %s, %s)
    """

    for _, row in df_products.iterrows():
        product_id = row["Product ID"]
        product_name = row["Product Name"]
        category = row["Category"]
        sub_category = row["Sub-Category"]

        cursor.execute(insert_query, (
            product_id, 
            product_name, 
            category, 
            sub_category
        ))


def insert_orders(cursor, df_orders):
    """
    向 orders 表插入数据。
    df_orders：原表所有行，对应每个订单行。
    """
    insert_query = """
    INSERT INTO orders (
        row_id, order_id, order_date, ship_date, ship_mode, sales, customer_id, product_id
    ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
    """

    for _, row in df_orders.iterrows():
        row_id = row["Row ID"]
        order_id = row["Order ID"]

        # 处理日期列
        order_date = pd.to_datetime(row["Order Date"]).date() if not pd.isnull(row["Order Date"]) else None
        ship_date = pd.to_datetime(row["Ship Date"]).date() if not pd.isnull(row["Ship Date"]) else None

        ship_mode = row["Ship Mode"]
        sales = float(row["Sales"]) if not pd.isnull(row["Sales"]) else 0.0
        customer_id = row["Customer ID"]
        product_id = row["Product ID"]

        cursor.execute(insert_query, (
            row_id,
            order_id,
            order_date,
            ship_date,
            ship_mode,
            sales,
            customer_id,
            product_id
        ))



In [6]:

# ------------ 以下在一个Jupyter cell中执行拆分和插入逻辑 ------------
try:
    # A. 连接 MySQL
    connection = pymysql.connect(**MYSQL_CONNECTION_INFO)
    cursor = connection.cursor()

    # B. 创建表结构
    create_tables(cursor)

    # C. 读取 CSV
    df = pd.read_csv(csv_file)

    # D. 拆分
    df_customers = (
        df[[
            "Customer ID",
            "Customer Name",
            "Segment",
            "City",
            "State",
            "Region",
            "Country",
            "Postal Code"
        ]]
        .drop_duplicates(subset=["Customer ID"])
        .reset_index(drop=True)
    )

    df_products = (
        df[[
            "Product ID",
            "Product Name",
            "Category",
            "Sub-Category"
        ]]
        .drop_duplicates(subset=["Product ID"])
        .reset_index(drop=True)
    )

    df_orders = df.reset_index(drop=True)

    # E. 插入数据
    insert_customers(cursor, df_customers)
    insert_products(cursor, df_products)
    insert_orders(cursor, df_orders)

    # F. 提交事务
    connection.commit()
    print("已成功拆分并插入数据到多张表。")

except Exception as e:
    print("发生错误:", e)
    if connection:
        connection.rollback()

finally:
    if 'cursor' in locals():
        cursor.close()
    if 'connection' in locals():
        connection.close()


  ship_date = pd.to_datetime(row["Ship Date"]).date() if not pd.isnull(row["Ship Date"]) else None
  order_date = pd.to_datetime(row["Order Date"]).date() if not pd.isnull(row["Order Date"]) else None


已成功拆分并插入数据到多张表。
