In [8]:
import pandas as pd
import ast
from tqdm import tqdm  # For progress bar
import numpy as np
from sklearn.preprocessing import StandardScaler

def process_large_dataset(df, list_columns):
    """
    Process large datasets with string-encoded lists efficiently
    """
    # Create a new dataframe for expanded columns
    expanded_df = df.drop(columns=list_columns).copy()
    
    # Process each list column
    for col in tqdm(list_columns, desc="Processing columns"):
        # Get the first row to determine the list length
        first_row_list = ast.literal_eval(df[col].iloc[0])
        num_elements = len(first_row_list)
        
        # Pre-allocate numpy arrays for better performance
        expanded_values = np.zeros((len(df), num_elements))
        
        # Process chunks of the dataframe
        chunk_size = 1000
        for start_idx in tqdm(range(0, len(df), chunk_size), desc=f"Processing {col}"):
            end_idx = min(start_idx + chunk_size, len(df))
            chunk = df[col].iloc[start_idx:end_idx]
            
            # Process each row in the chunk
            for i, row in enumerate(chunk):
                try:
                    values = ast.literal_eval(row)
                    expanded_values[start_idx + i] = values
                except (ValueError, SyntaxError) as e:
                    print(f"Error processing row {start_idx + i} in column {col}: {e}")
                    expanded_values[start_idx + i] = np.nan
        
        # Add the expanded columns to the dataframe
        for i in range(num_elements):
            expanded_df[f'{col}_{i+1}'] = expanded_values[:, i]
    
    return expanded_df

# Example usage:
path = "/home/richtsai1103/CRL/src/results/HalfCheetah-v5/ppo_20241212_020732/selected_steps.csv"
df = pd.read_csv(path)
list_columns = ['current_state', 'current_action', 'prev_state', 'prev_action']
expanded_df = process_large_dataset(df, list_columns)
expanded_df

Processing current_state: 100%|██████████| 2/2 [00:00<00:00, 10.87it/s]
Processing current_action: 100%|██████████| 2/2 [00:00<00:00, 49.43it/s]
Processing prev_state: 100%|██████████| 2/2 [00:00<00:00, 22.31it/s]
Processing prev_action: 100%|██████████| 2/2 [00:00<00:00, 49.29it/s]
Processing columns: 100%|██████████| 4/4 [00:00<00:00, 10.47it/s]


Unnamed: 0,global_step,episode,current_reward,done,prev_reward,current_state_1,current_state_2,current_state_3,current_state_4,current_state_5,...,prev_state_14,prev_state_15,prev_state_16,prev_state_17,prev_action_1,prev_action_2,prev_action_3,prev_action_4,prev_action_5,prev_action_6
0,1000,0,-0.673945,True,-0.277783,-0.015463,-0.094336,-0.075143,0.034125,0.029438,...,5.867504,-15.696860,2.213892,9.558078,1.797210,-0.170801,0.930828,-2.083316,-2.562076,0.560429
1,2000,1,-0.361510,True,-0.835217,0.077898,0.086809,-0.028441,0.014306,-0.035626,...,-8.760627,3.608866,-0.932511,-9.566932,-0.907863,1.509416,-0.067992,0.321941,0.192680,-0.708834
2,3000,2,0.127424,True,-0.383446,-0.069944,-0.009932,0.059265,-0.053872,-0.089596,...,-5.562939,1.210716,3.529471,14.233454,-2.081728,-0.428466,-0.248964,0.315535,0.650916,0.699920
3,4000,3,-0.490850,True,1.227300,0.090918,-0.000021,-0.014954,0.024043,0.099019,...,7.780116,-7.945633,-10.580946,-2.255747,-1.097691,0.439106,0.471627,-0.642456,-1.389353,-0.367692
4,5000,4,-1.385514,True,-0.524974,0.093585,-0.097059,0.072728,0.096239,0.091442,...,1.149285,7.782791,1.829556,0.883807,-1.473485,0.757580,-1.723526,-0.101329,0.507383,0.552656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,1996000,0,0.006247,True,0.337107,-0.006135,-0.011026,-0.075260,0.046130,0.014392,...,3.096484,-13.904811,-5.740456,-13.078385,2.301426,1.957026,0.585023,-3.895117,-4.616184,-1.948931
1996,1997000,1,1.253199,True,1.521279,0.017811,0.058344,-0.082932,0.077833,-0.011857,...,4.440222,-24.525218,0.668700,-9.444017,2.190485,1.133501,-0.238043,-1.665647,-4.051864,-1.542027
1997,1998000,2,0.862805,True,0.331325,-0.074969,0.014875,-0.087969,0.020308,0.065355,...,0.460004,-13.474900,-8.441839,0.971193,3.072515,1.665262,-1.790982,-1.040521,-5.605346,0.557282
1998,1999000,3,1.720291,True,2.702325,-0.003330,-0.078384,-0.047495,0.089896,0.030557,...,6.506647,4.167562,-22.490452,-0.171658,-0.884392,-1.652829,0.309836,1.223061,-0.369421,0.810269


In [9]:
expanded_df = expanded_df.iloc[:,2:]
expanded_df

Unnamed: 0,current_reward,done,prev_reward,current_state_1,current_state_2,current_state_3,current_state_4,current_state_5,current_state_6,current_state_7,...,prev_state_14,prev_state_15,prev_state_16,prev_state_17,prev_action_1,prev_action_2,prev_action_3,prev_action_4,prev_action_5,prev_action_6
0,-0.673945,True,-0.277783,-0.015463,-0.094336,-0.075143,0.034125,0.029438,0.023077,-0.023264,...,5.867504,-15.696860,2.213892,9.558078,1.797210,-0.170801,0.930828,-2.083316,-2.562076,0.560429
1,-0.361510,True,-0.835217,0.077898,0.086809,-0.028441,0.014306,-0.035626,0.018860,-0.032418,...,-8.760627,3.608866,-0.932511,-9.566932,-0.907863,1.509416,-0.067992,0.321941,0.192680,-0.708834
2,0.127424,True,-0.383446,-0.069944,-0.009932,0.059265,-0.053872,-0.089596,-0.019090,-0.060297,...,-5.562939,1.210716,3.529471,14.233454,-2.081728,-0.428466,-0.248964,0.315535,0.650916,0.699920
3,-0.490850,True,1.227300,0.090918,-0.000021,-0.014954,0.024043,0.099019,0.089789,-0.007991,...,7.780116,-7.945633,-10.580946,-2.255747,-1.097691,0.439106,0.471627,-0.642456,-1.389353,-0.367692
4,-1.385514,True,-0.524974,0.093585,-0.097059,0.072728,0.096239,0.091442,-0.070247,0.094526,...,1.149285,7.782791,1.829556,0.883807,-1.473485,0.757580,-1.723526,-0.101329,0.507383,0.552656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,0.006247,True,0.337107,-0.006135,-0.011026,-0.075260,0.046130,0.014392,0.058136,-0.010632,...,3.096484,-13.904811,-5.740456,-13.078385,2.301426,1.957026,0.585023,-3.895117,-4.616184,-1.948931
1996,1.253199,True,1.521279,0.017811,0.058344,-0.082932,0.077833,-0.011857,0.082874,0.096715,...,4.440222,-24.525218,0.668700,-9.444017,2.190485,1.133501,-0.238043,-1.665647,-4.051864,-1.542027
1997,0.862805,True,0.331325,-0.074969,0.014875,-0.087969,0.020308,0.065355,0.070299,-0.010135,...,0.460004,-13.474900,-8.441839,0.971193,3.072515,1.665262,-1.790982,-1.040521,-5.605346,0.557282
1998,1.720291,True,2.702325,-0.003330,-0.078384,-0.047495,0.089896,0.030557,0.047668,-0.062129,...,6.506647,4.167562,-22.490452,-0.171658,-0.884392,-1.652829,0.309836,1.223061,-0.369421,0.810269


In [10]:
expanded_df.describe()

Unnamed: 0,current_reward,prev_reward,current_state_1,current_state_2,current_state_3,current_state_4,current_state_5,current_state_6,current_state_7,current_state_8,...,prev_state_14,prev_state_15,prev_state_16,prev_state_17,prev_action_1,prev_action_2,prev_action_3,prev_action_4,prev_action_5,prev_action_6
count,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,...,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0
mean,0.775488,0.756341,0.000449,0.002789,-0.000995,0.000384,-0.001386,0.000509,0.001967,0.000667,...,-0.094095,0.220261,-0.039584,-0.076933,0.055171,-0.216199,0.357256,-0.052649,-0.987821,0.028049
std,1.095763,1.088285,0.05839,0.057113,0.058402,0.057937,0.058213,0.057983,0.056662,0.057265,...,10.304847,13.33392,12.244519,7.968637,1.268377,2.262779,1.444305,2.068345,1.85247,1.520129
min,-2.685754,-2.330171,-0.09995,-0.099934,-0.099871,-0.099947,-0.099875,-0.099857,-0.099995,-0.09989,...,-28.698649,-25.13565,-28.556873,-17.218609,-3.783534,-5.510066,-4.091074,-6.61973,-7.73984,-4.306869
25%,-0.076432,-0.075667,-0.050224,-0.045193,-0.053042,-0.048208,-0.050699,-0.050405,-0.044419,-0.048617,...,-7.423439,-9.10491,-6.430186,-5.430195,-0.836245,-2.153496,-0.589158,-1.577422,-2.172943,-0.957304
50%,0.734437,0.720548,0.000988,0.002475,-0.001046,-0.000188,-0.002574,0.002016,0.002045,0.000868,...,0.844846,-0.694319,-2.409484,-0.80278,-0.034609,-0.448068,0.342308,-0.259847,-0.962269,-0.043859
75%,1.553394,1.544702,0.050558,0.052981,0.05038,0.051605,0.048469,0.050978,0.050106,0.051674,...,7.574062,4.48271,3.20473,4.708686,0.92223,1.793228,1.229432,1.569933,0.272928,0.866576
max,3.478207,3.403774,0.099994,0.099874,0.099861,0.099952,0.099999,0.099773,0.099752,0.099772,...,22.162896,30.617462,33.898477,20.372753,4.198519,5.219758,5.969357,6.520699,3.725488,6.69506


In [17]:
expanded_df = expanded_df.drop('done', axis=1)

KeyError: "['done'] not found in axis"

In [15]:
expanded_df

Unnamed: 0,current_reward,prev_reward,current_state_1,current_state_2,current_state_3,current_state_4,current_state_5,current_state_6,current_state_7,current_state_8,...,prev_state_14,prev_state_15,prev_state_16,prev_state_17,prev_action_1,prev_action_2,prev_action_3,prev_action_4,prev_action_5,prev_action_6
0,-0.673945,-0.277783,-0.015463,-0.094336,-0.075143,0.034125,0.029438,0.023077,-0.023264,0.099442,...,5.867504,-15.696860,2.213892,9.558078,1.797210,-0.170801,0.930828,-2.083316,-2.562076,0.560429
1,-0.361510,-0.835217,0.077898,0.086809,-0.028441,0.014306,-0.035626,0.018860,-0.032418,-0.021676,...,-8.760627,3.608866,-0.932511,-9.566932,-0.907863,1.509416,-0.067992,0.321941,0.192680,-0.708834
2,0.127424,-0.383446,-0.069944,-0.009932,0.059265,-0.053872,-0.089596,-0.019090,-0.060297,-0.081849,...,-5.562939,1.210716,3.529471,14.233454,-2.081728,-0.428466,-0.248964,0.315535,0.650916,0.699920
3,-0.490850,1.227300,0.090918,-0.000021,-0.014954,0.024043,0.099019,0.089789,-0.007991,0.051546,...,7.780116,-7.945633,-10.580946,-2.255747,-1.097691,0.439106,0.471627,-0.642456,-1.389353,-0.367692
4,-1.385514,-0.524974,0.093585,-0.097059,0.072728,0.096239,0.091442,-0.070247,0.094526,0.077987,...,1.149285,7.782791,1.829556,0.883807,-1.473485,0.757580,-1.723526,-0.101329,0.507383,0.552656
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,0.006247,0.337107,-0.006135,-0.011026,-0.075260,0.046130,0.014392,0.058136,-0.010632,-0.032107,...,3.096484,-13.904811,-5.740456,-13.078385,2.301426,1.957026,0.585023,-3.895117,-4.616184,-1.948931
1996,1.253199,1.521279,0.017811,0.058344,-0.082932,0.077833,-0.011857,0.082874,0.096715,0.060730,...,4.440222,-24.525218,0.668700,-9.444017,2.190485,1.133501,-0.238043,-1.665647,-4.051864,-1.542027
1997,0.862805,0.331325,-0.074969,0.014875,-0.087969,0.020308,0.065355,0.070299,-0.010135,0.033641,...,0.460004,-13.474900,-8.441839,0.971193,3.072515,1.665262,-1.790982,-1.040521,-5.605346,0.557282
1998,1.720291,2.702325,-0.003330,-0.078384,-0.047495,0.089896,0.030557,0.047668,-0.062129,0.091364,...,6.506647,4.167562,-22.490452,-0.171658,-0.884392,-1.652829,0.309836,1.223061,-0.369421,0.810269


In [16]:
expanded_df.to_csv('/home/richtsai1103/CRL/src/results/HalfCheetah-v5/ppo_20241212_020732/expanded_steps.csv', 
          index=False,      # Don't save row indices
          header=True,      # Save column names
          encoding='utf-8'  # Specify encoding
)