In [1]:
import rosbag
import csv
import numpy as np
import pandas as pd

  matches = re.match("#ROS(.*) V(\d).(\d)", version_line)


In [2]:
def extract_rosbag_to_csv_with_timegap(bag_path, csv_path, start_time, end_time):
    """
    Extracts torque and gripper data from a rosbag and writes to CSV
    within the specified [start_time, end_time] time window.
    """
    bag = rosbag.Bag(bag_path)

    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)

        # Write header
        writer.writerow([
            'time',
            'tau_J_0', 'tau_J_1', 'tau_J_2', 'tau_J_3', 'tau_J_4', 'tau_J_5', 'tau_J_6',
            'gripper_position_left', 'gripper_position_right'
        ])

        latest_gripper_position = [None, None]

        for topic, msg, t in bag.read_messages(topics=[
            '/franka_state_controller/franka_states',
            '/franka_gripper/joint_states'
        ]):
            timestamp = t.to_sec()

            if timestamp < start_time:
                continue
            if timestamp > end_time:
                break

            if topic == '/franka_gripper/joint_states':
                latest_gripper_position = list(msg.position) if msg.position else [None, None]

            elif topic == '/franka_state_controller/franka_states':
                tau_J = list(msg.tau_J)
                row = [timestamp] + tau_J + latest_gripper_position
                writer.writerow(row)

    bag.close()


def extract_rosbag_to_matrix_with_timegap(bag_path, start_time, end_time):

    bag = rosbag.Bag(bag_path)
    data_matrix = []

    latest_gripper_position = [np.nan, np.nan]

    for topic, msg, t in bag.read_messages(topics=[
        '/franka_state_controller/franka_states',
        '/franka_gripper/joint_states'
    ]):
        timestamp = t.to_sec()

        if timestamp < start_time:
            continue
        if timestamp > end_time:
            break

        if topic == '/franka_gripper/joint_states':
            latest_gripper_position = list(msg.position) if msg.position else [np.nan, np.nan]

        elif topic == '/franka_state_controller/franka_states':
            tau_J = list(msg.tau_J)
            row = tau_J + latest_gripper_position
            data_matrix.append(row)

    bag.close()
    return data_matrix

In [15]:
bag_path = 'rosbag_files/Good/8_P2.bag'
#csv_path = 'franka_torque_and_gripper_log_filtered.csv'

start_time = 1752829041.346629   # replace with your start timestamp
end_time = 1752829043.479962   # replace with your end timestamp

#extract_rosbag_to_csv_with_timegap(bag_path, csv_path, start_time, end_time)
data = extract_rosbag_to_matrix_with_timegap(bag_path,start_time,end_time)

In [5]:
def max_pool_matrix(matrix, pool_size):
    """
    Applies MaxPooling to reduce the size of the matrix.
    
    Args:
        matrix: List of lists (rows x features) or numpy array.
        pool_size: Integer, the pooling window size.

    Returns:
        pooled_matrix: List of pooled rows (reduced in size).
    """
    matrix = np.array(matrix)
    n_rows, n_cols = matrix.shape
    pooled_matrix = []

    for i in range(0, n_rows, pool_size):
        window = matrix[i:i+pool_size, :]
        if window.shape[0] == 0:
            continue

        # For timestamp: use mean or first
        timestamp = np.mean(window[:, 0])  # or window[0, 0]

        # For other columns: max pooling
        pooled_features = np.max(window[:, 1:], axis=0)

        pooled_row = [timestamp] + pooled_features.tolist()
        pooled_matrix.append(pooled_row)

    return pooled_matrix

In [16]:
data[0]

[-0.23191767930984497,
 -49.09267807006836,
 -2.614889144897461,
 23.03385353088379,
 1.2294764518737793,
 2.6619012355804443,
 -0.20399820804595947,
 nan,
 nan]

In [17]:
print(np.array(data).shape)

(63, 9)


In [18]:
pooled_matrix = max_pool_matrix(data,8)

In [19]:
print(np.array(pooled_matrix).shape)

(8, 9)
