In [60]:
import zarr
import torch
from torch.utils.data import Dataset
import numpy as np

In [80]:
import zarr
import torch
from torch.utils.data import Dataset

class RobotDataset(Dataset):
    def __init__(self, zarr_path):
        self.root = zarr.open(zarr_path, 'r')
        self.cam_0 = self.root['data']['cam_0']
        self.cam_1 = self.root['data']['cam_1']
        self.actions = self.root['data']['action']
        self.joint_pos = self.root['data']['joint_pos']
        self.joint_vel = self.root['data']['joint_vel']
    
    def __len__(self):
        return len(self.actions)
    
    def __getitem__(self, idx):
        # Load images
        img0 = torch.from_numpy(self.cam_0[idx]).permute(2, 0, 1)  # (C,H,W)
        img1 = torch.from_numpy(self.cam_1[idx]).permute(2, 0, 1)
        
        # Concatenate cameras (6 channels total)
        obs = torch.cat([img0, img1], dim=0)
        
        # Load action
        action = torch.from_numpy(self.actions[idx])
        
        return obs, action

# Usage
dataset = RobotDataset('/home/vinoth/BTP/multicam_data_collection/collected_data/session_20260101_204329.zarr')
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [81]:
# Load dataset
root = zarr.open('/home/vinoth/BTP/multicam_data_collection/collected_data/session_20260101_204329.zarr', 'r')

# Check shapes
print(f"Camera 0: {root['data']['cam_0'].shape}")
print(f"Camera 1: {root['data']['cam_1'].shape}")
print(f"Actions: {root['data']['action'].shape}")
print(f"Joint Pos: {root['data']['joint_pos'].shape}")
print(f"Joint Vel: {root['data']['joint_vel'].shape}")

timestamp = root['data']['timestamp'][:]
# Check episode boundaries
episode_ends = root['meta']['episode_ends'][:]
print(f"Episodes: {len(episode_ends)}")

# Verify image range
img_sample = root['data']['cam_0'][0]
print(f"Image range: [{img_sample.min():.3f}, {img_sample.max():.3f}]")

Camera 0: (418, 96, 96, 3)
Camera 1: (418, 96, 96, 3)
Actions: (418, 6)
Joint Pos: (418, 6)
Joint Vel: (418, 6)
Episodes: 3
Image range: [0.098, 0.761]


In [82]:
from PIL import Image

In [83]:
arr = root['data']['cam_1'][330]  # HWC, float32 in [0,1]
arr_uint8 = (arr * 255).clip(0, 255).astype(np.uint8)
img = Image.fromarray(arr_uint8, mode="RGB")

img.show()        # visualize
img.save("img.png")

  img = Image.fromarray(arr_uint8, mode="RGB")


In [84]:
print(episode_ends)

[105 338 418]


In [66]:
for i in range(len(timestamp)):
    print(f"Frame {i}: Timestamp {timestamp[i]}\n")   

Frame 0: Timestamp 1767278605.9862149

Frame 1: Timestamp 1767278606.3755014

Frame 2: Timestamp 1767278606.7647223

Frame 3: Timestamp 1767278607.153603

Frame 4: Timestamp 1767278607.5423307

Frame 5: Timestamp 1767278607.9313505

Frame 6: Timestamp 1767278608.320517

Frame 7: Timestamp 1767278608.709329

Frame 8: Timestamp 1767278609.0985131

Frame 9: Timestamp 1767278609.4873726

Frame 10: Timestamp 1767278609.8765645

Frame 11: Timestamp 1767278610.2652607

Frame 12: Timestamp 1767278610.6543562

Frame 13: Timestamp 1767278611.0233388

Frame 14: Timestamp 1767278611.3925717

Frame 15: Timestamp 1767278611.7612739

Frame 16: Timestamp 1767278612.1305995

Frame 17: Timestamp 1767278612.5194595

Frame 18: Timestamp 1767278612.8882282

Frame 19: Timestamp 1767278613.2372882

Frame 20: Timestamp 1767278613.6064248

Frame 21: Timestamp 1767278613.9756987

Frame 22: Timestamp 1767278614.3645735

Frame 23: Timestamp 1767278614.754534

Frame 24: Timestamp 1767278615.143435

Frame 25: Times

In [None]:
root['data']['joint_vel'][:]

array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)

In [90]:
print(root['data']['joint_pos'][:])

[[ 494. 1170. 2258. 2861.  763.  244.]
 [ 494. 1170. 2258. 2861.  763.  244.]
 [ 494. 1170. 2258. 2861.  763.  244.]
 ...
 [ 493. 1172. 2254. 2856.  761.  221.]
 [ 493. 1172. 2254. 2856.  761.  221.]
 [ 493. 1172. 2254. 2858.  761.  221.]]


In [91]:
for i in range(0,len(root['data']['joint_pos'][:])):
    print(f"Joint Pos {root['data']['joint_pos'][i]}\n")

Joint Pos [ 494. 1170. 2258. 2861.  763.  244.]

Joint Pos [ 494. 1170. 2258. 2861.  763.  244.]

Joint Pos [ 494. 1170. 2258. 2861.  763.  244.]

Joint Pos [ 494. 1170. 2258. 2861.  763.  244.]

Joint Pos [ 521. 1170. 2253. 2841.  763.  244.]

Joint Pos [ 562. 1169. 2257. 2809.  763.  244.]

Joint Pos [ 604. 1156. 2260. 2823.  763.  244.]

Joint Pos [ 570. 1187. 2265. 2823.  763.  244.]

Joint Pos [ 568. 1230. 2253. 2824.  763.  244.]

Joint Pos [ 548. 1284. 2225. 2824.  763.  244.]

Joint Pos [ 534. 1346. 2178. 2824.  763.  244.]

Joint Pos [ 539. 1412. 2179. 2800.  763.  244.]

Joint Pos [ 539. 1481. 2207. 2760.  763.  244.]

Joint Pos [ 539. 1550. 2231. 2716.  763.  244.]

Joint Pos [ 539. 1624. 2234. 2740.  763.  244.]

Joint Pos [ 518. 1681. 2240. 2735.  763.  244.]

Joint Pos [ 529. 1678. 2260. 2710.  763.  244.]

Joint Pos [ 495. 1685. 2260. 2722.  763.  244.]

Joint Pos [ 445. 1685. 2267. 2722.  763.  273.]

Joint Pos [ 398. 1679. 2271. 2722.  763.  273.]

Joint Pos [ 427. 168

In [97]:
for i in range(0,len(root['data']['action'][:])):
    print(f"Joint Pos {root['data']['action'][i]}\n")

Joint Pos [ 494. 1172. 2252. 2860.  763.  244.]

Joint Pos [ 494. 1172. 2252. 2860.  763.  244.]

Joint Pos [ 494. 1172. 2252. 2860.  763.  244.]

Joint Pos [ 673. 1172. 2118. 2715.  763.  244.]

Joint Pos [ 686. 1172. 2284. 2674.  763.  244.]

Joint Pos [ 705. 1130. 2256. 2830.  763.  244.]

Joint Pos [ 567. 1339. 2261. 2820.  763.  244.]

Joint Pos [ 567. 1349. 2085. 2825.  763.  244.]

Joint Pos [ 548. 1390. 2085. 2825.  763.  244.]

Joint Pos [ 534. 1438. 2059. 2825.  763.  244.]

Joint Pos [ 540. 1494. 2171. 2644.  763.  244.]

Joint Pos [ 540. 1560. 2204. 2634.  763.  244.]

Joint Pos [ 540. 1626. 2242. 2600.  763.  244.]

Joint Pos [ 540. 1696. 2228. 2744.  763.  244.]

Joint Pos [ 371. 1687. 2234. 2734.  763.  244.]

Joint Pos [ 530. 1668. 2342. 2625.  763.  244.]

Joint Pos [ 349. 1685. 2253. 2725.  763.  244.]

Joint Pos [ 331. 1685. 2262. 2719.  763.  244.]

Joint Pos [ 325. 1634. 2267. 2723.  763.  255.]

Joint Pos [ 428. 1688. 2271. 2723.  763.  273.]

Joint Pos [ 417. 168

In [95]:
for i in range(len(root['data']['cam_0'][:])):
    arr = root['data']['cam_0'][i]  # HWC, float32 in [0,1]
    arr_uint8 = (arr * 255).clip(0, 255).astype(np.uint8)
    img = Image.fromarray(arr_uint8, mode="RGB")
    img.save(f"img_check/{i}_img.png")
    # img.show()        # visualize

  img = Image.fromarray(arr_uint8, mode="RGB")


In [None]:
for i in range(len(root['data']['cam_0'][:])):
    arr = root['data']['cam_0'][i]  # HWC, float32 in [0,1]
    arr_uint8 = (arr * 255).clip(0, 255).astype(np.uint8)
    img = Image.fromarray(arr_uint8, mode="RGB")
    img.save(f"img_check/cam_0/{i}_img.png")
    # img.show() 

In [96]:
for i in range(len(root['data']['cam_1'][:])):
    arr = root['data']['cam_1'][i]  # HWC, float32 in [0,1]
    arr_uint8 = (arr * 255).clip(0, 255).astype(np.uint8)
    img = Image.fromarray(arr_uint8, mode="RGB")
    img.save(f"img_check/cam_1/{i}_img.png")
    # img.show() 

  img = Image.fromarray(arr_uint8, mode="RGB")
