In [None]:
import numpy as np
import pytensor.tensor as pt
import pytensor
import numba
pytensor.config.mode == 'NUMBA'
import pandas as pd
import pymc as pm  # 这是PyMC5
import matplotlib.pyplot as plt
import arviz as az
import pytensor.tensor as pt
import numpy as np
import tempfile
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [1]:
import pandas as pd
import numpy as np

def load_data(file_path):
    df = pd.read_csv(file_path, header=6, encoding='gbk')

    header_df = pd.read_csv(file_path, header=None, nrows=6, encoding='gbk')
    vehicle_names = []
    for col_idx in range(1, 6):
        vehicle_name = header_df.iloc[1, col_idx]
        if pd.notna(vehicle_name):
            vehicle_name_str = str(vehicle_name).strip()
            vehicle_name_str = ' '.join(vehicle_name_str.split())
            vehicle_names.append(vehicle_name_str)

    if len(vehicle_names) < 5:
        default_names = ['Vehicle 1', 'Vehicle 2', 'Vehicle 3', 'Vehicle 4', 'Vehicle 5']
        vehicle_names = vehicle_names + default_names[len(vehicle_names):5]

    print("Vehicle names:", vehicle_names)

    df.columns = [f'col_{i}' for i in range(len(df.columns))]

    data = {
        'time': df['col_0'].values,
        'speeds': [
            df['col_1'].values, 
            df['col_8'].values,  
            df['col_15'].values,  
            df['col_22'].values,  
            df['col_29'].values  
        ],
        'head_coords': {
            'e': df['col_5'].values,
            'n': df['col_6'].values,
            'u': df['col_7'].values
        },
        'distances': [
            df['col_41'].values, 
            df['col_42'].values, 
            df['col_43'].values, 
            df['col_44'].values 
        ],
        'vehicle_names': vehicle_names
    }

    return data

def downsample_data(data, dt_target=0.5):
    time = data['time']
    if len(time) <= 1:
        return data
    
    dt_current = np.mean(np.diff(time))
    step = max(1, int(round(dt_target / dt_current)))
    
    indices = np.arange(0, len(time), step)
    
    data_downsampled = {
        'time': time[indices],
        'speeds': [speed[indices] for speed in data['speeds']],
        'distances': [dist[indices] for dist in data['distances']],
        'head_coords': {k: v[indices] for k, v in data['head_coords'].items()},
        'vehicle_names': data['vehicle_names'],
    }
    
    return data_downsampled

def extract_vehicle_tracks(time, speeds, distances):
    tracks = {}
    
    follower_leader_mapping = [
        (1, 0, 0), 
        (2, 1, 1),  
        (3, 2, 2),  
        (4, 3, 3)   
    ]
    
    for pair_id, (follower_idx, leader_idx, dist_idx) in enumerate(follower_leader_mapping):
        follower_speed = speeds[follower_idx]
        leader_speed = speeds[leader_idx]
        spacing = distances[dist_idx]
        
        valid_mask = (
            (follower_speed > 0) & (follower_speed < 50) &
            (leader_speed > 0) & (leader_speed < 50) &
            (spacing > 2) & (spacing < 200) &
            (~np.isnan(follower_speed)) & 
            (~np.isnan(leader_speed)) & 
            (~np.isnan(spacing))
        )
        
        valid_indices = np.where(valid_mask)[0]
        valid_count = len(valid_indices)
        
        if valid_count < 10:  
            print(f"Vehicle pair {pair_id} (Vehicle {follower_idx+1} following Vehicle {leader_idx+1}) has insufficient data points, skipping")
            continue
        
        track_data = {
            'vFollReal': follower_speed[valid_indices],
            'vLeadReal': leader_speed[valid_indices],
            'sReal': spacing[valid_indices],
            'dvReal': leader_speed[valid_indices] - follower_speed[valid_indices],
            'vFollReal_next': np.zeros(valid_count),
            'driver_id': f"Driver_{follower_idx+1}", 
            'vehicle_pair': f"V{follower_idx+1}_following_V{leader_idx+1}"  
        }

        for i, idx in enumerate(valid_indices):
            if idx < len(time) - 1 and valid_mask[idx + 1]:
                track_data['vFollReal_next'][i] = follower_speed[idx + 1]
            else:
                track_data['vFollReal_next'][i] = np.nan
        
        final_valid_mask = ~np.isnan(track_data['vFollReal_next'])
        final_count = np.sum(final_valid_mask)
        
        if final_count > 10:  
            for key in ['vFollReal', 'vLeadReal', 'sReal', 'dvReal', 'vFollReal_next']:
                track_data[key] = track_data[key][final_valid_mask]
            
            tracks[pair_id] = track_data
            print(f"Vehicle pair {pair_id}: {track_data['vehicle_pair']} - {final_count} valid data points")
        else:
            print(f"Vehicle pair {pair_id} has insufficient valid data points, skipping")
    
    return tracks

def convert_to_ar_idm_input(tracks):
    vt_list, s_list, dv_list, label_v_list, id_idx_list = [], [], [], [], []
    
    for pair_id, track_data in tracks.items():
        vt_list.append(track_data['vFollReal'])
        s_list.append(track_data['sReal'])
        dv_list.append(track_data['dvReal'])
        label_v_list.append(track_data['vFollReal_next'])
        id_idx_list.append(pair_id * np.ones(len(track_data['vFollReal']), dtype=int))
    
    vt = np.concatenate(vt_list)
    s = np.concatenate(s_list)
    dv = np.concatenate(dv_list)
    label_v = np.concatenate(label_v_list)
    id_idx = np.concatenate(id_idx_list)
    
    return vt, s, dv, label_v, id_idx

def prepare_ar_idm_data(file_path, dt_target=0.1):
    data = load_data(file_path)
    
    data_downsampled = downsample_data(data, dt_target=dt_target)
    
    tracks = extract_vehicle_tracks(
        data_downsampled['time'], 
        data_downsampled['speeds'], 
        data_downsampled['distances']
    )
    
    vt, s, dv, label_v, id_idx = convert_to_ar_idm_input(tracks)
    
    print(f"Data preparation completed!")
    print(f"Total data points: {len(vt)}")
    print(f"Number of vehicles: {len(np.unique(id_idx))}")
    
    return {
        'vt': vt, 's': s, 'dv': dv, 'label_v': label_v, 'id_idx': id_idx,
        'n_vehicles': len(np.unique(id_idx)), 'tracks': tracks
    }

In [2]:
file_path = r'D:\学习\ETH学习是一辈子的大事\project\3\ASta_040719_platoon311.csv'
ar_idm_data = prepare_ar_idm_data(file_path, dt_target=0.5)

Vehicle names: ['Audi(A8)', 'Tesla(Model3)', 'BMW(X5)', 'Audi(A6)', 'Mercedes(AClass)']
Vehicle pair 0: V2_following_V1 - 1202 valid data points
Vehicle pair 1: V3_following_V2 - 1202 valid data points
Vehicle pair 2: V4_following_V3 - 1202 valid data points
Vehicle pair 3: V5_following_V4 - 1202 valid data points
Data preparation completed!
Total data points: 4808
Number of vehicles: 4
