In [1]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time

In [2]:
from typing import Tuple, Sequence, Dict, Union, Optional
import numpy as np
import math
import collections
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [3]:
from utils import *
dataset_name1 = "hopper-medium-replay-v2"
dataset1, env_info = get_dataset(dataset_name1)
env_name = env_info['env_name']
state_dim = env_info['state_dim']
act_dim = env_info['action_dim']

  logger.warn(
  logger.warn(
load datafile: 100%|██████████| 11/11 [00:00<00:00, 34.30it/s]


In [4]:
dataset_name2 = "hopper-expert-v2"
dataset2, _ = get_dataset(dataset_name2)

dataset1_obs = np.concatenate([dataset1['infos/qpos'], dataset1['infos/qvel']], axis=1)
dataset2_obs = np.concatenate([dataset2['infos/qpos'], dataset2['infos/qvel']], axis=1)
observations = np.concatenate((dataset1_obs, dataset2_obs), axis=0)
actions = np.concatenate((dataset1['actions'], dataset2['actions']), axis=0)
next_observations = np.concatenate((dataset1['next_observations'], dataset2['next_observations']), axis=0)

load datafile: 100%|██████████| 21/21 [00:00<00:00, 28.74it/s]


In [5]:
next_observations = np.roll(observations, -1, axis=0)
mask = ~np.concatenate((dataset1['terminals']+dataset1['timeouts'], dataset2['terminals']+dataset2['timeouts']), axis=0)
observations = np.array(observations[mask], dtype=np.float32)
actions = np.array(actions[mask], dtype=np.float32)
next_observations = np.array(next_observations[mask], dtype=np.float32)

In [6]:
from huggingface_hub.utils import IGNORE_GIT_FOLDER_PATTERNS
obs = observations[0]
act = actions[0]
next_obs = next_observations[0]
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("Obs: ", repr(obs))
    print("Action: ", repr(act))
    print("Next Obs: ", repr(next_obs))

Obs:  array([-0.0024,  1.2498,  0.0041, ...,  0.0044,  0.0015,  0.0023],
      dtype=float32)
Action:  array([-0.438 ,  0.3708, -0.9334], dtype=float32)
Next Obs:  array([-0.0027,  1.2496,  0.0018, ..., -0.6151, -0.0082, -1.3391],
      dtype=float32)


In [7]:
def create_sample_indices(
        episode_ends:np.ndarray, sequence_length:int,
        pad_before: int=0, pad_after: int=0):
    indices = list()
    for i in range(len(episode_ends)):
        start_idx = 0
        if i > 0:
            start_idx = episode_ends[i-1]
        end_idx = episode_ends[i]
        episode_length = end_idx - start_idx

        min_start = -pad_before
        max_start = episode_length - sequence_length + pad_after

        # range stops one idx before end
        for idx in range(min_start, max_start+1):
            buffer_start_idx = max(idx, 0) + start_idx
            buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
            start_offset = buffer_start_idx - (idx+start_idx)
            end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
            sample_start_idx = 0 + start_offset
            sample_end_idx = sequence_length - end_offset
            indices.append([
                buffer_start_idx, buffer_end_idx,
                sample_start_idx, sample_end_idx])
    indices = np.array(indices)
    return indices


def sample_sequence(train_data, sequence_length,
                    buffer_start_idx, buffer_end_idx,
                    sample_start_idx, sample_end_idx):
    result = dict()
    for key, input_arr in train_data.items():
        sample = input_arr[buffer_start_idx:buffer_end_idx]
        data = sample
        if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
            data = np.zeros(
                shape=(sequence_length,) + input_arr.shape[1:],
                dtype=input_arr.dtype)
            if sample_start_idx > 0:
                data[:sample_start_idx] = sample[0]
            if sample_end_idx < sequence_length:
                data[sample_end_idx:] = sample[-1]
            data[sample_start_idx:sample_end_idx] = sample
        result[key] = data
    return result

# normalize data
def get_data_stats(data):
    data = data.reshape(-1,data.shape[-1])
    stats = {
        'min': np.min(data, axis=0),
        'max': np.max(data, axis=0)
    }
    return stats

def normalize_data(data, stats):
    # nomalize to [0,1]
    ndata = (data - stats['min']) / (stats['max'] - stats['min'])
    # normalize to [-1, 1]
    ndata = ndata * 2 - 1
    return ndata

def unnormalize_data(ndata, stats):
    ndata = (ndata + 1) / 2
    data = ndata * (stats['max'] - stats['min']) + stats['min']
    return data

# dataset
class DynamicsDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_root, pred_horizon=1, obs_horizon=1, next_obs_horizon=1):


        # All demonstration episodes are concatinated in the first dimension N
        train_data = {
            # (N, action_dim)
            'state_action': dataset_root['state_action'],
            # (N, obs_dim)
            'next_state': dataset_root['next_state'],
        }
        # Marks one-past the last index for each episode
        episode_ends = dataset_root['episode_ends']

        # compute start and end of each state-action sequence
        # also handles padding
        indices = create_sample_indices(
            episode_ends=episode_ends,
            sequence_length=pred_horizon,
            # add padding such that each timestep in the dataset are seen
            pad_before=obs_horizon-1,
            pad_after=next_obs_horizon-1)

        # compute statistics and normalized data to [-1,1]
        stats = dict()
        normalized_train_data = dict()
        for key, data in train_data.items():
            stats[key] = get_data_stats(data)
            normalized_train_data[key] = normalize_data(data, stats[key])

        self.indices = indices
        self.stats = stats
        self.normalized_train_data = normalized_train_data
        self.pred_horizon = pred_horizon
        self.next_obs_horizon = next_obs_horizon
        self.obs_horizon = obs_horizon

    def __len__(self):
        # all possible segments of the dataset
        return len(self.indices)

    def __getitem__(self, idx):
        # get the start/end indices for this datapoint
        buffer_start_idx, buffer_end_idx, \
            sample_start_idx, sample_end_idx = self.indices[idx]

        # get nomralized data using these indices
        nsample = sample_sequence(
            train_data=self.normalized_train_data,
            sequence_length=self.pred_horizon,
            buffer_start_idx=buffer_start_idx,
            buffer_end_idx=buffer_end_idx,
            sample_start_idx=sample_start_idx,
            sample_end_idx=sample_end_idx
        )

        # discard unused observations
        nsample['state_action'] = nsample['state_action'][:self.obs_horizon,:]
        return nsample

In [8]:
dataset_root = {
    'state_action': np.concatenate((observations, actions), axis=1),
    'next_state': next_observations,
    'episode_ends': np.array([len(observations)])
}

In [9]:
hopper_dataset = DynamicsDataset(dataset_root)
stats = hopper_dataset.stats
print("Stats: ")
for key, stat in stats.items():
    print(key+":")
    print("Min: ", stat['min'])
    print("Max: ", stat['max'])

Stats: 
state_action:
Min:  [ -5.0355263    0.70533     -0.19956256  -1.7470077   -1.9270817
  -0.9721309   -2.2943766   -5.580463    -5.682034   -10.
 -10.         -10.          -0.9999998   -0.99999475  -1.        ]
Max:  [22.013454    1.8055555   0.19973862  0.05772436  0.12416191  0.97480917
  5.6330166   3.3520665   7.750577   10.         10.         10.
  1.          1.          1.        ]
next_state:
Min:  [ -5.042045     0.7000183   -0.19996448  -1.7813634   -1.9897504
  -0.9721309   -2.2943766   -5.959067    -5.8681846  -10.
 -10.         -10.        ]
Max:  [22.047226    1.8055555   0.19989353  0.05772436  0.12416191  0.97480917
  5.6330166   3.3520665   7.9372845  10.         10.         10.        ]


In [10]:
dataloader = DataLoader(hopper_dataset, batch_size=256,
                        num_workers=1, shuffle=True,
                        pin_memory=True, persistent_workers=True)

batch = next(iter(dataloader))
print("Batch: ")
for key, val in batch.items():
    print(key, val.shape)

Batch: 
state_action torch.Size([256, 1, 15])
next_state torch.Size([256, 1, 12])


In [16]:
state_dim = observations.shape[1]
act_dim = actions.shape[1]
next_state_dim = next_observations.shape[1]
state_dim, act_dim, next_state_dim

(12, 3, 12)

In [17]:
#@markdown ### **Network**
#@markdown
#@markdown Defines a 1D UNet architecture `ConditionalUnet1D`
#@markdown as the noies prediction network
#@markdown
#@markdown Components
#@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k
#@markdown - `Downsample1d` Strided convolution to reduce temporal resolution
#@markdown - `Upsample1d` Transposed convolution to increase temporal resolution
#@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish
#@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \
#@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection.
#@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning.

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class ConditionalResidualBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
            Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
        ])

        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

        # make sure dimensions compatible
        self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        '''
            x : [ batch_size x in_channels x horizon ]
            cond : [ batch_size x cond_dim]

            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x)
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * out + bias

        out = self.blocks[1](out)
        out = out + self.residual_conv(x)
        return out


class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
            ConditionalResidualBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ),
        ])

        down_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_out, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(nn.ModuleList([
                ConditionalResidualBlock1D(
                    dim_out*2, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                ConditionalResidualBlock1D(
                    dim_in, dim_in, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv

        print("number of parameters: {:e}".format(
            sum(p.numel() for p in self.parameters()))
        )

    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)
        # (B,C,T)

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)
        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps.expand(sample.shape[0])

        global_feature = self.diffusion_step_encoder(timesteps)

        if global_cond is not None:
            global_feature = torch.cat([
                global_feature, global_cond
            ], axis=-1)

        x = sample
        h = []
        for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            h.append(x)
            x = downsample(x)

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
            # print(x.shape, h[-1].shape)
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, global_feature)
            x = resnet2(x, global_feature)
            x = upsample(x)

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


In [18]:
input_dim = state_dim # 12 for hopper (noise prediction)
obs_dim = state_dim + act_dim # 12 + 3 for hopper (state + action)
obs_horizon = 1 # 1 for markov dynamics
global_cond_dim = obs_horizon * obs_dim # 15 for film conditioning on previous state-action
pred_horizon = 1 # 1 for next state prediction

# lets use pred_horizon = 16 for now
# pred_horizon = 1 throws an error
pred_horizon = 16

In [19]:
noise_pred_net = ConditionalUnet1D(
    input_dim=input_dim,
    global_cond_dim=global_cond_dim,
)
# conditional unet is a noise prediction network, initialized with input_dim and global_cond_dim
# input_dim is the dimension of the input and output of the network
# global_cond_dim is the dimension of the global conditioning applied with FiLM in addition to diffusion step embedding

# example inputs
noised_next_state = torch.randn((1, pred_horizon, input_dim))
obs = torch.zeros((1, obs_horizon, obs_dim))
diffusion_iter = torch.zeros((1,))


# the noise prediction network
# takes noisy action, diffusion iteration and observation as input
# predicts the noise added to action
print("Noised Next State: ", noised_next_state.shape)
noise = noise_pred_net(
    sample=noised_next_state,
    timestep=diffusion_iter,
    global_cond=obs.flatten(start_dim=1))

# illustration of removing noise
# the actual noise removal is performed by NoiseScheduler
# and is dependent on the diffusion noise schedule
denoised_action = noised_next_state - noise

# DDPMScheduler with 100 diffusion iterations
num_diffusion_iters = 100
noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    # clip output to [-1,1] to improve stability
    clip_sample=True,
    # the network predicts noise (instead of denoised action)
    prediction_type='epsilon'
)

# device transfer
device = torch.device('cuda')
_ = noise_pred_net.to(device)

number of parameters: 6.544283e+07
Noised Next State:  torch.Size([1, 16, 12])


In [None]:

num_epochs = 100 

# Exponential Moving Average
# accelerates training and improves stability
# holds a copy of the model weights
ema = EMAModel(
    parameters=noise_pred_net.parameters(),
    power=0.75)

# Standard ADAM optimizer
# Note that EMA parametesr are not optimized
optimizer = torch.optim.AdamW(
    params=noise_pred_net.parameters(),
    lr=1e-4, weight_decay=1e-6)

# Cosine LR schedule with linear warmup
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

with tqdm(range(num_epochs), desc='Epoch') as tglobal:
    # epoch loop
    for epoch_idx in tglobal:
        epoch_loss = list()
        # batch loop
        with tqdm(dataloader, desc='Batch', leave=False) as tepoch:
            for nbatch in tepoch:
                # data normalized in dataset
                # device transfer
                nobs = nbatch['state_action'].to(device)
                nnextobs = nbatch['next_state'].to(device)
                B = nobs.shape[0]

                # observation as FiLM conditioning
                # (B, obs_horizon, obs_dim)
                obs_cond = nobs[:,:obs_horizon,:]
                # (B, obs_horizon * obs_dim)
                obs_cond = obs_cond.flatten(start_dim=1)

                # sample noise to add to nextobs
                noise_shape = nnextobs.shape
                # nnextobs shape is (B, pred_horizon, input_dim)
                # pred_horizon on dataset is 1, as we are predicting the next state only
                # the networks throws an error if pred_horizon is < 16
                # change noise_shape to (B, 16, input_dim)
                noise_shape = (noise_shape[0], pred_horizon, noise_shape[-1])
                noise = torch.randn(noise_shape, device=device)

                # sample a diffusion iteration for each data point
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (B,), device=device
                ).long()

                # add noise to the clean images according to the noise magnitude at each diffusion iteration
                # (this is the forward diffusion process)
                noisy_nextobs = noise_scheduler.add_noise(
                    nnextobs, noise, timesteps)
                
                # nnextobs (B, pred_horizon_dataset, input_dim), pred_horizon_dataset = 1
                # noise (B, pred_horizon, input_dim), pred_horizon = 16 required to run the network
                # noisy_nextobs <- nnextobs + noise at each diffusion iteration
                # noisy_nextobs (B, pred_horizon, input_dim), pred_horizon = 16
                

                # predict the noise residual
                noise_pred = noise_pred_net(
                    noisy_nextobs, timesteps, global_cond=obs_cond)
                
                # Based on the noisy nextobs, the network predicts the noise added to the nextobs
                # conditioned on the observation

                # L2 loss between predicted noise and actual noise
                loss = nn.functional.mse_loss(noise_pred, noise)


                # optimize
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                # step lr scheduler every batch
                # this is different from standard pytorch behavior
                lr_scheduler.step()

                # update Exponential Moving Average of the model weights
                ema.step(noise_pred_net.parameters())

                # logging
                loss_cpu = loss.item()
                epoch_loss.append(loss_cpu)
                tepoch.set_postfix(loss=loss_cpu)
        if (epoch_idx+1) % 50 == 0:
            torch.save(noise_pred_net.state_dict(), f'./DiffusionDynamicsModels/ddpm_hopper_dynamics_{epoch_idx+1}.pth')
        tglobal.set_postfix(loss=np.mean(epoch_loss))

# Weights of the EMA model
# is used for inference
ema_noise_pred_net = noise_pred_net
ema.copy_to(ema_noise_pred_net.parameters())

In [21]:
ema_noise_pred_net = noise_pred_net
ema_noise_pred_net.load_state_dict(torch.load('./DiffusionDynamicsModels/ddpm_hopper_dynamics_100.pth'))
print("Model loaded")

Model loaded


In [55]:
def predict_dynamics(
    model: nn.Module,
    obs,
    act,
    obs_horizon: int = 1,
    device: str = 'cuda'
):
    
    # the function assumes a normalized datapoint 
    # and returns an unnormalized prediction

    # nobs is (T, obs_dim + act_dim) 
    
    if not isinstance(obs, torch.Tensor):
        if isinstance(obs, np.ndarray):
            nobs = np.concatenate((obs, act), axis=-1)
        else:
            print("Use torch tensors or numpy arrays as input!")
        nobs = torch.tensor(nobs, dtype=torch.float32, device=device).detach()
    else:
        nobs = torch.cat((obs, act), dim=-1)
        nobs = nobs.to(device).to(torch.float32).clone().detach()

    # add a code to incorporate the obs_horizon
    obs_horizon = int(obs_horizon)
    pred_horizon = 16 # for now, debug later

    B = obs.shape[0]
    obs_dim = obs.shape[-1]

    # infer the next state
    with torch.no_grad():
        obs_cond = nobs.clone().detach()
        noisy_nextobs = torch.randn((B, pred_horizon, obs_dim), device=device)

        noise_scheduler.set_timesteps(num_diffusion_iters)
        
        for k in noise_scheduler.timesteps:
            noise = model(
                sample = noisy_nextobs,
                timestep = k,
                global_cond = obs_cond
            )
            noisy_nextobs = noise_scheduler.step(
                model_output=noise,
                timestep=k,
                sample=noisy_nextobs
            ).prev_sample

    noisy_nextobs = noisy_nextobs.cpu().numpy()
    noisy_nextobs = noisy_nextobs[:,0,:]
    return noisy_nextobs

In [53]:
# use the model to predict the next state
sample_data = next(iter(dataloader))
obs = sample_data['state_action'][0]
act = obs[0:1, state_dim:]
obs = obs[0:1, :state_dim]
next_obs = sample_data['next_state'][0]
pred = predict_dynamics(
    model=ema_noise_pred_net,
    obs=obs,
    act=act,
    obs_horizon=1,
    device='cuda'
)

print("Predicted Next State: ", pred)
print("Actual Next State: ", next_obs)

torch.Size([1, 15])
Predicted Next State:  [[-5.7977831e-01 -5.5535835e-01 -9.0648212e-02  7.9119086e-02
   8.8451844e-01  7.4196362e-01 -1.5411229e-01  1.2442957e-01
  -1.5479720e-01 -1.4843488e-01  5.6755950e-04  6.2835328e-03]]
Actual Next State:  tensor([[-5.7969e-01, -5.5526e-01, -9.0659e-02,  7.9296e-02,  8.8481e-01,
          7.4183e-01, -1.5465e-01,  1.2446e-01, -1.5484e-01, -1.4836e-01,
          6.2013e-04,  6.0327e-03]])


In [33]:
test_dataset_name = "hopper-random-v2"
test_dataset, env_info = get_dataset(test_dataset_name)

  logger.warn(
  logger.warn(
load datafile: 100%|██████████| 9/9 [00:00<00:00, 12.52it/s]


In [34]:
rand_obs = np.concatenate([test_dataset['infos/qpos'], test_dataset['infos/qvel']], axis=1)
rand_act = test_dataset['actions']
rand_next_obs = test_dataset['next_observations']

In [35]:
rand_obs = np.concatenate([test_dataset['infos/qpos'], test_dataset['infos/qvel']], axis=1)
rand_act = test_dataset['actions']
rand_next_obs = test_dataset['next_observations']

In [36]:
rand_obs_act = np.concatenate((rand_obs, rand_act), axis=1)
rand_obs_act = normalize_data(rand_obs_act, stats['state_action'])

In [None]:
normalized_rand_obs = rand_obs_act[:, :state_dim]
normalized_rand_act = rand_obs_act[:, state_dim:]

normalized_rand_act = normalized_rand_act.reshape(1000, 1000, -1)
normalized_rand_obs = normalized_rand_obs.reshape(1000, 1000, -1)


all_predicted_next_obs = list()
for this_obs, this_act in zip(normalized_rand_obs, normalized_rand_act):
    predicted_next_obs = predict_dynamics(
                                model=ema_noise_pred_net,
                                obs=this_obs,
                                act=this_act,
                                obs_horizon=1,
                                device='cuda'
                            )
    all_predicted_next_obs.append(predicted_next_obs)

In [61]:
all_predicted_next_obs = np.array(all_predicted_next_obs)
all_predicted_next_obs = unnormalize_data(all_predicted_next_obs, stats['next_state']).reshape(-1, state_dim)
all_predicted_next_obs.shape

(423, 1000, 12)