## Get Dataset

In [65]:
%matplotlib inline
%load_ext autoreload
%autoreload 2


In [None]:
#File exploration
import os
import re
import numpy as np
from collections import defaultdict

def analyze_directory(path):
    file_pattern = re.compile(r"^(.*?)(\d{5})\..+$")  # captures category and 5-digit number
    category_files = defaultdict(list)

    for filename in os.listdir(path):
        match = file_pattern.match(filename)
        if match:
            category, number_str = match.groups()
            category_files[category].append((filename, int(number_str)))

    # Print summary of categories
    for category, files in category_files.items():
        numbers = [num for _, num in files]
        print(f"Category: {category}")
        print(f"  Number of files: {len(files)}")
        print(f"  Number range: {min(numbers)} to {max(numbers)}")

    print("\nInspecting one file per category:")
    for category, files in category_files.items():
        sample_file = next(f for f in files if f[0].endswith('.npz'))[0]
        filepath = os.path.join(path, sample_file)
        print(f"\nSample file for category '{category}': {sample_file}")
        try:
            data = np.load(filepath)
            for key in data:
                print(f"  Key: {key}, Shape: {data[key].shape}")
        except Exception as e:
            print(f"  Could not load file '{sample_file}': {e}")

# Example usage:
analyze_directory("../dexnet_2.1/dexnet_2.1_eps_10/tensors")


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random

# Define all categories and their shape descriptions
categories = {
    'camera_poses_': (1000, 7),
    'hand_poses_': (1000, 6),
    'depth_ims_tf_table_': (466, 32, 32, 1),
    'labels_': (1000,),
    'traj_ids_': (1000,),
    'grasp_metrics_': (1000,),
    'camera_intrs_': (1000, 4),
    'grasped_obj_keys_': (1000,),
    'grasp_collision_metrics_': (1000,),
    'pile_ids_': (1000,)
}

def load_file(path, category, file_num):
    fname = f"{category}{file_num:05d}.npz"
    fpath = os.path.join(path, fname)
    return np.load(fpath)['arr_0']

def find_common_file_numbers(path):
    files = os.listdir(path)
    category_to_nums = {cat: set() for cat in categories}
    for fname in files:
        for cat in categories:
            if fname.startswith(cat) and fname.endswith('.npz'):
                try:
                    num = int(fname[len(cat):-4])
                    category_to_nums[cat].add(num)
                except:
                    continue
    # Find intersection of all sets
    common_nums = set.intersection(*category_to_nums.values())
    return sorted(list(common_nums))

def visualize_random_example(path):
    common_files = find_common_file_numbers(path)
    if not common_files:
        print("No common file numbers found across all categories.")
        return

    chosen_file_num = random.choice(common_files)

    # Load a sample file to determine valid index range
    depth_map = load_file(path, 'depth_ims_tf_table_', chosen_file_num)
    max_index = depth_map.shape[0]  # Likely 466
    chosen_index = random.randint(0, max_index - 1)

    print(f"Selected file number: {chosen_file_num:05d}, sample index: {chosen_index}\n")

    # Store and print/plot each category
    for category in categories:
        data = load_file(path, category, chosen_file_num)

        if category == 'depth_ims_tf_table_':
            image = data[chosen_index].squeeze()
            plt.figure()
            plt.title("Depth Map")
            plt.imshow(image, cmap='gray')
            plt.colorbar()
            plt.show()

        elif category == 'grasp_metrics_':
            plt.figure()
            plt.title("Grasp Metric (value)")
            plt.bar([0], [data[chosen_index]])
            plt.xticks([0], ['Grasp Metric'])
            plt.ylabel('Score')
            plt.show()

        else:
            print(f"{category}{chosen_file_num:05d} -> Example[{chosen_index}]: {data[chosen_index]}\n")

# Example usage
# Replace this path with the actual path to your data folder
visualize_random_example("../dexnet_2.1/dexnet_2.1_eps_10/tensors")


In [1]:
# cd scripts/
# ./download_dexnet_2.sh


In [None]:
dataset['image']['depth_ims']     # Shape: (N, 32, 32)   — depth image
dataset['pose']                   # Shape: (N, 4)        — grasp pose (x, y, z, angle)
dataset['success']                # Shape: (N,)          — binary label: success/failure


In [3]:
"""
For tensorfloaw dataset
"""
# import torch
# from torch.utils.data import Dataset
# import h5py
# import numpy as np

# class DexNetDataset(Dataset):
#     def __init__(self, h5_path):
#         self.data = h5py.File(h5_path, 'r')
#         self.depth_images = self.data['image']['depth_ims'][:]
#         self.labels = self.data['grasp_qualities'][:]  # or 'success', depending on file

#     def __len__(self):
#         return len(self.depth_images)

#     def __getitem__(self, idx):
#         img = self.depth_images[idx]
#         img = np.expand_dims(img, axis=0)  # Convert to (1, H, W) for PyTorch CNN
#         label = self.labels[idx]
#         return torch.tensor(img, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)


# from torch.utils.data import DataLoader

# dataset = DexNetDataset('path/to/dexnet_dataset.h5')
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [84]:
from customize_dataset import DexNetNPZDataset
from torch.utils.data import DataLoader

dataset = DexNetNPZDataset('../dexnet_2.1/dexnet_2.1_eps_10/tensors/')
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

In [None]:
for i in dataloader:
    print(i[0].shape, i[1])
    break


In [None]:
i[0].shape

## Define Model

In [68]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(32 * 6 * 6, 64)
        self.fc2 = nn.Linear(64, 1)  # Binary classification (grasp success)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # (B, 16, 15, 15)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 32, 6, 6)
        x = x.view(-1, 32 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x


## Forward Pass and Evaluation Measurements Define

In [None]:
dataloader

## Model Training

In [None]:
model = SimpleCNN()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from tqdm import tqdm
for epoch in range(3):
    for imgs, labels in tqdm(dataloader):
        outputs = model(imgs)
        loss = criterion(outputs.squeeze(), labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [None]:
import torch
from tqdm import tqdm

num_epochs = 15
loss_history = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for imgs, labels in progress_bar:
        outputs = model(imgs).squeeze()
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        running_loss += loss.item() * imgs.size(0)

        # Optional: compute accuracy
        predicted = (outputs >= 0.5).float()
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    accuracy = correct / total

    loss_history.append(epoch_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] — Loss: {epoch_loss:.4f} — Accuracy: {accuracy:.4f}")


In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_history, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss per Epoch')
plt.grid(True)
plt.show()
