In [1]:
import pandas as pd
import os

In [None]:
# Create an association between the timestamps and the corresponding images
def prep_combined_csv(input_dir, output_filepath):
    # Read the CSV files
    cam0_path = os.path.join(input_dir, 'cam0')
    cam1_path = os.path.join(input_dir, 'cam1')
    cam0_csv_mappings = pd.read_csv(os.path.join(cam0_path, 'data.csv'))
    cam1_csv_mappings = pd.read_csv(os.path.join(cam1_path, 'data.csv'))
    imu0_path = os.path.join(input_dir, 'imu0')
    imu_readings = pd.read_csv(os.path.join(imu0_path, 'data.csv')
    
    # Merge the IMU readings with the camera data
    combined_df = pd.merge(cam0_csv_mappings, cam1_csv_mappings, on='#timestamp [ns]', how='outer', suffixes=('_cam0', '_cam1'))
    combined_df = pd.merge(combined_df, imu_readings, on='#timestamp [ns]', how='outer')
    combined_df['filename_cam0'] = combined_df['filename_cam0'].fillna(method='ffill')
    combined_df['filename_cam1'] = combined_df['filename_cam1'].fillna(method='ffill')
    combined_df.columns = ['timestamp', 'cam0_filename', 'cam1_filename', 'w_x', 'w_y', 'w_z', 'a_x', 'a_y', 'a_z']

    # Save the combined dataframe to a new CSV file
    combined_df.to_csv(output_filepath, index=False)

In [3]:
prep_combined_csv("../data/mav0", "../data/mav0/combined.csv")

       #timestamp [ns]  w_RS_S_x [rad s^-1]  w_RS_S_y [rad s^-1]  \
0  1403715273262142976            -0.002094             0.017453   
1  1403715273267142912            -0.001396             0.019548   
2  1403715273272143104            -0.002094             0.016755   
3  1403715273277143040            -0.002793             0.020944   
4  1403715273282142976            -0.002094             0.020944   

   w_RS_S_z [rad s^-1]  a_RS_S_x [m s^-2]  a_RS_S_y [m s^-2]  \
0             0.077493           9.087496           0.130755   
1             0.078191           9.079323           0.122583   
2             0.074700           9.038462           0.147100   
3             0.078191           9.071151           0.122583   
4             0.078889           9.079323           0.130755   

   a_RS_S_z [m s^-2]  
0          -3.693838  
1          -3.693838  
2          -3.669322  
3          -3.677494  
4          -3.702010  


  combined_df['filename_cam0'] = combined_df['filename_cam0'].fillna(method='ffill')
  combined_df['filename_cam1'] = combined_df['filename_cam1'].fillna(method='ffill')


In [4]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
import os
import pandas as pd

class IMUImageDataset(Dataset):
    def __init__(self, csv_path, cam0_image_root, cam1_image_root, transform=None):
        self.data = pd.read_csv(csv_path)
        self.cam0_image_root = cam0_image_root
        self.cam1_image_root = cam1_image_root
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Load image
        cam0_path = os.path.join(self.cam0_image_root, row['cam0_filename'])
        cam0_image = Image.open(cam0_path).convert("RGB")
        cam0_image = self.transform(cam0_image)

        cam1_path = os.path.join(self.cam1_image_root, row['cam1_filename'])
        cam1_image = Image.open(cam1_path).convert("RGB")
        cam1_image = self.transform(cam1_image)

        # Load IMU features
        imu = row[["w_x", "w_y", "w_z", "a_x", "a_y", "a_z"]].values.astype("float32")
        imu_tensor = torch.tensor(imu)

        return imu_tensor, cam0_image, cam1_image

In [8]:
# Example usage:
dataset = IMUImageDataset(
    csv_path="../data/mav0/combined.csv",
    cam0_image_root="../data/mav0/cam0/data",
    cam1_image_root="../data/mav0/cam1/data"
)

first_window = [dataset[i] for i in range(10)]
imu_data, cam0_images, cam1_images = zip(*first_window)
print(imu_data[0].shape)
print(cam0_images[0].shape)
print(cam1_images[0].shape)

torch.Size([6])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
