**Скачаем репозиторий TD-MPC:**

In [1]:
!git clone https://github.com/nicklashansen/tdmpc

Cloning into 'tdmpc'...
remote: Enumerating objects: 108, done.[K
remote: Counting objects: 100% (53/53), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 108 (delta 31), reused 31 (delta 30), pack-reused 55 (from 1)[K
Receiving objects: 100% (108/108), 1.49 MiB | 6.19 MiB/s, done.
Resolving deltas: 100% (38/38), done.


**Перепишем код логгера, добавив возможность сохранять чекпойнты:**

In [2]:
%%writefile tdmpc/src/train.py
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import Episode, ReplayBuffer
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards = []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        if video: video.save(env_step)
    return np.nanmean(episode_rewards)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)
    
    # Run training
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):

        # Collect trajectory
        obs = env.reset()
        episode = Episode(cfg, obs)
        while not episode.done:
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            episode += (obs, action, reward, done)
        assert len(episode) == cfg.episode_length
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))
        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train')
        
        if (cfg.save_checkpoint) and (env_step % cfg.save_freq == 0):
             L.save_checkpoint(agent, step) 
        
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'] = evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            L.log(common_metrics, category='eval')
            
    L.finish(agent)
    print('Training completed successfully')	


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))


Overwriting tdmpc/src/train.py


In [3]:
%%writefile tdmpc/src/logger.py
import sys
import os
import datetime
import re
import numpy as np
import torch
import pandas as pd
from termcolor import colored
from omegaconf import OmegaConf


CONSOLE_FORMAT = [('episode', 'E', 'int'), ('env_step', 'S', 'int'), ('episode_reward', 'R', 'float'), ('total_time', 'T', 'time')]
AGENT_METRICS = ['consistency_loss', 'reward_loss', 'value_loss', 'total_loss', 'weighted_loss', 'pi_loss', 'grad_norm']


def make_dir(dir_path):
	"""Create directory if it does not already exist."""
	try:
		os.makedirs(dir_path)
	except OSError:
		pass
	return dir_path


def print_run(cfg, reward=None):
	"""Pretty-printing of run information. Call at start of training."""
	prefix, color, attrs = '  ', 'green', ['bold']
	def limstr(s, maxlen=32):
		return str(s[:maxlen]) + '...' if len(str(s)) > maxlen else s
	def pprint(k, v):
		print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v))
	kvs = [('task', cfg.task_title),
		   ('train steps', f'{int(cfg.train_steps*cfg.action_repeat):,}'),
		   ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
		   ('actions', cfg.action_dim),
		   ('experiment', cfg.exp_name)]
	if reward is not None:
		kvs.append(('episode reward', colored(str(int(reward)), 'white', attrs=['bold'])))
	w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
	div = '-'*w
	print(div)
	for k,v in kvs:
		pprint(k, v)
	print(div)


def cfg_to_group(cfg, return_list=False):
	"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
	lst = [cfg.task, cfg.modality, re.sub('[^0-9a-zA-Z]+', '-', cfg.exp_name)]
	return lst if return_list else '-'.join(lst)


class VideoRecorder:
	"""Utility class for logging evaluation videos."""
	def __init__(self, root_dir, wandb, render_size=384, fps=15):
		self.save_dir = (root_dir / 'eval_video') if root_dir else None
		self._wandb = wandb
		self.render_size = render_size
		self.fps = fps
		self.frames = []
		self.enabled = False

	def init(self, env, enabled=True):
		self.frames = []
		self.enabled = self.save_dir and self._wandb and enabled
		self.record(env)

	def record(self, env):
		if self.enabled:
			frame = env.render(mode='rgb_array', height=self.render_size, width=self.render_size, camera_id=0)
			self.frames.append(frame)

	def save(self, step):
		if self.enabled:
			frames = np.stack(self.frames).transpose(0, 3, 1, 2)
			self._wandb.log({'eval_video': self._wandb.Video(frames, fps=self.fps, format='mp4')}, step=step)


class Logger(object):
	"""Primary logger object. Logs either locally or using wandb."""
	def __init__(self, log_dir, cfg):
		self._log_dir = make_dir(log_dir)
		self._model_dir = make_dir(self._log_dir / 'models')
		self._save_model = cfg.save_model
		self._group = cfg_to_group(cfg)
		self._seed = cfg.seed
		self._cfg = cfg
		self._eval = []
		print_run(cfg)
		project, entity = cfg.get('wandb_project', 'none'), cfg.get('wandb_entity', 'none')
		run_offline = not cfg.get('use_wandb', False) or project == 'none' or entity == 'none'
		if run_offline:
			print(colored('Logs will be saved locally.', 'yellow', attrs=['bold']))
			self._wandb = None
		else:
			try:
				os.environ["WANDB_SILENT"] = "true"
				import wandb
				wandb.init(project=project,
						entity=entity,
						name=cfg.exp_name,
						group=self._group,
						tags=cfg_to_group(cfg, return_list=True) + [f'seed:{cfg.seed}'],
						dir=self._log_dir,
						config=OmegaConf.to_container(cfg, resolve=True))
				print(colored('Logs will be synced with wandb.', 'blue', attrs=['bold']))
				self._wandb = wandb
			except:
				print(colored('Warning: failed to init wandb. Logs will be saved locally.', 'yellow'), attrs=['bold'])
				self._wandb = None
		self._video = VideoRecorder(log_dir, self._wandb) if self._wandb and cfg.save_video else None

	@property
	def video(self):
		return self._video

	def finish(self, agent):
		if self._save_model:
			fp = self._model_dir / f'model.pt'
			torch.save(agent.state_dict(), fp)
			if self._wandb:
				artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model')
				artifact.add_file(fp)
				self._wandb.log_artifact(artifact)
		if self._wandb:
			self._wandb.finish()
		print_run(self._cfg, self._eval[-1][-1])
		
	def save_checkpoint(self, agent, iteration):
		fp = self._model_dir / f'model_{iteration}.pt'
		torch.save(agent.state_dict(), fp)
		if self._wandb:
			artifact = self._wandb.Artifact(self._group+'-'+str(self._seed)+'-'+str(iteration), type='model')
			artifact.add_file(fp)
			self._wandb.log_artifact(artifact)

	def _format(self, key, value, ty):
		if ty == 'int':
			return f'{colored(key+":", "grey")} {int(value):,}'
		elif ty == 'float':
			return f'{colored(key+":", "grey")} {value:.01f}'
		elif ty == 'time':
			value = str(datetime.timedelta(seconds=int(value)))
			return f'{colored(key+":", "grey")} {value}'
		else:
			raise f'invalid log format type: {ty}'

	def _print(self, d, category):
		category = colored(category, 'blue' if category == 'train' else 'green')
		pieces = [f' {category:<14}']
		for k, disp_k, ty in CONSOLE_FORMAT:
			pieces.append(f'{self._format(disp_k, d.get(k, 0), ty):<26}')
		print('   '.join(pieces))

	def log(self, d, category='train'):
		assert category in {'train', 'eval'}
		if self._wandb is not None:
			for k,v in d.items():
				self._wandb.log({category + '/' + k: v}, step=d['env_step'])
		if category == 'eval':
			keys = ['env_step', 'episode_reward']
			self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
			pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / 'eval.log', header=keys, index=None)
		self._print(d, category)


Overwriting tdmpc/src/logger.py


**Устанавливаем зависимости:**

In [None]:
!sudo apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf

In [5]:
!mkdir -p /root/.mujoco

In [None]:
!wget https://github.com/google-deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz

In [7]:
!tar -xf mujoco.tar.gz -C /root/.mujoco

In [None]:
!sudo apt-get install python3.8 python3.8-dev python3.8-distutils
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
!sudo apt-get install python3.8-venv -y
!sudo update-alternatives --config python3

In [None]:
!sudo apt-get install python3-pip

In [10]:
!python3.8 -m venv new-env

In [11]:
!source new-env/bin/activate

In [12]:
!source new-env/bin/activate; echo  $VIRTUAL_ENV

/kaggle/working/new-env


In [13]:
!source new-env/bin/activate; python --version

Python 3.8.20


In [None]:
import os
os.environ["WANDB_API_KEY"] = input()

In [None]:
!source new-env/bin/activate; pip install \
  wheel \
  torch==1.9.0+cu111 \
  torchvision==0.10.0+cu111 \
  -f https://download.pytorch.org/whl/torch_stable.html \
  termcolor \
  omegaconf \
  gym==0.21.0 \
  dm-control==0.0.403778684 \
  pandas \
  wandb \
  moviepy \
  imageio \
  numpy==1.19.5

**Эксперимент без добавления регуляризации (не используется ни latent state consistency loss, ни reconstruction loss):**

In [16]:
%%writefile tdmpc/cfgs/default.yaml
# environment
task: quadruped-walk
modality: 'state'
action_repeat: ???
discount: 0.99
episode_length: 1000/${action_repeat}
train_steps: 500000/${action_repeat}

# planning
iterations: 6
num_samples: 512
num_elites: 64
mixture_coef: 0.05
min_std: 0.05
temperature: 0.5
momentum: 0.1

# learning
batch_size: 512
max_buffer_size: 1000000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 0
rho: 0.5
kappa: 0.1
lr: 1e-3
std_schedule: linear(0.5, ${min_std}, 25000)
horizon_schedule: linear(1, ${horizon}, 25000)
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 5000
update_freq: 2
tau: 0.01

# architecture
enc_dim: 256
mlp_dim: 512
latent_dim: 50

# wandb (insert your own)
use_wandb: true
wandb_project: 'TD-MPC'
wandb_entity: 'askoro'

# misc
seed: 1
exp_name: default
eval_freq: 20000
eval_episodes: 10
save_video: true
save_model: true
save_checkpoint: true
save_freq: 100000

Overwriting tdmpc/cfgs/default.yaml


In [17]:
!source new-env/bin/activate; cd tdmpc; python  src/train.py exp_name="quadruped_walk_no regularization"

-----------------------------------------------------
  [1m[32mTask:           [0m Quadruped Walk
  [1m[32mTrain steps:    [0m 500,000
  [1m[32mObservations:   [0m 78
  [1m[32mActions:        [0m 12
  [1m[32mExperiment:     [0m quadruped_walk_no regularization
-----------------------------------------------------
[1m[34mLogs will be synced with wandb.[0m
 [34mtrain[0m   [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 510.8            [30mT:[0m 0:00:00       
 [32meval[0m    [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 58.5             [30mT:[0m 0:00:00       
 [34mtrain[0m   [30mE:[0m 2                [30mS:[0m 1,000            [30mR:[0m 498.2            [30mT:[0m 0:00:57       
 [34mtrain[0m   [30mE:[0m 3                [30mS:[0m 2,000            [30mR:[0m 6.8              [30mT:[0m 0:00:57       
 [34mtrain[0m   [30mE:[0m 4                [30mS:[0m 3,000            [30mR:[0m 211.6    

---

**Эксперимент с добавлением latent state consistency loss:**

In [19]:
%%writefile tdmpc/cfgs/default.yaml
# environment
task: quadruped-walk
modality: 'state'
action_repeat: ???
discount: 0.99
episode_length: 1000/${action_repeat}
train_steps: 500000/${action_repeat}

# planning
iterations: 6
num_samples: 512
num_elites: 64
mixture_coef: 0.05
min_std: 0.05
temperature: 0.5
momentum: 0.1

# learning
batch_size: 512
max_buffer_size: 1000000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 2
rho: 0.5
kappa: 0.1
lr: 1e-3
std_schedule: linear(0.5, ${min_std}, 25000)
horizon_schedule: linear(1, ${horizon}, 25000)
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 5000
update_freq: 2
tau: 0.01

# architecture
enc_dim: 256
mlp_dim: 512
latent_dim: 50

# wandb (insert your own)
use_wandb: true
wandb_project: 'TD-MPC'
wandb_entity: 'askoro'

# misc
seed: 1
exp_name: default
eval_freq: 20000
eval_episodes: 10
save_video: true
save_model: true
save_checkpoint: true
save_freq: 100000

Overwriting tdmpc/cfgs/default.yaml


In [20]:
!source new-env/bin/activate; cd tdmpc; python src/train.py exp_name="quadruped_walk_consistency_loss"

----------------------------------------------------
  [1m[32mTask:           [0m Quadruped Walk
  [1m[32mTrain steps:    [0m 500,000
  [1m[32mObservations:   [0m 78
  [1m[32mActions:        [0m 12
  [1m[32mExperiment:     [0m quadruped_walk_consistency_loss
----------------------------------------------------
[1m[34mLogs will be synced with wandb.[0m
 [34mtrain[0m   [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 510.8            [30mT:[0m 0:00:00       
 [32meval[0m    [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 58.5             [30mT:[0m 0:00:00       
 [34mtrain[0m   [30mE:[0m 2                [30mS:[0m 1,000            [30mR:[0m 498.2            [30mT:[0m 0:00:58       
 [34mtrain[0m   [30mE:[0m 3                [30mS:[0m 2,000            [30mR:[0m 6.8              [30mT:[0m 0:00:59       
 [34mtrain[0m   [30mE:[0m 4                [30mS:[0m 3,000            [30mR:[0m 211.6       

---

**Эксперимент с добавлением reconstruction loss:**

In [43]:
%%writefile tdmpc/src/algorithm/helper.py
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal


__REDUCE__ = lambda b: 'mean' if b else 'none'


def l1(pred, target, reduce=False):
	"""Computes the L1-loss between predictions and targets."""
	return F.l1_loss(pred, target, reduction=__REDUCE__(reduce))


def mse(pred, target, reduce=False):
	"""Computes the MSE loss between predictions and targets."""
	return F.mse_loss(pred, target, reduction=__REDUCE__(reduce))


def _get_out_shape(in_shape, layers):
	"""Utility function. Returns the output shape of a network for a given input shape."""
	x = torch.randn(*in_shape).unsqueeze(0)
	return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape


def orthogonal_init(m):
	"""Orthogonal layer initialization."""
	if isinstance(m, nn.Linear):
		nn.init.orthogonal_(m.weight.data)
		if m.bias is not None:
			nn.init.zeros_(m.bias)
	elif isinstance(m, nn.Conv2d):
		gain = nn.init.calculate_gain('relu')
		nn.init.orthogonal_(m.weight.data, gain)
		if m.bias is not None:
			nn.init.zeros_(m.bias)


def ema(m, m_target, tau):
	"""Update slow-moving average of online network (target network) at rate tau."""
	with torch.no_grad():
		for p, p_target in zip(m.parameters(), m_target.parameters()):
			p_target.data.lerp_(p.data, tau)


def set_requires_grad(net, value):
	"""Enable/disable gradients for a given (sub)network."""
	for param in net.parameters():
		param.requires_grad_(value)


class TruncatedNormal(pyd.Normal):
	"""Utility class implementing the truncated normal distribution."""
	def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
		super().__init__(loc, scale, validate_args=False)
		self.low = low
		self.high = high
		self.eps = eps

	def _clamp(self, x):
		clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
		x = x - x.detach() + clamped_x.detach()
		return x

	def sample(self, clip=None, sample_shape=torch.Size()):
		shape = self._extended_shape(sample_shape)
		eps = _standard_normal(shape,
							   dtype=self.loc.dtype,
							   device=self.loc.device)
		eps *= self.scale
		if clip is not None:
			eps = torch.clamp(eps, -clip, clip)
		x = self.loc + eps
		return self._clamp(x)


class NormalizeImg(nn.Module):
	"""Normalizes pixel observations to [0,1) range."""
	def __init__(self):
		super().__init__()

	def forward(self, x):
		return x.div(255.)


class Flatten(nn.Module):
	"""Flattens its input to a (batched) vector."""
	def __init__(self):
		super().__init__()
		
	def forward(self, x):
		return x.view(x.size(0), -1)


def enc(cfg):
	"""Returns a TOLD encoder."""
	if cfg.modality == 'pixels':
		C = int(3*cfg.frame_stack)
		layers = [NormalizeImg(),
				  nn.Conv2d(C, cfg.num_channels, 7, stride=2), nn.ReLU(),
				  nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), nn.ReLU(),
				  nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), nn.ReLU(),
				  nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), nn.ReLU()]
		out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), layers)
		layers.extend([Flatten(), nn.Linear(np.prod(out_shape), cfg.latent_dim)])
	else:
		layers = [nn.Linear(cfg.obs_shape[0], cfg.enc_dim), nn.ELU(),
				  nn.Linear(cfg.enc_dim, cfg.latent_dim)]
	return nn.Sequential(*layers)


def dec(cfg):
	"""Returns a TOLD decoder (reverse copy of encoder, works only for 'state' modality)."""
	layers = [nn.Linear(cfg.latent_dim, cfg.enc_dim), nn.ELU(), nn.Linear(cfg.enc_dim, cfg.obs_shape[0])]
	return nn.Sequential(*layers)


def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.ELU()):
	"""Returns an MLP."""
	if isinstance(mlp_dim, int):
		mlp_dim = [mlp_dim, mlp_dim]
	return nn.Sequential(
		nn.Linear(in_dim, mlp_dim[0]), act_fn,
		nn.Linear(mlp_dim[0], mlp_dim[1]), act_fn,
		nn.Linear(mlp_dim[1], out_dim))

def q(cfg, act_fn=nn.ELU()):
	"""Returns a Q-function that uses Layer Normalization."""
	return nn.Sequential(nn.Linear(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim), nn.LayerNorm(cfg.mlp_dim), nn.Tanh(),
						 nn.Linear(cfg.mlp_dim, cfg.mlp_dim), nn.ELU(),
						 nn.Linear(cfg.mlp_dim, 1))


class RandomShiftsAug(nn.Module):
	"""
	Random shift image augmentation.
	Adapted from https://github.com/facebookresearch/drqv2
	"""
	def __init__(self, cfg):
		super().__init__()
		self.pad = int(cfg.img_size/21) if cfg.modality == 'pixels' else None

	def forward(self, x):
		if not self.pad:
			return x
		n, c, h, w = x.size()
		assert h == w
		padding = tuple([self.pad] * 4)
		x = F.pad(x, padding, 'replicate')
		eps = 1.0 / (h + 2 * self.pad)
		arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
		arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
		base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
		base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
		shift = torch.randint(0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype)
		shift *= 2.0 / (h + 2 * self.pad)
		grid = base_grid + shift
		return F.grid_sample(x, grid, padding_mode='zeros', align_corners=False)


class Episode(object):
	"""Storage object for a single episode."""
	def __init__(self, cfg, init_obs):
		self.cfg = cfg
		self.device = torch.device(cfg.device)
		dtype = torch.float32 if cfg.modality == 'state' else torch.uint8
		self.obs = torch.empty((cfg.episode_length+1, *init_obs.shape), dtype=dtype, device=self.device)
		self.obs[0] = torch.tensor(init_obs, dtype=dtype, device=self.device)
		self.action = torch.empty((cfg.episode_length, cfg.action_dim), dtype=torch.float32, device=self.device)
		self.reward = torch.empty((cfg.episode_length,), dtype=torch.float32, device=self.device)
		self.cumulative_reward = 0
		self.done = False
		self._idx = 0
	
	def __len__(self):
		return self._idx

	@property
	def first(self):
		return len(self) == 0
	
	def __add__(self, transition):
		self.add(*transition)
		return self

	def add(self, obs, action, reward, done):
		self.obs[self._idx+1] = torch.tensor(obs, dtype=self.obs.dtype, device=self.obs.device)
		self.action[self._idx] = action
		self.reward[self._idx] = reward
		self.cumulative_reward += reward
		self.done = done
		self._idx += 1


class ReplayBuffer():
	"""
	Storage and sampling functionality for training TD-MPC / TOLD.
	The replay buffer is stored in GPU memory when training from state.
	Uses prioritized experience replay by default."""
	def __init__(self, cfg):
		self.cfg = cfg
		self.device = torch.device(cfg.device)
		self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
		dtype = torch.float32 if cfg.modality == 'state' else torch.uint8
		obs_shape = cfg.obs_shape if cfg.modality == 'state' else (3, *cfg.obs_shape[-2:])
		self._obs = torch.empty((self.capacity+1, *obs_shape), dtype=dtype, device=self.device)
		self._last_obs = torch.empty((self.capacity//cfg.episode_length, *cfg.obs_shape), dtype=dtype, device=self.device)
		self._action = torch.empty((self.capacity, cfg.action_dim), dtype=torch.float32, device=self.device)
		self._reward = torch.empty((self.capacity,), dtype=torch.float32, device=self.device)
		self._priorities = torch.ones((self.capacity,), dtype=torch.float32, device=self.device)
		self._eps = 1e-6
		self._full = False
		self.idx = 0

	def __add__(self, episode: Episode):
		self.add(episode)
		return self

	def add(self, episode: Episode):
		self._obs[self.idx:self.idx+self.cfg.episode_length] = episode.obs[:-1] if self.cfg.modality == 'state' else episode.obs[:-1, -3:]
		self._last_obs[self.idx//self.cfg.episode_length] = episode.obs[-1]
		self._action[self.idx:self.idx+self.cfg.episode_length] = episode.action
		self._reward[self.idx:self.idx+self.cfg.episode_length] = episode.reward
		if self._full:
			max_priority = self._priorities.max().to(self.device).item()
		else:
			max_priority = 1. if self.idx == 0 else self._priorities[:self.idx].max().to(self.device).item()
		mask = torch.arange(self.cfg.episode_length) >= self.cfg.episode_length-self.cfg.horizon
		new_priorities = torch.full((self.cfg.episode_length,), max_priority, device=self.device)
		new_priorities[mask] = 0
		self._priorities[self.idx:self.idx+self.cfg.episode_length] = new_priorities
		self.idx = (self.idx + self.cfg.episode_length) % self.capacity
		self._full = self._full or self.idx == 0

	def update_priorities(self, idxs, priorities):
		self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps

	def _get_obs(self, arr, idxs):
		if self.cfg.modality == 'state':
			return arr[idxs]
		obs = torch.empty((self.cfg.batch_size, 3*self.cfg.frame_stack, *arr.shape[-2:]), dtype=arr.dtype, device=torch.device('cuda'))
		obs[:, -3:] = arr[idxs].cuda()
		_idxs = idxs.clone()
		mask = torch.ones_like(_idxs, dtype=torch.bool)
		for i in range(1, self.cfg.frame_stack):
			mask[_idxs % self.cfg.episode_length == 0] = False
			_idxs[mask] -= 1
			obs[:, -(i+1)*3:-i*3] = arr[_idxs].cuda()
		return obs.float()

	def sample(self):
		probs = (self._priorities if self._full else self._priorities[:self.idx]) ** self.cfg.per_alpha
		probs /= probs.sum()
		total = len(probs)
		idxs = torch.from_numpy(np.random.choice(total, self.cfg.batch_size, p=probs.cpu().numpy(), replace=not self._full)).to(self.device)
		weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
		weights /= weights.max()

		obs = self._get_obs(self._obs, idxs)
		next_obs_shape = self._last_obs.shape[1:] if self.cfg.modality == 'state' else (3*self.cfg.frame_stack, *self._last_obs.shape[-2:])
		next_obs = torch.empty((self.cfg.horizon+1, self.cfg.batch_size, *next_obs_shape), dtype=obs.dtype, device=obs.device)
		action = torch.empty((self.cfg.horizon+1, self.cfg.batch_size, *self._action.shape[1:]), dtype=torch.float32, device=self.device)
		reward = torch.empty((self.cfg.horizon+1, self.cfg.batch_size), dtype=torch.float32, device=self.device)
		for t in range(self.cfg.horizon+1):
			_idxs = idxs + t
			next_obs[t] = self._get_obs(self._obs, _idxs+1)
			action[t] = self._action[_idxs]
			reward[t] = self._reward[_idxs]

		mask = (_idxs+1) % self.cfg.episode_length == 0
		next_obs[-1, mask] = self._last_obs[_idxs[mask]//self.cfg.episode_length].cuda().float()
		if not action.is_cuda:
			action, reward, idxs, weights = \
				action.cuda(), reward.cuda(), idxs.cuda(), weights.cuda()

		return obs, next_obs, action, reward.unsqueeze(2), idxs, weights


def linear_schedule(schdl, step):
	"""
	Outputs values following a linear decay schedule.
	Adapted from https://github.com/facebookresearch/drqv2
	"""
	try:
		return float(schdl)
	except ValueError:
		match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
		if match:
			init, final, duration = [float(g) for g in match.groups()]
			mix = np.clip(step / duration, 0.0, 1.0)
			return (1.0 - mix) * init + mix * final
	raise NotImplementedError(schdl)

Overwriting tdmpc/src/algorithm/helper.py


In [44]:
%%writefile tdmpc/src/algorithm/tdmpc.py
import numpy as np
import torch
import torch.nn as nn
from copy import deepcopy
import algorithm.helper as h


class TOLD(nn.Module):
	"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
	def __init__(self, cfg):
		super().__init__()
		self.cfg = cfg
		self._encoder = h.enc(cfg)
		self._decoder = h.dec(cfg)
		self._dynamics = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, cfg.latent_dim)
		self._reward = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, 1)
		self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim)
		self._Q1, self._Q2 = h.q(cfg), h.q(cfg)
		self.apply(h.orthogonal_init)
		for m in [self._reward, self._Q1, self._Q2]:
			m[-1].weight.data.fill_(0)
			m[-1].bias.data.fill_(0)

	def track_q_grad(self, enable=True):
		"""Utility function. Enables/disables gradient tracking of Q-networks."""
		for m in [self._Q1, self._Q2]:
			h.set_requires_grad(m, enable)

	def h(self, obs):
		"""Encodes an observation into its latent representation (h)."""
		return self._encoder(obs)
	
	def d(self, z):
		"""Decodes an observation from its latent representation (h)."""
		return self._decoder(z)

	def next(self, z, a):
		"""Predicts next latent state (d) and single-step reward (R)."""
		x = torch.cat([z, a], dim=-1)
		return self._dynamics(x), self._reward(x)

	def pi(self, z, std=0):
		"""Samples an action from the learned policy (pi)."""
		mu = torch.tanh(self._pi(z))
		if std > 0:
			std = torch.ones_like(mu) * std
			return h.TruncatedNormal(mu, std).sample(clip=0.3)
		return mu

	def Q(self, z, a):
		"""Predict state-action value (Q)."""
		x = torch.cat([z, a], dim=-1)
		return self._Q1(x), self._Q2(x)


class TDMPC():
	"""Implementation of TD-MPC learning + inference."""
	def __init__(self, cfg):
		self.cfg = cfg
		self.device = torch.device('cuda')
		self.std = h.linear_schedule(cfg.std_schedule, 0)
		self.model = TOLD(cfg).cuda()
		self.model_target = deepcopy(self.model)
		self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
		self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
		self.aug = h.RandomShiftsAug(cfg)
		self.model.eval()
		self.model_target.eval()

	def state_dict(self):
		"""Retrieve state dict of TOLD model, including slow-moving target network."""
		return {'model': self.model.state_dict(),
				'model_target': self.model_target.state_dict()}

	def save(self, fp):
		"""Save state dict of TOLD model to filepath."""
		torch.save(self.state_dict(), fp)
	
	def load(self, fp):
		"""Load a saved state dict from filepath into current agent."""
		d = torch.load(fp)
		self.model.load_state_dict(d['model'])
		self.model_target.load_state_dict(d['model_target'])

	@torch.no_grad()
	def estimate_value(self, z, actions, horizon):
		"""Estimate value of a trajectory starting at latent state z and executing given actions."""
		G, discount = 0, 1
		for t in range(horizon):
			z, reward = self.model.next(z, actions[t])
			G += discount * reward
			discount *= self.cfg.discount
		G += discount * torch.min(*self.model.Q(z, self.model.pi(z, self.cfg.min_std)))
		return G

	@torch.no_grad()
	def plan(self, obs, eval_mode=False, step=None, t0=True):
		"""
		Plan next action using TD-MPC inference.
		obs: raw input observation.
		eval_mode: uniform sampling and action noise is disabled during evaluation.
		step: current time step. determines e.g. planning horizon.
		t0: whether current step is the first step of an episode.
		"""
		# Seed steps
		if step < self.cfg.seed_steps and not eval_mode:
			return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)

		# Sample policy trajectories
		obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
		horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
		num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
		if num_pi_trajs > 0:
			pi_actions = torch.empty(horizon, num_pi_trajs, self.cfg.action_dim, device=self.device)
			z = self.model.h(obs).repeat(num_pi_trajs, 1)
			for t in range(horizon):
				pi_actions[t] = self.model.pi(z, self.cfg.min_std)
				z, _ = self.model.next(z, pi_actions[t])

		# Initialize state and parameters
		z = self.model.h(obs).repeat(self.cfg.num_samples+num_pi_trajs, 1)
		mean = torch.zeros(horizon, self.cfg.action_dim, device=self.device)
		std = 2*torch.ones(horizon, self.cfg.action_dim, device=self.device)
		if not t0 and hasattr(self, '_prev_mean'):
			mean[:-1] = self._prev_mean[1:]

		# Iterate CEM
		for i in range(self.cfg.iterations):
			actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * \
				torch.randn(horizon, self.cfg.num_samples, self.cfg.action_dim, device=std.device), -1, 1)
			if num_pi_trajs > 0:
				actions = torch.cat([actions, pi_actions], dim=1)

			# Compute elite actions
			value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
			elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
			elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]

			# Update parameters
			max_value = elite_value.max(0)[0]
			score = torch.exp(self.cfg.temperature*(elite_value - max_value))
			score /= score.sum(0)
			_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
			_std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9))
			_std = _std.clamp_(self.std, 2)
			mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std

		# Outputs
		score = score.squeeze(1).cpu().numpy()
		actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
		self._prev_mean = mean
		mean, std = actions[0], _std[0]
		a = mean
		if not eval_mode:
			a += std * torch.randn(self.cfg.action_dim, device=std.device)
		return a

	def update_pi(self, zs):
		"""Update policy using a sequence of latent states."""
		self.pi_optim.zero_grad(set_to_none=True)
		self.model.track_q_grad(False)

		# Loss is a weighted sum of Q-values
		pi_loss = 0
		for t,z in enumerate(zs):
			a = self.model.pi(z, self.cfg.min_std)
			Q = torch.min(*self.model.Q(z, a))
			pi_loss += -Q.mean() * (self.cfg.rho ** t)

		pi_loss.backward()
		torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False)
		self.pi_optim.step()
		self.model.track_q_grad(True)
		return pi_loss.item()

	@torch.no_grad()
	def _td_target(self, next_obs, reward):
		"""Compute the TD-target from a reward and the observation at the following time step."""
		next_z = self.model.h(next_obs)
		td_target = reward + self.cfg.discount * \
			torch.min(*self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)))
		return td_target

	def update(self, replay_buffer, step):
		"""Main update function. Corresponds to one iteration of the TOLD model learning."""
		obs, next_obses, action, reward, idxs, weights = replay_buffer.sample()
		self.optim.zero_grad(set_to_none=True)
		self.std = h.linear_schedule(self.cfg.std_schedule, step)
		self.model.train()

		# Representation
		z = self.model.h(self.aug(obs))
		zs = [z.detach()]

		reconstruction_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0
		for t in range(self.cfg.horizon):

			# Predictions
			Q1, Q2 = self.model.Q(z, action[t])
			z, reward_pred = self.model.next(z, action[t])
			dec_obs = self.model.d(z)
			with torch.no_grad():
				next_obs = self.aug(next_obses[t])
				td_target = self._td_target(next_obs, reward[t])
			zs.append(z.detach())

			# Losses
			rho = (self.cfg.rho ** t)
			reconstruction_loss += rho * torch.mean(h.mse(dec_obs, next_obs), dim=1, keepdim=True)
			reward_loss += rho * h.mse(reward_pred, reward[t])
			value_loss += rho * (h.mse(Q1, td_target) + h.mse(Q2, td_target))
			priority_loss += rho * (h.l1(Q1, td_target) + h.l1(Q2, td_target))

		# Optimize model
		total_loss = self.cfg.reconstruction_coef * reconstruction_loss.clamp(max=1e4) + \
					 self.cfg.reward_coef * reward_loss.clamp(max=1e4) + \
					 self.cfg.value_coef * value_loss.clamp(max=1e4)
		weighted_loss = (total_loss.squeeze(1) * weights).mean()
		weighted_loss.register_hook(lambda grad: grad * (1/self.cfg.horizon))
		weighted_loss.backward()
		grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False)
		self.optim.step()
		replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach())

		# Update policy + target network
		pi_loss = self.update_pi(zs)
		if step % self.cfg.update_freq == 0:
			h.ema(self.model, self.model_target, self.cfg.tau)

		self.model.eval()
		return {'consistency_loss': float(reconstruction_loss.mean().item()),
				'reward_loss': float(reward_loss.mean().item()),
				'value_loss': float(value_loss.mean().item()),
				'pi_loss': pi_loss,
				'total_loss': float(total_loss.mean().item()),
				'weighted_loss': float(weighted_loss.mean().item()),
				'grad_norm': float(grad_norm)}

Overwriting tdmpc/src/algorithm/tdmpc.py


In [45]:
%%writefile tdmpc/cfgs/default.yaml
# environment
task: quadruped-walk
modality: 'state'
action_repeat: ???
discount: 0.99
episode_length: 1000/${action_repeat}
train_steps: 500000/${action_repeat}

# planning
iterations: 6
num_samples: 512
num_elites: 64
mixture_coef: 0.05
min_std: 0.05
temperature: 0.5
momentum: 0.1

# learning
batch_size: 512
max_buffer_size: 1000000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
reconstruction_coef: 2
rho: 0.5
kappa: 0.1
lr: 1e-3
std_schedule: linear(0.5, ${min_std}, 25000)
horizon_schedule: linear(1, ${horizon}, 25000)
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 5000
update_freq: 2
tau: 0.01

# architecture
enc_dim: 256
mlp_dim: 512
latent_dim: 50

# wandb (insert your own)
use_wandb: true
wandb_project: 'TD-MPC'
wandb_entity: 'askoro'

# misc
seed: 1
exp_name: default
eval_freq: 20000
eval_episodes: 10
save_video: true
save_model: true
save_checkpoint: true
save_freq: 100000

Overwriting tdmpc/cfgs/default.yaml


In [46]:
!source new-env/bin/activate; cd tdmpc; python src/train.py exp_name="quadruped_walk_reconstruction_loss"

--------------------------------------------------------
  [1m[32mTask:           [0m Quadruped Walk
  [1m[32mTrain steps:    [0m 500,000
  [1m[32mObservations:   [0m 78
  [1m[32mActions:        [0m 12
  [1m[32mExperiment:     [0m quadruped_walk_reconstruction_lo...
--------------------------------------------------------
[1m[34mLogs will be synced with wandb.[0m
 [34mtrain[0m   [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 510.8            [30mT:[0m 0:00:00       
 [32meval[0m    [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 58.5             [30mT:[0m 0:00:00       
 [34mtrain[0m   [30mE:[0m 2                [30mS:[0m 1,000            [30mR:[0m 498.2            [30mT:[0m 0:00:55       
 [34mtrain[0m   [30mE:[0m 3                [30mS:[0m 2,000            [30mR:[0m 6.8              [30mT:[0m 0:00:56       
 [34mtrain[0m   [30mE:[0m 4                [30mS:[0m 3,000            [30mR:[0m 