In [20]:
from datetime import datetime, timedelta
import gc
import os
import json
import numpy as np
import pandas as pd
from pyspark.sql import Window
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import avg, from_json, to_json
from pyspark.sql.functions import col, concat, lit, explode, array, struct
from pyspark.sql.types import StringType, ArrayType,IntegerType, FloatType,StructType,StructField,BooleanType, DateType
from pyspark.ml.feature import StringIndexer
from aibrain_common.utils.date_convert_utils import DateConvertUtils
from aibrain_common.component import tools
import uuid
from aibrain_common.data.dataset_builder import DatasetBuilder
from aibrain_common.utils import env_utils

from collections import defaultdict
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler


# %matplotlib notebook


import logging

logger = logging.getLogger(__name__)

import gc


spark = SparkSession.builder.\
        config('spark.executor.memory', '12g').\
        config('spark.executor.cores', '6').\
        config('spark.driver.memory','10g').\
        config('spark.executor.instances', '10').\
        config('spark.driver.maxResultSize', '50000m').\
        appName('ebiktrainfeature').\
        enableHiveSupport().getOrCreate()


time_str_formats = {
    "hour": "%Y%m%d%H",
    "day": "%Y%m%d",
}

# device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

In [21]:
def datetime2str(date: datetime, rtype="hour"):
    if rtype not in time_str_formats:
        raise ValueError("rtype Error!")
    else:
        return date.strftime(time_str_formats[rtype])


def str2datetime(s, itype="hour"):
    return datetime.strptime(s, time_str_formats[itype])
def add_delta(time_str: str, delta: dict, itype="day", rtype="day"):
    target_time = str2datetime(time_str, itype) + timedelta(**delta)
    target_time = datetime2str(target_time, rtype)
    return target_time

In [22]:
today = '20251020'  #测试阶段预测已知的值

# today = DateConvertUtils().parse_data_date('${yyyymmdd}')

print("today is :%s" % today)

pt = today # 运行当天t
week = 6 # 训练数据拉取周数
tomorrow = add_delta(today, {'days': 1}, "day", "day") # 预测的日期（t+1）
yesterday = add_delta(today, {'days': -1}, "day", "day") # 前一天，用于找大点


day1 = - (week * 7 - 1) # 差值为13，获取14天的数据特征
end_date = today  # 能够获取到的最新的真值的日期为t-2(流入流出特征需要两天真值来计算)
start_date = add_delta(end_date, {'days': day1}, "day", "day")  # 预测需要用到前14天特征，
twoweek_ago_date = add_delta(end_date, {'days': - 13}, "day", "day") 
# start_date = '20250708' 
start_date_minus11 = (datetime.strptime(start_date, '%Y%m%d') + timedelta(days=-11)).date().strftime('%Y%m%d')  #还需要+11天得到lag14日特征(当前为t-3)
start_date_2 = (datetime.strptime(start_date, '%Y%m%d') + timedelta(days=-2)).date().strftime('%Y%m%d')
start_date_add2 = (datetime.strptime(start_date, '%Y%m%d') + timedelta(days=2)).date().strftime('%Y%m%d')
# start_date_add1 = (datetime.strptime(start_date, '%Y%m%d') + timedelta(days=1)).date().strftime('%Y%m%d')
# start_date_1 = (datetime.strptime(start_date, '%Y%m%d') + timedelta(days=-1)).date().strftime('%Y%m%d')
# end_date = '20250812'  # 0812  一共36天,最后四天两天周末两天周中用作验证
# end_date_add1 = (datetime.strptime(end_date, '%Y%m%d') + timedelta(days=1)).date().strftime('%Y%m%d')
# end_date_1 = (datetime.strptime(end_date, '%Y%m%d') + timedelta(days=-1)).date().strftime('%Y%m%d')
end_date_add2 = (datetime.strptime(end_date, '%Y%m%d') + timedelta(days=2)).date().strftime('%Y%m%d')
end_date_2 = (datetime.strptime(end_date, '%Y%m%d') + timedelta(days=-2)).date().strftime('%Y%m%d')


today is :20250928


In [23]:
df_flow = spark.sql(f'''
with 
    bike_start_order as (
        select 
            bike_start_park_guid as parking_guid,
            count(order_id) as daily_order_cnt
        from dwd.dwd_trd_ord_ebik_order_ent_di
        where pt between '{twoweek_ago_date}' and '{end_date}'
        group by bike_start_park_guid,pt
    ),
    
    streets_max_orders as (
        select 
            parking_guid,
            max(daily_order_cnt) as max_daily_orders
        from bike_start_order
        group by parking_guid
        having max_daily_orders >= 15
    ),
        
    streets_mean_orders as (
        select 
            parking_guid,
            mean(daily_order_cnt) as mean_daily_orders
        from bike_start_order
        group by parking_guid
        having mean_daily_orders >= 10
    ),
    
    filtered_streets as (
        select parking_guid from streets_max_orders
        union
        select parking_guid from streets_mean_orders
    ),

    net_data as (
        select 
            site_guid, city_guid, label_10, label_16, label_21, hour, pt
        from turing.ebike_site_period_net_out_label_di
        where pt between '{start_date_minus11}' and '{end_date}'
    )

select net_data.* from net_data
where net_data.site_guid in (select parking_guid from filtered_streets)
and net_data.site_guid is not null

''')


In [24]:
@F.udf(StringType())
def generate_datetime_hour(date_str, hour_int):
    return f"{date_str}{hour_int:02d}"

# 定义窗口规范
window_spec = Window.partitionBy("site_guid","hour").orderBy("date")
# 定义滞后天数列表
lags = [1, 2, 3, 4, 11]

# 2. lag24数据 (t-4日)
long_df = df_flow.withColumn("dt", generate_datetime_hour("pt", "hour")  # 如pt=20250701 + hour=12 → dt=2025070112
    ).withColumn('date', F.to_date(F.col('dt').cast('string'), 'yyyyMMddHH')
    )

# 为每个特征生成滞后列
for k in lags:
    # 计算滞后的小时数
    # lag_hours = 24 * k
    long_df = long_df \
        .withColumn(f"lag{k}d_10", F.lag("label_10", k).over(window_spec)) \
        .withColumn(f"lag{k}d_16", F.lag("label_16", k).over(window_spec)) \
        .withColumn(f"lag{k}d_21", F.lag("label_21", k).over(window_spec)) \

long_df = long_df.filter(F.col("pt").between(start_date, end_date)).fillna(0)
# long_df.cache()

In [25]:
ebik_park_hf = spark.read.format("iceberg").load("dwb.dwb_veh_ebik_park_hf") \
    .filter(F.col("pt").between(start_date,end_date)&
        (F.col("min") == "00")
    ).select("parking_guid","put_veh_cnt","idle_12h_cnt","idle_1d_cnt","hr","pt"
    ).withColumn("dt2", F.concat(F.col("pt"), F.col("hr"))
    )

long_df = long_df.join(ebik_park_hf, (long_df.site_guid==ebik_park_hf.parking_guid) & (long_df.dt==ebik_park_hf.dt2), 'left'
                ).select(long_df.site_guid,long_df.city_guid,"label_10","label_16","label_21","put_veh_cnt","idle_12h_cnt","idle_1d_cnt",
        "lag1d_10","lag1d_16","lag1d_21","lag2d_10","lag2d_16","lag2d_21",
        "lag3d_10","lag3d_16","lag3d_21","lag4d_10","lag4d_16","lag4d_21",
        "lag11d_10","lag11d_16","lag11d_21","hour",long_df.pt,"dt")

In [26]:
fill_dict = {
    "label_10": 0,
    "label_16": 0,
    "label_21": 0,
    "put_veh_cnt": 0,
    "idle_12h_cnt": 0,
    "idle_1d_cnt": 0,    
    "lag1d_10": 0,
    "lag1d_16": 0,
    "lag1d_21": 0,
    "lag2d_10": 0,
    "lag2d_16": 0,
    "lag2d_21": 0,
    "lag3d_10": 0,
    "lag3d_16": 0,
    "lag3d_21": 0,
    "lag4d_10": 0,
    "lag4d_16": 0,
    "lag4d_21": 0,
    "lag11d_10": 0,
    "lag11d_16": 0,
    "lag11d_21": 0
    # 'cycle_weather_level': 2,  # 假设填充2
    # 'workday_level': 1,
    # 'temperature_avg_val': 25.00,  # 假设填充25
}


long_df = long_df.fillna(fill_dict)
long_df = long_df.select(
    'dt', 
    'site_guid',
    'city_guid',
    "label_10", 
    "label_16", 
    "label_21",
    "put_veh_cnt",
    "idle_12h_cnt",
    "idle_1d_cnt",
    "lag1d_10",
    "lag1d_16",
    "lag1d_21",
    "lag2d_10",
    "lag2d_16",
    "lag2d_21",
    "lag3d_10",
    "lag3d_16",
    "lag3d_21",
    "lag4d_10",
    "lag4d_16",
    "lag4d_21",
    "lag11d_10",
    "lag11d_16",
    "lag11d_21", 
    'hour'
).sort('site_guid', 'dt')

# long_df.cache()

In [27]:
from pyspark.sql.functions import avg, from_json, to_json
def create_optimized_intermediate_table_v2(long_df, pred_date, seq_len, id_col="site_guid", time_col="dt", feature_cols=None):
    """
    修复版本：避免复杂UDF，使用更直接的方法
    """
    if feature_cols is None:
        feature_cols = [col for col in long_df.columns if col not in [id_col, time_col, 'city_guid']]
    
    from pyspark.sql.functions import collect_list, array, lit, row_number, concat_ws, to_json
    from pyspark.sql.window import Window
    
    print(f"特征列: {feature_cols}")
    
   # 数据预处理
    window_spec = Window.partitionBy(id_col).orderBy(F.desc(time_col))
    windowed_df = long_df.withColumn("row_num", row_number().over(window_spec)) \
                         .filter(F.col("row_num") <= seq_len) \
                         .orderBy(id_col, time_col)
    
    # 创建特征向量
    windowed_df = windowed_df.withColumn(
        "feature_vector", 
        array(*[F.col(feat).cast("double") for feat in feature_cols])
    )
    
    # 分组并收集
    grouped_df = windowed_df.groupBy(id_col,"city_guid").agg(
        collect_list("feature_vector").alias("features_array"),
        F.count("*").alias("seq_count")
    ).filter(F.col("seq_count") == seq_len)
    
    # 方案A: 存储为JSON字符串（推荐）
    final_df = grouped_df.withColumn("features_json", to_json(F.col("features_array"))
                                    ).select(id_col,"city_guid","features_json")
    # final_df.show(2, truncate=False)
    return final_df
# outputtable = 'turing_dev.turing_ebike_fixtime_train_features_df'
final_df = create_optimized_intermediate_table_v2(long_df, '20250928', 42*24, id_col="site_guid", time_col="dt", feature_cols=None)

# 保存
# final_df.createOrReplaceTempView("final_feature_df")
# spark.sql(f"""
#     INSERT OVERWRITE TABLE {output_table_name} 
#     PARTITION(pt='{pred_date}') 
#     SELECT {id_col}, features_json
#     FROM final_feature_df
# """)

# print(f"优化的中间表已保存: {output_table_name}")
# print(f"序列长度: {seq_len}, 特征数量: {len(feature_cols)}")

特征列: ['label_10', 'label_16', 'label_21', 'put_veh_cnt', 'idle_12h_cnt', 'idle_1d_cnt', 'lag1d_10', 'lag1d_16', 'lag1d_21', 'lag2d_10', 'lag2d_16', 'lag2d_21', 'lag3d_10', 'lag3d_16', 'lag3d_21', 'lag4d_10', 'lag4d_16', 'lag4d_21', 'lag11d_10', 'lag11d_16', 'lag11d_21', 'hour']


In [59]:
final_df.show(20)

+--------------------+--------------------+--------------------+
|           site_guid|           city_guid|       features_json|
+--------------------+--------------------+--------------------+
|00a1cc65951c42f2a...|e272dd16cd5144168...|[[1.0,3.0,3.0,4.0...|
|0167b931b61742eaa...|d64cfb70450648c1b...|[[0.0,0.0,1.0,22....|
|0721f81075364b429...|f85204eed54348a4b...|[[5.0,5.0,7.0,7.0...|
|0a395e40l1l5ryfvw...|BEEF3892FC8C4F35B...|[[0.0,0.0,0.0,1.0...|
|0a831ec1l1lyewsuj...|85177E9C4BE74D748...|[[0.0,0.0,0.0,1.0...|
|0a83244bl1lv5tps8...|4e8c12b8428947c8a...|[[0.0,0.0,0.0,0.0...|
|0a8326bal1l358gku...|282D06BE2EAE4D859...|[[0.0,0.0,1.0,1.0...|
|0a8326bal1lb4apnh...|83903688AD394E9BB...|[[0.0,1.0,1.0,0.0...|
|0a8326bal1lb86dvf...|83903688AD394E9BB...|[[4.0,4.0,4.0,2.0...|
|0a8326bal1lcj3p9p...|83903688AD394E9BB...|[[7.0,8.0,8.0,7.0...|
|0a8326bal1lhtnvdx...|F495A777469444F08...|[[5.0,5.0,6.0,10....|
|0a832885l1lf23ifm...|B2DD2675A8D44EE49...|[[0.0,0.0,0.0,0.0...|
|0a832a0el1lwmdihi...|101

In [28]:
label_start_date = (datetime.strptime(start_date, '%Y%m%d') + timedelta(days=13+1)).date().strftime('%Y%m%d')  #小时数据在开始日之后14天集齐第一组特征，以其为t-2预测t+1，
label_end_date = (datetime.strptime(end_date, '%Y%m%d') + timedelta(days=-2)).date().strftime('%Y%m%d')   #天级数据以t-1预测t+1,故需要在最后一个真值前-2
# 使用DataFrame API生成日期序列，更稳定 
start_dt = datetime.strptime(label_start_date, '%Y%m%d')
end_dt = datetime.strptime(label_end_date, '%Y%m%d')

# 创建日期范围
date_range = [start_dt + timedelta(days=x) for x in range((end_dt - start_dt).days + 1)]

# 转换为Spark DataFrame
date_df = spark.createDataFrame(
    [(date.strftime('%Y%m%d'), date) for date in date_range],
    ["date_str", "date"]
)

print(f"生成的日期范围大小: {date_df.count()}")

window_site = Window.orderBy("site_guid,pt")
res_site = final_df.select("site_guid","city_guid")

# 交叉连接点位ID和日期
res_time_site = res_site.crossJoin(date_df)
res_time_site.cache()
# res_time_site.show(20)

生成的日期范围大小: 26


DataFrame[site_guid: string, city_guid: string, date_str: string, date: timestamp]

In [20]:
test1 = res_time_site.filter(F.col("site_guid")=='1074588174745837568')
test1.show(50)

+-------------------+--------------------+--------+-------------------+
|          site_guid|           city_guid|date_str|               date|
+-------------------+--------------------+--------+-------------------+
|1074588174745837568|40E4554E4C14445FA...|20250901|2025-09-01 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250902|2025-09-02 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250903|2025-09-03 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250904|2025-09-04 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250905|2025-09-05 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250906|2025-09-06 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250907|2025-09-07 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250908|2025-09-08 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250909|2025-09-09 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250910|2025-09-10 00:00:00|
|1074588174745837568|40E4554E4C14445FA...|20250911|2025-09-11 00

In [2]:
# 用全量数据得出city encode
city_guid_encode = spark.sql('''
    select distinct city_guid from 
        dim.dim_spt_fence_info
    where pt between 20251010 and 20251020
    and area_status = 5        
    and area_type = 207
''')
indexer = StringIndexer(inputCol="city_guid", outputCol="city_guid_encoded")
indexer_model = indexer.fit(city_guid_encode)
# static_feature_df = indexer_model.transform(static_feature_df)

# 保存 StringIndexer 模型（用于后续新数据）
idxmap = indexer_model.transform(
    city_guid_encode.select("city_guid").distinct()
).orderBy("city_guid_encoded").toPandas()
idxmap.to_csv("ebik_dataset/city_guid_mapping_3.csv",index = False)


In [29]:
window_start_dt = start_dt - timedelta(days=10)
daily_orders = spark.table("dwd.dwd_trd_ord_ebik_order_ent_di")\
    .filter(F.col("pt").between(window_start_dt.strftime('%Y%m%d'), end_dt.strftime('%Y%m%d')))\
    .groupBy(F.col("bike_start_park_guid").alias("parking_guid"), "pt")\
    .agg(F.count("order_id").alias("daily_order_cnt"),
    F.max("start_parking_capacity").alias("parking_capacity"))

window_spec = Window.partitionBy("parking_guid").orderBy("pt")
daily_orders = daily_orders.withColumn(f"lag7d_order_cnt", F.lag("daily_order_cnt", 5).over(window_spec))   #真实对应上周数据

# 读取 CSV 城市映射表
mapping_df = pd.read_csv("ebik_dataset/city_guid_mapping_3.csv")
mapping_df2 = spark.createDataFrame(mapping_df)
max_citys = len(mapping_df) 

static_feature_df = res_time_site.join(
    daily_orders.select(
        F.col("parking_guid"),
        F.col("pt"),
        F.col("lag7d_order_cnt"),
        F.col("parking_capacity")
    ),
    (F.col("site_guid") == F.col("parking_guid")) &
    (F.col("pt") == F.col("date_str")),
    "left"
).join(
    mapping_df2.select("city_guid", "city_guid_encoded"),
    on="city_guid",
    how="left"
).fillna({"city_guid_encoded": max_citys}
).select(
    F.col("site_guid"),
    "city_guid",
    "date_str",
    F.date_format(F.date_add(F.col("date"), 2), "yyyyMMdd").alias("date_pred"),
    F.coalesce("city_guid_encoded",F.lit(max_citys)).alias("city_guid_encoded"), 
    F.dayofweek(F.date_add(F.col("date"), 2)).alias("day_of_week"),
    F.coalesce("lag7d_order_cnt", F.lit(0)).alias("lag7d_order_cnt"),
    F.coalesce("parking_capacity", F.lit(0)).alias("parking_capacity")
)




# 查看编码后的数据
# df_encoded.select("city_guid", "city_guid_encoded").show()
# static_feature_df.show(50)

In [13]:
mapping_df = pd.read_csv("ebik_dataset/city_guid_mapping_3.csv")
mapping_df2 =  spark.createDataFrame(mapping_df)
len(mapping_df)

1023

In [41]:
# 用 pandas 读取 CSV，跳过可能的错误行
try:
    df_test = pd.read_csv("ebik_dataset/city_guid_mapping.csv")
    print(df_test.head())
except Exception as e:
    print("Error:", e)
    # 用逐行读取检查问题行
    with open("ebik_dataset/city_guid_mapping.csv", "r") as f:
        for i, line in enumerate(f):
            if i >= 100:  # 检查第 1 行和第 105 行附近
                print(f"Line {i}: {line.strip()}")

                          city_guid  city_guid_encoded
0  B2DD2675A8D44EE49224906B90E8EAF9                0.0
1  AA99442B62E7485086F77A1DBD4FDF65                1.0
2  D9DBDF2F159143778C7748C47B262BB4                2.0
3  BEEF3892FC8C4F35BC8B7F832FA311D7                3.0
4  A2A7BDAF3532442CA33B8A5F4BA2F6C5                4.0


In [30]:
@F.udf(StringType())
def generate_datetime_hour_minus2(date_str, hour_int):
    date_m2 = (datetime.strptime(date_str, '%Y%m%d') + timedelta(days=-2)).date().strftime('%Y%m%d') 
    return f"{date_m2}{hour_int:02d}"

start_dt1 = (start_dt + timedelta(days=1)).date()
end_dt1 = (end_dt + timedelta(days=1)).date()

# 2. 获取数据
wtw = spark.table('turing_dev.turing_net_pred_wea_temp_wkd_feature') \
    .select('city_guid','forecast_date','pred_date','cycle_weather_level','temperature_avg_val','workday_level') \
    .filter(F.col('pred_date').between(start_dt1.strftime('%Y%m%d'), end_dt1.strftime('%Y%m%d'))) \
    .groupBy('city_guid', 'pred_date') \
    .agg(
        F.first('cycle_weather_level').alias('cycle_weather_level'),
        F.first('temperature_avg_val').alias('temperature_avg_val'),
        F.first('workday_level').alias('workday_level')
    )

static_feature_df = static_feature_df.join(wtw, 
                       (static_feature_df.date_pred==wtw.pred_date) &
                       (static_feature_df.city_guid==wtw.city_guid), 'left'
        ).dropna(subset = ['site_guid']
        ).select(static_feature_df.site_guid, static_feature_df.city_guid, F.col('date_pred').alias('pt'),
                static_feature_df.lag7d_order_cnt, static_feature_df.parking_capacity,
                F.coalesce(wtw.temperature_avg_val,F.lit(20)).alias("temperature_avg_val"), 
                static_feature_df.city_guid_encoded, static_feature_df.day_of_week,
                F.coalesce(wtw.workday_level,F.lit(1)).alias("workday_level"),
                F.coalesce(wtw.cycle_weather_level,F.lit(2)).alias("cycle_weather_level")
        ).sort('site_guid', 'date_str')



In [59]:
static_feature_df.show(50)

+--------------------+--------------------+--------+---------------+----------------+-------------------+-----------------+-----------+-------------+-------------------+
|           site_guid|           city_guid|      pt|lag7d_order_cnt|parking_capacity|temperature_avg_val|city_guid_encoded|day_of_week|workday_level|cycle_weather_level|
+--------------------+--------------------+--------+---------------+----------------+-------------------+-----------------+-----------+-------------+-------------------+
|0002139af19b44849...|39CA4DFE94E647909...|20250901|             11|              20|            27.9800|             21.0|          4|            1|                  3|
|0002139af19b44849...|39CA4DFE94E647909...|20250902|              3|             144|            28.7700|             21.0|          5|            1|                  3|
|0002139af19b44849...|39CA4DFE94E647909...|20250903|              5|             144|            29.7400|             21.0|          6|            1| 

In [31]:
final_df_daily = create_optimized_intermediate_table_v2(static_feature_df, '20250928', 26, id_col="site_guid", time_col="pt", feature_cols=None)
# final_df_daily.show(50)
final_df = final_df.alias("df1").join(final_df_daily.alias("df2"), 
                                     "site_guid", 
                                     'left'
                                    ).select(final_df.site_guid, 
                                           final_df.city_guid,
                                           final_df.features_json, 
                                           final_df_daily.features_json.alias("daily_features_json"))
final_df.createOrReplaceTempView("final_feature_df")
output_table_name = "turing_dev.turing_ebike_fixtime_train_features_df_2"
pred_date = 20251001
spark.sql(f"""
    INSERT OVERWRITE TABLE {output_table_name} 
    PARTITION(pt='{pred_date}') 
    SELECT site_guid, city_guid, features_json, daily_features_json
    FROM final_feature_df
""")

print(f"优化的中间表已保存: {output_table_name}")
# print(f"序列长度: {seq_len}, 特征数量: {len(feature_cols)}")

特征列: ['lag7d_order_cnt', 'parking_capacity', 'temperature_avg_val', 'city_guid_encoded', 'day_of_week', 'workday_level', 'cycle_weather_level']
优化的中间表已保存: turing_dev.turing_ebike_fixtime_train_features_df_2


In [None]:
static_feature_df.show(50)

In [32]:
from tqdm import tqdm
import os
def batch_predict_from_spark_table_optimized(table_name, pred_date, output_dir, batch_size=2500, 
                                           hourly_seq_len=336, static_seq_len=26, hourly_feature_count=11, static_feature_count=5):
    
    
    # 1. 为数据添加行号以便分批处理
    df = spark.sql(f"""
        SELECT *, 
               ROW_NUMBER() OVER (ORDER BY site_guid) as row_id
        FROM turing_dev.turing_ebike_fixtime_train_features_df_2
        WHERE pt = '{pred_date}'
    """)
    
    # 获取总行数
    total_count = df.count()
    print(f"总共需要处理 {total_count} 个点位")
    
    # 2. 计算批次数
    num_batches = (total_count + batch_size - 1) // batch_size
    
    all_time_features = []
    all_daily_features = []
    all_site_guids = []
    
    for batch_idx in range(num_batches):
        start_row = batch_idx * batch_size + 1  # row_number从1开始
        end_row = min((batch_idx + 1) * batch_size, total_count)
        
        print(f"处理批次 {batch_idx + 1}/{num_batches}, 行范围: {start_row}-{end_row}")
        
        # 3. 获取当前批次数据
        hourly_batch_df = df.filter(F.col('row_id').between(start_row, end_row)
                            ).select('site_guid','hourly_features')
        static_batch_df = df.filter(F.col('row_id').between(start_row, end_row)
                            ).select('site_guid','static_features')
        
        hourly_batch_data = hourly_batch_df.collect()
        static_batch_data = static_batch_df.collect()
        
        if not hourly_batch_data or not static_batch_data:
            continue
            
        # 处理批次数据
        hourly_batch_features, site_guids = prepare_batch_features(
            hourly_batch_data, hourly_seq_len, hourly_feature_count, 'hourly_features'
        )
        static_batch_features, site_guids = prepare_batch_features(
            static_batch_data, static_seq_len, static_feature_count, 'static_features'
        )
        
        
        
        if len(hourly_batch_features) > 0:
            np.save(os.path.join(output_dir, f'batch_{batch_idx}_hourly_features.npy'), hourly_batch_features)
            with open(os.path.join(output_dir, f'batch_{batch_idx}_guids.txt'), 'w') as f:
                for guid in site_guids:
                    f.write(f"{guid}\n")
        
            print(f"已保存时间特征，批次 {batch_idx}: {hourly_batch_features.shape}")
            del hourly_batch_features, site_guids
    
        
        if len(static_batch_features) > 0:
            np.save(os.path.join(output_dir, f'batch_{batch_idx}_static_features.npy'), static_batch_features)
#             with open(os.path.join(output_dir, f'batch_{batch_idx}_guids.txt'), 'w') as f:
#                 for guid in site_guids:
#                     f.write(f"{guid}\n")
        
            print(f"已保存静态特征，批次 {batch_idx}: {static_batch_features.shape}")
            del static_batch_features
#     # 4. 合并结果
#     if all_features:
#         final_features = np.concatenate(all_features, axis=0)
#         print(f"最终特征集形状: {final_features.shape}")
#         print(f"总共处理了 {len(all_site_guids)} 个有效点位")
#     else:
#         final_features = np.empty((0, seq_len, feature_count))
#         print("警告: 没有有效的特征数据")
    
    return

def prepare_batch_features(batch_data, seq_len, feature_count, col_name):
    """
    解码JSON格式的特征数据
    """
    import json
    
    batch_features = []
    site_guids = []
    
    for row in batch_data:
        site_guid = row['site_guid']
        features_json = row[col_name]
        
        # print(f"处理 {site_guid}")
        # print(f"JSON数据类型: {type(features_json)}")
        
        try:
            # 解析JSON
            if isinstance(features_json, str):
                features_list = json.loads(features_json)
                # print("JSON解析成功")
            else:
                features_list = features_json
                print("直接使用原始数据")
            
            # print(f"解析后数据类型: {type(features_list)}")
            # print(f"解析后数据长度: {len(features_list) if features_list else 0}")
            
            # if features_list and len(features_list) > 0:
                # print(f"第一个时间步: {features_list[0]}")
            
            # 转换为numpy数组
            if features_list and len(features_list) == seq_len:
                feature_matrix = []
                
                for i, time_step in enumerate(features_list):
                    if isinstance(time_step, (list, tuple)) and len(time_step) == feature_count:
                        row_data = [float(x) for x in time_step]
                        feature_matrix.append(row_data)
                    else:
                        print(f"时间步 {i} 格式错误: {type(time_step)}, 长度: {len(time_step) if hasattr(time_step, '__len__') else 'N/A'}")
                        break
                else:
                    # 所有时间步都正确处理
                    feature_array = np.array(feature_matrix, dtype=np.float32)
                    # print(f"转换后shape: {feature_array.shape}")
                    
                    if feature_array.shape == (seq_len, feature_count):
                        batch_features.append(feature_array)
                        site_guids.append(site_guid)
                        # print("✓ 成功添加")
                    else:
                        print(f"✗ 维度不匹配: {feature_array.shape} vs ({seq_len}, {feature_count})")
            else:
                print(f"✗ 序列长度不匹配: {len(features_list) if features_list else 0} vs {seq_len}")
                
        # except json.JSONDecodeError as e:
            # print(f"✗ JSON解析错误 {site_guid}: {e}")
            # print(f"原始数据: {features_json[:200]}...")
        except Exception as e:
            # print(f"✗ 转换错误 {site_guid}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # 转换为3D数组 (batch_size, seq_len, feature_count)
    if batch_features:
        batch_features = np.stack(batch_features, axis=0)
        print(f"最终批次形状: {batch_features.shape}")
        # print(f"数据范围: min={batch_features.min():.4f}, max={batch_features.max():.4f}")
    else:
        batch_features = np.empty((0, seq_len, feature_count))
        # print("警告: 没有有效的特征数据")
    
    return batch_features, site_guids

In [33]:
table_name = "turing_dev.turing_ebike_fixtime_train_features_df_2"
output_dir = "ebik_dataset/traindata1"
pred_date = "20251001"
batch_predict_from_spark_table_optimized(
    table_name=table_name,
    output_dir=output_dir,
    pred_date=pred_date,
    batch_size=25000,  # 根据内存情况调整
    hourly_seq_len=42*24, 
    static_seq_len=26, 
    hourly_feature_count=22, 
    static_feature_count=7
)

总共需要处理 113468 个点位
处理批次 1/5, 行范围: 1-25000
最终批次形状: (25000, 1008, 22)
最终批次形状: (25000, 26, 7)
已保存时间特征，批次 0: (25000, 1008, 22)
已保存静态特征，批次 0: (25000, 26, 7)
处理批次 2/5, 行范围: 25001-50000
最终批次形状: (25000, 1008, 22)
最终批次形状: (25000, 26, 7)
已保存时间特征，批次 1: (25000, 1008, 22)
已保存静态特征，批次 1: (25000, 26, 7)
处理批次 3/5, 行范围: 50001-75000
最终批次形状: (25000, 1008, 22)
最终批次形状: (25000, 26, 7)
已保存时间特征，批次 2: (25000, 1008, 22)
已保存静态特征，批次 2: (25000, 26, 7)
处理批次 4/5, 行范围: 75001-100000
最终批次形状: (25000, 1008, 22)
最终批次形状: (25000, 26, 7)
已保存时间特征，批次 3: (25000, 1008, 22)
已保存静态特征，批次 3: (25000, 26, 7)
处理批次 5/5, 行范围: 100001-113468
最终批次形状: (13468, 1008, 22)
最终批次形状: (13468, 26, 7)
已保存时间特征，批次 4: (13468, 1008, 22)
已保存静态特征，批次 4: (13468, 26, 7)


In [2]:
spark.sql("""
CREATE TABLE turing_dev.turing_ebike_site_fixtime_predict_features_df_2(
  site_guid string COMMENT '点位id',
  city_guid string COMMENT '城市guid',
  hourly_features string COMMENT '时序特征',
  static_features string COMMENT '静态特征'
  )
COMMENT '助力车点位定终点模型训练特征数据表2'
PARTITIONED BY ( 
  pt string COMMENT '预测的时间')
""")

DataFrame[]

In [61]:
gc.collect()

101816