# Build Env

In [1]:
# !python --version
# !pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 \
# scikit-image==0.19.3 scikit-video==1.1.11 zarr==2.12.0 numcodecs==0.10.2 \
# pygame==2.1.2 pymunk==6.2.1 gym==0.26.2 shapely==1.8.4 \
# &> /dev/null # mute output

In [2]:
# !pip install opencv-python-headless scikit-video

# Import

In [3]:
#@markdown ### **Imports**
# diffusion policy import
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import math
import torch
import torch.nn as nn
import collections
import zarr
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

# env import
import gym
from gym import spaces
import pygame
import pymunk
import pymunk.pygame_util
from pymunk.space_debug_draw_options import SpaceDebugColor
from pymunk.vec2d import Vec2d
import shapely.geometry as sg
import cv2
import skimage.transform as st
from skvideo.io import vwrite
from IPython.display import Video
import gdown
import os

# Load Data

In [4]:
# import pickle
# with open(f'/Users/qiyangyan/Desktop/Diffusion/Demonstration/VFF-random_9884demos', 'rb') as f:
#   dataset_path = pickle.load(f)

import pickle
with open(f'/Users/qiyangyan/Desktop/Diffusion/Demonstration/VFF-bigSteps', 'rb') as f:
  dataset_path = pickle.load(f)

In [16]:
#@markdown ### **Dataset Demo**
from my_module import PushTStateDataset

# parameters
# pred_horizon = 16
# obs_horizon = 2
# action_horizon = 8

pred_horizon = 4
obs_horizon = 2
action_horizon = 2

#|o|o|                             observations: 2
#| |a|a|a|a|a|a|a|a|               actions executed: 8
#|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16

# create dataset from file
dataset = PushTStateDataset(
    dataset_path=dataset_path,
    pred_horizon=pred_horizon,
    obs_horizon=obs_horizon,
    action_horizon=action_horizon
)
# save training data statistics (min, max) for each dim
stats = dataset.stats
print(dataset.__len__())



# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    num_workers=1,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['obs'].shape:", batch['obs'].shape)
print("batch['action'].shape", batch['action'].shape)

print(np.shape(batch['obs'].flatten(start_dim=1)))
print(batch['obs'].flatten(start_dim=1)[0])

# Create Network

In [6]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        print(np.shape(x), h)
        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


In [7]:
#@markdown ### **Network Demo**

# observation and action dimensions corrsponding to
# the output of PushTEnv
obs_dim = 24
action_dim = 2

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

# example inputs
noised_action = torch.randn((1, pred_horizon, action_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))

print(np.shape(noised_action))
print(np.shape(pred_horizon))

# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
noise = noise_pred_net(
    sample=noised_action,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_action - noise

# for this demo, we use DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    # the choise of beta schedule has big impact on performance
    # we found squared cosine works the best
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # our network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cpu')
_ = noise_pred_net.to(device)

In [8]:
import gymnasium as gym
import numpy as np

env = gym.make("VariableFriction-v8")
env_dict = env.reset()[0]
print(env_dict)
for key in env_dict.keys():
  print(np.shape(env_dict[key]))

In [9]:
import os
import gdown
import torch

# Set the path to the checkpoint file
ckpt_path = '/Users/qiyangyan/Downloads/ema_noise_pred_net_epoch_99.ckpt'

# Ensure the file exists
if not os.path.isfile(ckpt_path):
    print(f"Checkpoint file {ckpt_path} does not exist.")
else:
    # Load the state dictionary
    state_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))

    # Load the state dictionary into the model
    ema_noise_pred_net = noise_pred_net
    ema_noise_pred_net.load_state_dict(state_dict['model_state_dict'])

    print('Pretrained weights loaded.')

In [12]:
#@markdown ### **Training**
#@markdown
#@markdown Takes about an hour. If you don't want to wait, skip to the next cell
#@markdown to load pre-trained weights

num_epochs = 100

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # data normalized in dataset
                # device transfer
                nobs = nbatch['obs'].to(device)
                naction = nbatch['action'].to(device)
                B = nobs.shape[0]

                print("Check: ", nobs)

                # observation as FiLM conditioning
                # (B, obs_horizon, obs_dim)
                obs_cond = nobs[:,:obs_horizon,:]
                # (B, obs_horizon * obs_dim)
                obs_cond = obs_cond.flatten(start_dim=1)

                # sample noise to add to actions
                noise = torch.randn(naction.shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_actions = noise_scheduler.add_noise(
                    naction, noise, timesteps)

                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_actions, timesteps, global_cond=obs_cond)

                # L2 loss
                loss = nn.functional.mse_loss(noise_pred, noise)

                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(noise_pred_net.parameters())

                # logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        tglobal.set_postfix(loss=np.mean(epoch_loss))

# Weights of the EMA model
# is used for inference
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())

In [None]:
import time
import rotations

def discretize_action_to_control_mode_E2E(action):
    """
    -1 ~ 1 maps to 0 ~ 1
    """
    # Your action discretization logic here
    # print("Action: ", action)
    action_norm = (action + 1) / 2
    # print(action_norm, action)
    if 1 / 6 > action_norm >= 0:
        print("| Slide up on right finger")
        control_mode = 0
        friction_state = 1  # left finger high friction
        pos_idx = 0
    elif 2 / 6 > action_norm >= 1 / 6:
        print("| Slide down on right finger")
        control_mode = 1
        friction_state = 1
        pos_idx = 1
    elif 3 / 6 > action_norm >= 2 / 6:
        print("| Slide up on left finger")
        control_mode = 2
        friction_state = -1
        pos_idx = 1
    elif 4 / 6 > action_norm >= 3 / 6:
        print("| Slide down on left finger")
        control_mode = 3
        friction_state = -1
        pos_idx = 0
    elif 5 / 6 > action_norm >= 4 / 6:
        print("| Rotate clockwise")
        control_mode = 4
        friction_state = 0
        pos_idx = 0
        # print("Rotate")
    else:
        assert 1 >= action_norm >= 5 / 6
        print("| Rotate anticlockwise")
        control_mode = 5
        friction_state = 0
        pos_idx = 1
        # print("Rotate")
    return friction_state, control_mode, pos_idx

def compute_orientation_diff(goal_a, goal_b):
    ''' get pos difference and rotation difference
    left motor pos: 0.037012 -0.1845 0.002
    right motor pos: -0.037488 -0.1845 0.002
    '''
    assert goal_a.shape == goal_b.shape, f"Check: {goal_a.shape}, {goal_b.shape}"
    assert goal_a.shape[-1] == 7
    goal_a[2] = goal_b[2]

    d_pos = np.zeros_like(goal_a[..., 0])

    delta_pos = goal_a[..., :3] - goal_b[..., :3]
    d_pos = np.linalg.norm(delta_pos, axis=-1)

    quat_a, quat_b = goal_a[..., 3:], goal_b[..., 3:]

    euler_a = rotations.quat2euler(quat_a)
    euler_b = rotations.quat2euler(quat_b)
    if euler_a.ndim == 1:
        euler_a = euler_a[np.newaxis, :]  # Reshape 1D to 2D (1, 3)
    if euler_b.ndim == 1:
        euler_b = euler_b[np.newaxis, :]  # Reshape 1D to 2D (1, 3)
    euler_a[:,:2] = euler_b[:,:2]  # make the second and third term of euler angle the same
    quat_a = rotations.euler2quat(euler_a)
    quat_a = quat_a.reshape(quat_b.shape)

    # print(quat_a, quat_b)
    quat_diff = rotations.quat_mul(quat_a, rotations.quat_conjugate(quat_b))  # q_diff = q1 * q2*
    angle_diff = 2 * np.arccos(np.clip(quat_diff[..., 0], -1.0, 1.0))
    d_rot = angle_diff
    return d_pos, d_rot

def change_friction_full_obs(action, env):
    '''
    one while loop might have an early-finish when the position of finger is reached
    the friction change might not be completed
    '''
    start_t = time.time()
    while True:
        next_env_dict, distance, terminated, truncated, info = env.step(action)
        # print("two finger position: ", next_env_dict['observation'][0], next_env_dict['observation'][2])
        if terminated is True:
            return next_env_dict, distance, terminated, truncated, info
            # return False, distance, next_env_dict["observation"][pos_idx*2]  # false meaning not complete
        if distance["action_complete"]:
            break
        if time.time() - start_t > 2.5:
            terminated = True
            return next_env_dict, distance, terminated, truncated, info
    return next_env_dict, distance, terminated, truncated, info

def friction_change(friction_state, env):
    friction_action_1 = [2, 0, True]
    friction_action_2 = [2, friction_state, True]
    # input("Press Enter to continue...")
    new_obs, rewards, terminated, _, infos = change_friction_full_obs(np.array(friction_action_1),
                                                                                    env)
    if terminated:
        print("terminate at friction change to high")
        return new_obs, rewards["RL_IHM"], terminated, _, infos
    # input("press")
    new_obs, rewards, terminated, _, infos = change_friction_full_obs(np.array(friction_action_2),
                                                                                    env)
    if terminated:
        print("terminate at friction change to low")
    return new_obs, rewards["RL_IHM"], terminated, _, infos

def pick_up(inAir, env):
    t1 = time.time()
    # pick_up_action = [0, -2, False]
    pick_up_action = [0, 2, False]
    # print("start picking")
    'The position position-controlled finger reaches the middle'
    while True:  # for _ in range(105):
        # pick_up_action[0] += 0.01
        pick_up_action[0] = 0.9557
        # print(pick_up_action[0])
        state, reward, _, _, _ = env.step(np.array(pick_up_action))
        # while not reward['action_complete']:
        #     state, reward, _, _, _ = self.env.step(np.array(pick_up_action))
        # print(state["observation"][0])
        if abs(state["observation"][0] - pick_up_action[0]) < 0.003:
            print("Pick up complete")
            break
    # print("closing")
    'Wait until the torque-controlled finger reaches the middle'
    for _ in range(50):
        state, reward, _, _, _ = env.step(np.array(pick_up_action))
    # print("pick up complete --------")

    'Wait until the finger raised to air'
    lift_action = [0, -3, False]
    if inAir is True:
        print("Lifting the block")
        while True: # for _ in range(120):
            state, reward, _, _, _ = env.step(np.array(lift_action))
            if reward["action_complete"]:
                break
    return state, reward

def ihm_step(env, action, last_friction_state):
  friction_state, control_mode_dis, pos_idx = discretize_action_to_control_mode_E2E(action[1])
  if last_friction_state == friction_state:
    obs, r_dict, terminated, _, info_ = env.step(action)
  else:
    obs, r_dict, terminated, _, info_ = friction_change(friction_state, env)
    if not terminated:
      obs, r_dict, terminated, _, info_ = env.step(action)

  _, angle_diff = compute_orientation_diff(np.array(obs['desired_goal'][:7]), np.array(obs['achieved_goal'][:7]))
  radi_diff = obs['desired_goal'][7:9] - obs['achieved_goal'][7:9]
  obs['observation'][-1] = angle_diff
  obs['observation'][-2] = radi_diff[1]
  obs['observation'][-3] = radi_diff[0]
  return obs, r_dict, terminated, terminated, info_, friction_state

In [None]:
#@markdown ### **Inference**
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt


# limit enviornment interaction to 200 steps before termination
max_steps = 800

env = gym.make("VariableFriction-v8", render_mode="human")

# get first observation
# obs_dict, info = env.reset(seed=100000)
obs_dict, info = env.reset()
obs_dict, _ = pick_up(False, env)
obs = obs_dict['observation']

# keep a queue of last 2 steps of observations
obs_deque = collections.deque(
    [obs] * obs_horizon, maxlen=obs_horizon)
# save visualization and rewards
rewards = list()
done = False
step_idx = 0
last_friction_state = 0



with tqdm(total=max_steps, desc="Eval PushTStateEnv") as pbar:
    while not done:
        B = 1
        # stack the last obs_horizon (2) number of observations
        obs_seq = np.stack(obs_deque)
        # normalize observation
        nobs = normalize_data(obs_seq, stats=stats['obs'])
        # device transfer
        nobs = torch.from_numpy(nobs).to(device, dtype=torch.float32)

        # infer action
        with torch.no_grad():
            # reshape observation to (B,obs_horizon*obs_dim)
            obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)

            # initialize action from Guassian noise
            noisy_action = torch.randn(
                (B, pred_horizon, action_dim), device=device)
            naction = noisy_action

            # init scheduler
            noise_scheduler.set_timesteps(num_diffusion_iters)

            for k in noise_scheduler.timesteps:
                # predict noise
                noise_pred = ema_noise_pred_net(
                    sample=naction,
                    timestep=k,
                    global_cond=obs_cond
                )

                # inverse diffusion step (remove noise)
                naction = noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample

        # unnormalize action
        naction = naction.detach().to('cpu').numpy()
        # (B, pred_horizon, action_dim)
        naction = naction[0]
        action_pred = unnormalize_data(naction, stats=stats['action'])

        # only take action_horizon number of actions
        start = obs_horizon - 1
        end = start + action_horizon
        action = action_pred[start:end,:]
        # (action_horizon, action_dim)

        # execute action_horizon number of steps
        # without replanning

        print((action[:, 0]+1)/2)
        for i in range(len(action)):
            # stepping env
            obs_dict, reward, done, _, info, friction_state = ihm_step(env, action[i], last_friction_state)
            obs = obs_dict['observation']

            # print(obs)

            last_friction_state = friction_state
            # save observations
            obs_deque.append(obs)
            # and reward/vis
            rewards.append(0)

            # update progress bar
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=reward)
            if step_idx > max_steps:
                done = True
            if done or len(obs) >= max_steps:
                break
        
        # input("Press to continue")

env.close()
env.reset()

In [None]:
env.close()

In [None]:
import gymnasium as gym

# Create the environment
env = gym.make("VariableFriction-v8", render_mode="human")

# Reset the environment
obs_dict, info = env.reset(seed=100000)
print(obs_dict, info)