In [9]:
import h5py as h5
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from utils.data_utils import make_path, flatten_obs

In [21]:
class StackCubeDataset(Dataset):
    def __init__(self):
        dataset_path = make_path('..', 
                                 'datasets', 
                                 'trajectory_state_original.h5')
        self.obs = []
        self.actions = []

        with h5.File(dataset_path, 'r') as data:
            for traj in data.values():
                obs = flatten_obs(traj['obs'])
                obs = obs[:-1, :]
                actions = traj['actions'][:]
                self.obs.append(obs)
                self.actions.append(actions)

        self.obs = np.concatenate(self.obs, axis=0)
        self.actions = np.concatenate(self.actions, axis=0)

        assert len(self.obs) == len(self.actions)

    def __getitem__(self, index):
        return self.obs[index], self.actions[index]

    def __len__(self):
        return len(self.obs)

In [16]:
class BC(nn.Module):
    def __init__(self, obs_dim = 55, act_dim = 8):
        super(BC, self).__init__()
        self.fc1 = nn.Linear(obs_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, act_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

145716
