In [None]:
import numpy as np
import dsuite
import gym
from dsuite.dclaw.turn import DClawTurnImage, DClawTurnFixed
from softlearning.environments.adapters.gym_adapter import GymAdapter
import os
import imageio
import pickle
import matplotlib.pyplot as plt

from IPython.display import clear_output


In [None]:
exp_name = 'fixed_screw_2_goals'

In [None]:
cur_dir = os.getcwd()
directory = os.path.join(cur_dir, exp_name)
if not os.path.exists(directory):
    os.makedirs(directory)

In [None]:
'mixed_goal_pool = False
images = True
goals = [-90, 90]
num_goals = len(goals)

image_shape = (32, 32, 3)
NUM_TOTAL_EXAMPLES, ROLLOUT_LENGTH, STEPS_PER_SAMPLE = 500, 25, 5
observations = []

for goal_index, goal in enumerate(goals):
    print(f'\n\n ===== GOAL INDEX: {goal_index}, GOAL: {goal} ===== ')
    if not mixed_goal_pool:
        observations = []  # reset the observations

    num_positives = 0
    goal_angle = np.pi / 180. * goal # convert to radians

    env_kwargs = {
        'camera_settings': {
            'azimuth': 180,
            'distance': 0.3,
            'elevation': -50,
            'lookat': np.array([0.02, 0.004, 0.09])
        },
        'goals': (goal_angle,),
        'goal_collection': True,
        'init_object_pos_range': (goal_angle - 0.05, goal_angle + 0.05),
        'target_pos_range': (goal_angle, goal_angle),
        'pixel_wrapper_kwargs': {
            'pixels_only': False,
            'normalize': False,
            'render_kwargs': {
                'width': image_shape[0],
                'height': image_shape[1],
                'camera_id': -1
            },
        },
        'swap_goals_upon_completion': True,
        'one_hot_goal_index': True,
        'observation_keys': (
            'pixels',
            'claw_qpos',
            'last_action',
            'goal_index',
            'one_hot_goal_index',
        ),
    }
    env = GymAdapter(
        domain='DClaw',
        task='TurnMultiGoal-v0',
        **env_kwargs
    )

    if mixed_goal_pool:
        path = directory
    else:
        path = os.path.join(directory, f'goal_{goal_index}_{goal}')
    if not os.path.exists(path):
        os.makedirs(path)

    # reset the environment
    while num_positives <= NUM_TOTAL_EXAMPLES:
        observation = env.reset()
        print("Resetting environment...")
        t = 0
        while t < ROLLOUT_LENGTH:
            action = env.action_space.sample()
            for _ in range(STEPS_PER_SAMPLE):
                observation, _, _, _ = env.step(action)

            obs_dict = env.get_obs_dict()

            # For fixed screw
            object_target_angle_dist = obs_dict['object_to_target_angle_dist']

            ANGLE_THRESHOLD = 0.15
            if object_target_angle_dist < ANGLE_THRESHOLD:
                # Add observation if meets criteria
                if 'one_hot_goal_index' in observation:
                    one_hot = np.zeros(num_goals).astype(np.float32)
                    one_hot[goal_index] = 1.
                    observation['one_hot_goal_index'] = one_hot
                observation['goal_index'] = np.array([goal_index])
                observations.append(observation)
                if images:
                    img_obs = observation['pixels']
                    plt.imshow(img_obs)
                    plt.show()
                num_positives += 1
            t += 1
            
        if num_positives % 100 == 0:
            clear_output()
            
    goal_examples = {
        key: np.concatenate([
            obs[key][None] for obs in observations
        ], axis=0)
        for key in observations[0].keys()
    }

    with open(os.path.join(path, 'positives.pkl'), 'wb') as file:
        pickle.dump(goal_examples, file)'

In [None]:
with open(os.path.join(os.path.join(directory, '-90'), 'positives.pkl'), 'rb') as f:
    data = pickle.load(f)

In [None]:
data['goal_index'].shape

In [None]:
observations[0]['goal_index']