In [1]:
!git clone https://github.com/nicklashansen/tdmpc

Cloning into 'tdmpc'...
remote: Enumerating objects: 108, done.[K
remote: Counting objects: 100% (53/53), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 108 (delta 31), reused 31 (delta 30), pack-reused 55 (from 1)[K
Receiving objects: 100% (108/108), 1.49 MiB | 10.81 MiB/s, done.
Resolving deltas: 100% (38/38), done.


In [2]:
%%writefile tdmpc/src/train.py
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import Episode, ReplayBuffer
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards = []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        if video: video.save(env_step)
    return np.nanmean(episode_rewards)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)
    
    # Run training
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):

        # Collect trajectory
        obs = env.reset()
        episode = Episode(cfg, obs)
        while not episode.done:
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            episode += (obs, action, reward, done)
        assert len(episode) == cfg.episode_length
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))
        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train')
        
        if (cfg.save_checkpoint) and (env_step % cfg.save_freq == 0):
             L.save_checkpoint(agent, step) 
        
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'] = evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            L.log(common_metrics, category='eval')
            
    L.finish(agent)
    print('Training completed successfully')	


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))


Overwriting tdmpc/src/train.py


In [3]:
%%writefile tdmpc/src/logger.py
import sys
import os
import datetime
import re
import numpy as np
import torch
import pandas as pd
from termcolor import colored
from omegaconf import OmegaConf


CONSOLE_FORMAT = [('episode', 'E', 'int'), ('env_step', 'S', 'int'), ('episode_reward', 'R', 'float'), ('total_time', 'T', 'time')]
AGENT_METRICS = ['consistency_loss', 'reward_loss', 'value_loss', 'total_loss', 'weighted_loss', 'pi_loss', 'grad_norm']


def make_dir(dir_path):
	"""Create directory if it does not already exist."""
	try:
		os.makedirs(dir_path)
	except OSError:
		pass
	return dir_path


def print_run(cfg, reward=None):
	"""Pretty-printing of run information. Call at start of training."""
	prefix, color, attrs = '  ', 'green', ['bold']
	def limstr(s, maxlen=32):
		return str(s[:maxlen]) + '...' if len(str(s)) > maxlen else s
	def pprint(k, v):
		print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v))
	kvs = [('task', cfg.task_title),
		   ('train steps', f'{int(cfg.train_steps*cfg.action_repeat):,}'),
		   ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
		   ('actions', cfg.action_dim),
		   ('experiment', cfg.exp_name)]
	if reward is not None:
		kvs.append(('episode reward', colored(str(int(reward)), 'white', attrs=['bold'])))
	w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
	div = '-'*w
	print(div)
	for k,v in kvs:
		pprint(k, v)
	print(div)


def cfg_to_group(cfg, return_list=False):
	"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
	lst = [cfg.task, cfg.modality, re.sub('[^0-9a-zA-Z]+', '-', cfg.exp_name)]
	return lst if return_list else '-'.join(lst)


class VideoRecorder:
	"""Utility class for logging evaluation videos."""
	def __init__(self, root_dir, wandb, render_size=384, fps=15):
		self.save_dir = (root_dir / 'eval_video') if root_dir else None
		self._wandb = wandb
		self.render_size = render_size
		self.fps = fps
		self.frames = []
		self.enabled = False

	def init(self, env, enabled=True):
		self.frames = []
		self.enabled = self.save_dir and self._wandb and enabled
		self.record(env)

	def record(self, env):
		if self.enabled:
			frame = env.render(mode='rgb_array', height=self.render_size, width=self.render_size, camera_id=0)
			self.frames.append(frame)

	def save(self, step):
		if self.enabled:
			frames = np.stack(self.frames).transpose(0, 3, 1, 2)
			self._wandb.log({'eval_video': self._wandb.Video(frames, fps=self.fps, format='mp4')}, step=step)


class Logger(object):
	"""Primary logger object. Logs either locally or using wandb."""
	def __init__(self, log_dir, cfg):
		self._log_dir = make_dir(log_dir)
		self._model_dir = make_dir(self._log_dir / 'models')
		self._save_model = cfg.save_model
		self._group = cfg_to_group(cfg)
		self._seed = cfg.seed
		self._cfg = cfg
		self._eval = []
		print_run(cfg)
		project, entity = cfg.get('wandb_project', 'none'), cfg.get('wandb_entity', 'none')
		run_offline = not cfg.get('use_wandb', False) or project == 'none' or entity == 'none'
		if run_offline:
			print(colored('Logs will be saved locally.', 'yellow', attrs=['bold']))
			self._wandb = None
		else:
			try:
				os.environ["WANDB_SILENT"] = "true"
				import wandb
				wandb.init(project=project,
						entity=entity,
						name=cfg.exp_name,
						group=self._group,
						tags=cfg_to_group(cfg, return_list=True) + [f'seed:{cfg.seed}'],
						dir=self._log_dir,
						config=OmegaConf.to_container(cfg, resolve=True))
				print(colored('Logs will be synced with wandb.', 'blue', attrs=['bold']))
				self._wandb = wandb
			except:
				print(colored('Warning: failed to init wandb. Logs will be saved locally.', 'yellow'), attrs=['bold'])
				self._wandb = None
		self._video = VideoRecorder(log_dir, self._wandb) if self._wandb and cfg.save_video else None

	@property
	def video(self):
		return self._video

	def finish(self, agent):
		if self._save_model:
			fp = self._model_dir / f'model.pt'
			torch.save(agent.state_dict(), fp)
			if self._wandb:
				artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model')
				artifact.add_file(fp)
				self._wandb.log_artifact(artifact)
		if self._wandb:
			self._wandb.finish()
		print_run(self._cfg, self._eval[-1][-1])
		
	def save_checkpoint(self, agent, iteration):
		fp = self._model_dir / f'model_{iteration}.pt'
		torch.save(agent.state_dict(), fp)
		if self._wandb:
			artifact = self._wandb.Artifact(self._group+'-'+str(self._seed)+'-'+str(iteration), type='model')
			artifact.add_file(fp)
			self._wandb.log_artifact(artifact)

	def _format(self, key, value, ty):
		if ty == 'int':
			return f'{colored(key+":", "grey")} {int(value):,}'
		elif ty == 'float':
			return f'{colored(key+":", "grey")} {value:.01f}'
		elif ty == 'time':
			value = str(datetime.timedelta(seconds=int(value)))
			return f'{colored(key+":", "grey")} {value}'
		else:
			raise f'invalid log format type: {ty}'

	def _print(self, d, category):
		category = colored(category, 'blue' if category == 'train' else 'green')
		pieces = [f' {category:<14}']
		for k, disp_k, ty in CONSOLE_FORMAT:
			pieces.append(f'{self._format(disp_k, d.get(k, 0), ty):<26}')
		print('   '.join(pieces))

	def log(self, d, category='train'):
		assert category in {'train', 'eval'}
		if self._wandb is not None:
			for k,v in d.items():
				self._wandb.log({category + '/' + k: v}, step=d['env_step'])
		if category == 'eval':
			keys = ['env_step', 'episode_reward']
			self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
			pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / 'eval.log', header=keys, index=None)
		self._print(d, category)


Overwriting tdmpc/src/logger.py


In [4]:
%%writefile tdmpc/cfgs/default.yaml
# environment
task: quadruped-walk
modality: 'state'
action_repeat: ???
discount: 0.99
episode_length: 1000/${action_repeat}
train_steps: 500000/${action_repeat}

# planning
iterations: 6
num_samples: 512
num_elites: 64
mixture_coef: 0.05
min_std: 0.05
temperature: 0.5
momentum: 0.1

# learning
batch_size: 512
max_buffer_size: 1000000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 2
rho: 0.5
kappa: 0.1
lr: 1e-3
std_schedule: linear(0.5, ${min_std}, 25000)
horizon_schedule: linear(1, ${horizon}, 25000)
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 5000
update_freq: 2
tau: 0.01

# architecture
enc_dim: 256
mlp_dim: 512
latent_dim: 50

# wandb (insert your own)
use_wandb: true
wandb_project: 'TD-MPC'
wandb_entity: 'kurdun_maria'

# misc
seed: 1
exp_name: default
eval_freq: 20000
eval_episodes: 10
save_video: true
save_model: true
save_checkpoint: true
save_freq: 100000

Overwriting tdmpc/cfgs/default.yaml


In [5]:
!sudo apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libgl1-mesa-glx is already the newest version (23.0.4-0ubuntu1~22.04.1).
The following additional packages will be installed:
  libdrm-dev libgl-dev libglx-dev libosmesa6 libpciaccess-dev mesa-common-dev
Suggested packages:
  libgles1 libvulkan1
The following NEW packages will be installed:
  libdrm-dev libgl-dev libglfw3 libglx-dev libosmesa6 libosmesa6-dev
  libpciaccess-dev mesa-common-dev patchelf
0 upgraded, 9 newly installed, 0 to remove and 129 not upgraded.
Need to get 5,916 kB of archives.
After this operation, 19.0 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 libpciaccess-dev amd64 0.16-3 [21.9 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 libdrm-dev amd64 2.4.113-2~ubuntu0.22.04.1 [292 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/main amd64 libglx-dev amd64 1.4.0-1 [14.1 kB]
Get:4 http://archive.ubuntu

In [6]:
!mkdir -p /root/.mujoco

In [7]:
!wget https://github.com/google-deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz

--2025-03-30 09:21:45--  https://github.com/google-deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/400501136/1f51148e-4e64-4a12-a400-d6f1e21be444?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250330%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250330T092145Z&X-Amz-Expires=300&X-Amz-Signature=229a2333ea58880740a1fdb4f41b962017f790be652aff928399477e6694f38b&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dmujoco210-linux-x86_64.tar.gz&response-content-type=application%2Foctet-stream [following]
--2025-03-30 09:21:45--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/400501136/1f51148e-4e64-4a12-a400-d6f1e21be444?X-Amz-Algor

In [8]:
!tar -xf mujoco.tar.gz -C /root/.mujoco

In [9]:
!sudo apt-get install python3.8 python3.8-dev python3.8-distutils
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1
!sudo apt-get install python3.8-venv -y
!sudo update-alternatives --config python3

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libpython3.8 libpython3.8-dev libpython3.8-minimal libpython3.8-stdlib
  mailcap mime-support python3.8-lib2to3 python3.8-minimal
Suggested packages:
  python3.8-venv binfmt-support
The following NEW packages will be installed:
  libpython3.8 libpython3.8-dev libpython3.8-minimal libpython3.8-stdlib
  mailcap mime-support python3.8 python3.8-dev python3.8-distutils
  python3.8-lib2to3 python3.8-minimal
0 upgraded, 11 newly installed, 0 to remove and 129 not upgraded.
Need to get 12.1 MB of archives.
After this operation, 45.2 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/main amd64 mailcap all 3.70+nmu1ubuntu1 [23.8 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/main amd64 mime-support all 3.66 [3,696 B]
Get:3 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy/main amd64 libpython3.8-

In [10]:
!sudo apt-get install python3-pip

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  python3-setuptools python3-wheel
Suggested packages:
  python-setuptools-doc
The following NEW packages will be installed:
  python3-pip python3-setuptools python3-wheel
0 upgraded, 3 newly installed, 0 to remove and 129 not upgraded.
Need to get 1,677 kB of archives.
After this operation, 8,968 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 python3-setuptools all 59.6.0-1.2ubuntu0.22.04.2 [340 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 python3-wheel all 0.37.1-2ubuntu0.22.04.1 [32.0 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 python3-pip all 22.0.2+dfsg-1ubuntu0.5 [1,306 kB]
Fetched 1,677 kB in 1s (1,490 kB/s)  
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dia

In [11]:
!python3.8 -m venv new-env

In [12]:
!source new-env/bin/activate

In [13]:
!source new-env/bin/activate; echo  $VIRTUAL_ENV

/kaggle/working/new-env


In [14]:
!source new-env/bin/activate; python --version

Python 3.8.20


In [15]:
!source new-env/bin/activate; pip install \
  wheel \
  torch==1.9.0+cu111 \
  torchvision==0.10.0+cu111 \
  -f https://download.pytorch.org/whl/torch_stable.html \
  termcolor \
  omegaconf \
  gym==0.21.0 \
  dm-control==0.0.403778684 \
  pandas \
  wandb \
  moviepy \
  imageio \
  numpy==1.19.5


Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting wheel
  Downloading wheel-0.45.1-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.5/72.5 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch==1.9.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torch-1.9.0%2Bcu111-cp38-cp38-linux_x86_64.whl (2041.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 GB[0m [31m421.3 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision==0.10.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torchvision-0.10.0%2Bcu111-cp38-cp38-linux_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting termcolor
  Downloading termcolor-2.4.0-py3-none-any.whl (7.7 kB)
Collecting omegaconf
  Downloading omegaconf-2.3.0-py3-none-any.whl (79 k

In [16]:
import os
os.environ["WANDB_API_KEY"] = ""

In [32]:
!source new-env/bin/activate; cd tdmpc ; python  src/train.py exp_name=walk_base

-----------------------------------
  [1m[32mTask:           [0m Quadruped Walk
  [1m[32mTrain steps:    [0m 500,000
  [1m[32mObservations:   [0m 78
  [1m[32mActions:        [0m 12
  [1m[32mExperiment:     [0m walk_base
-----------------------------------
[1m[34mLogs will be synced with wandb.[0m
 [34mtrain[0m   [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 510.8            [30mT:[0m 0:00:00       
 [32meval[0m    [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 58.5             [30mT:[0m 0:00:00       
 [34mtrain[0m   [30mE:[0m 2                [30mS:[0m 1,000            [30mR:[0m 498.2            [30mT:[0m 0:01:01       
 [34mtrain[0m   [30mE:[0m 3                [30mS:[0m 2,000            [30mR:[0m 6.8              [30mT:[0m 0:01:02       
 [34mtrain[0m   [30mE:[0m 4                [30mS:[0m 3,000            [30mR:[0m 211.6            [30mT:[0m 0:01:03       
 [34mtrain[0m   [30mE

*Некоторые ячейки без логов так как запускались с другого аккаунта для экономии времени

In [None]:
!source new-env/bin/activate; cd tdmpc ; python  src/train.py exp_name=run_base task=quadruped-run

# Добавим finetuning & freezing

In [17]:
%%writefile tdmpc/src/train.py
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
import torch
import numpy as np
import gym
gym.logger.set_level(40)
import time
import random
from pathlib import Path
from cfg import parse_cfg
from env import make_env
from algorithm.tdmpc import TDMPC
from algorithm.helper import Episode, ReplayBuffer
import algorithm.helper as h
import logger
torch.backends.cudnn.benchmark = True
__CONFIG__, __LOGS__ = 'cfgs', 'logs'


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def evaluate(env, agent, num_episodes, step, env_step, video):
    """Evaluate a trained agent and optionally save a video."""
    episode_rewards = []
    for i in range(num_episodes):
        obs, done, ep_reward, t = env.reset(), False, 0, 0
        if video: video.init(env, enabled=(i==0))
        while not done:
            action = agent.plan(obs, eval_mode=True, step=step, t0=t==0)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            ep_reward += reward
            if video: video.record(env)
            t += 1
        episode_rewards.append(ep_reward)
        if video: video.save(env_step)
    return np.nanmean(episode_rewards)


def train(cfg):
    """Training script for TD-MPC. Requires a CUDA-enabled device."""
    assert torch.cuda.is_available()
    set_seed(cfg.seed)
    work_dir = Path().cwd() / __LOGS__ / cfg.task / cfg.modality / cfg.exp_name / str(cfg.seed)
    env, agent, buffer = make_env(cfg), TDMPC(cfg), ReplayBuffer(cfg)
    
    if cfg.checkpoint_path is not None:
        agent.load(cfg.checkpoint_path)
        
    if cfg.frozen_h:
        h.set_requires_grad(agent.model._encoder, False)
        
    if cfg.frozen_d:
        h.set_requires_grad(agent.model._dynamics, False)

    print('Dynamics requires grad = ',agent.model._dynamics[0].weight.requires_grad)
    if not isinstance(agent.model._encoder, torch.nn.Identity):
        print('Encoder requires grad = ',agent.model._encoder[0].weight.requires_grad)
    
    # Run training
    L = logger.Logger(work_dir, cfg)
    episode_idx, start_time = 0, time.time()
    for step in range(0, cfg.train_steps+cfg.episode_length, cfg.episode_length):

        # Collect trajectory
        obs = env.reset()
        episode = Episode(cfg, obs)
        while not episode.done:
            action = agent.plan(obs, step=step, t0=episode.first)
            obs, reward, done, _ = env.step(action.cpu().numpy())
            episode += (obs, action, reward, done)
        assert len(episode) == cfg.episode_length
        buffer += episode

        # Update model
        train_metrics = {}
        if step >= cfg.seed_steps:
            num_updates = cfg.seed_steps if step == cfg.seed_steps else cfg.episode_length
            for i in range(num_updates):
                train_metrics.update(agent.update(buffer, step+i))
        # Log training episode
        episode_idx += 1
        env_step = int(step*cfg.action_repeat)
        common_metrics = {
            'episode': episode_idx,
            'step': step,
            'env_step': env_step,
            'total_time': time.time() - start_time,
            'episode_reward': episode.cumulative_reward}
        train_metrics.update(common_metrics)
        L.log(train_metrics, category='train')
        
        if (cfg.save_checkpoint) and (env_step % cfg.save_freq == 0):
             L.save_checkpoint(agent, step) 
        
        if env_step % cfg.eval_freq == 0:
            common_metrics['episode_reward'] = evaluate(env, agent, cfg.eval_episodes, step, env_step, L.video)
            L.log(common_metrics, category='eval')
            
    L.finish(agent)
    print('Training completed successfully')	


if __name__ == '__main__':
    train(parse_cfg(Path().cwd() / __CONFIG__))


Overwriting tdmpc/src/train.py


In [18]:
%%writefile tdmpc/cfgs/default.yaml
# environment
task: quadruped-walk
modality: 'state'
action_repeat: ???
discount: 0.99
episode_length: 1000/${action_repeat}
train_steps: 500000/${action_repeat}

# planning
iterations: 6
num_samples: 512
num_elites: 64
mixture_coef: 0.05
min_std: 0.05
temperature: 0.5
momentum: 0.1

# learning
batch_size: 512
max_buffer_size: 1000000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 2
rho: 0.5
kappa: 0.1
lr: 1e-3
std_schedule: linear(0.5, ${min_std}, 25000)
horizon_schedule: linear(1, ${horizon}, 25000)
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 5000
update_freq: 2
tau: 0.01

# architecture
enc_dim: 256
mlp_dim: 512
latent_dim: 50

# wandb (insert your own)
use_wandb: true
wandb_project: 'TD-MPC'
wandb_entity: 'kurdun_maria'

# misc
seed: 1
exp_name: default
eval_freq: 20000
eval_episodes: 10
save_video: true
save_model: true
save_checkpoint: true
save_freq: 100000
checkpoint_path: none
frozen_h: false
frozen_d: false

Overwriting tdmpc/cfgs/default.yaml


In [52]:
!source new-env/bin/activate; cd tdmpc ; python src/train.py task=quadruped-run \
                                                            exp_name=run_transfer_full_finetune \
                                                            checkpoint_path='/kaggle/working/tdmpc/logs/quadruped-walk/state/walk_base/1/models/model.pt'

Dynamics requires grad =  True
Encoder requires grad =  True
-----------------------------------------------
  [1m[32mTask:           [0m Quadruped Run
  [1m[32mTrain steps:    [0m 500,000
  [1m[32mObservations:   [0m 78
  [1m[32mActions:        [0m 12
  [1m[32mExperiment:     [0m run_transfer_full_finetune
-----------------------------------------------
[1m[34mLogs will be synced with wandb.[0m
 [34mtrain[0m   [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 491.3            [30mT:[0m 0:00:00       
 [32meval[0m    [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 528.6            [30mT:[0m 0:00:00       
 [34mtrain[0m   [30mE:[0m 2                [30mS:[0m 1,000            [30mR:[0m 491.9            [30mT:[0m 0:01:09       
 [34mtrain[0m   [30mE:[0m 3                [30mS:[0m 2,000            [30mR:[0m 7.3              [30mT:[0m 0:01:10       
 [34mtrain[0m   [30mE:[0m 4                [30mS:

In [None]:
!source new-env/bin/activate; cd tdmpc ; python src/train.py task=quadruped-run \
                                                            exp_name=run_transfer_finetune_frozen_h \
                                                            frozen_h=true \
                                                            checkpoint_path='/kaggle/input/model_walk_base/pytorch/default/1/model_walk_base.pt'

In [None]:
!source new-env/bin/activate; cd tdmpc ; python src/train.py task=quadruped-run \
                                                            exp_name=run_transfer_finetune_frozen_h_d \
                                                            frozen_h=true \
                                                            frozen_d=true \
                                                            checkpoint_path='/kaggle/input/model_walk_base/pytorch/default/1/model_walk_base.pt'

In [20]:
%%writefile tdmpc/src/algorithm/tdmpc.py
import numpy as np
import torch
import torch.nn as nn
from copy import deepcopy
import algorithm.helper as h


class TOLD(nn.Module):
	"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
	def __init__(self, cfg):
		super().__init__()
		self.cfg = cfg
		if self.cfg.latent:
			self._encoder = h.enc(cfg)
		else:
			self._encoder = nn.Identity()
			cfg.latent_dim = cfg.obs_shape[0]
		print('Encoder = ', self._encoder)
		self._dynamics = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, cfg.latent_dim)
		self._reward = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, 1)
		self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim)
		self._Q1, self._Q2 = h.q(cfg), h.q(cfg)
		self.apply(h.orthogonal_init)
		for m in [self._reward, self._Q1, self._Q2]:
			m[-1].weight.data.fill_(0)
			m[-1].bias.data.fill_(0)

	def track_q_grad(self, enable=True):
		"""Utility function. Enables/disables gradient tracking of Q-networks."""
		for m in [self._Q1, self._Q2]:
			h.set_requires_grad(m, enable)

	def h(self, obs):
		"""Encodes an observation into its latent representation (h)."""
		return self._encoder(obs)

	def next(self, z, a):
		"""Predicts next latent state (d) and single-step reward (R)."""
		x = torch.cat([z, a], dim=-1)
		return self._dynamics(x), self._reward(x)

	def pi(self, z, std=0):
		"""Samples an action from the learned policy (pi)."""
		mu = torch.tanh(self._pi(z))
		if std > 0:
			std = torch.ones_like(mu) * std
			return h.TruncatedNormal(mu, std).sample(clip=0.3)
		return mu

	def Q(self, z, a):
		"""Predict state-action value (Q)."""
		x = torch.cat([z, a], dim=-1)
		return self._Q1(x), self._Q2(x)


class TDMPC():
	"""Implementation of TD-MPC learning + inference."""
	def __init__(self, cfg):
		self.cfg = cfg
		self.device = torch.device('cuda')
		self.std = h.linear_schedule(cfg.std_schedule, 0)
		self.model = TOLD(cfg).cuda()
		self.model_target = deepcopy(self.model)
		self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
		self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
		self.aug = h.RandomShiftsAug(cfg)
		self.model.eval()
		self.model_target.eval()

	def state_dict(self):
		"""Retrieve state dict of TOLD model, including slow-moving target network."""
		return {'model': self.model.state_dict(),
				'model_target': self.model_target.state_dict()}

	def save(self, fp):
		"""Save state dict of TOLD model to filepath."""
		torch.save(self.state_dict(), fp)
	
	def load(self, fp):
		"""Load a saved state dict from filepath into current agent."""
		d = torch.load(fp)
		self.model.load_state_dict(d['model'])
		self.model_target.load_state_dict(d['model_target'])

	@torch.no_grad()
	def estimate_value(self, z, actions, horizon):
		"""Estimate value of a trajectory starting at latent state z and executing given actions."""
		G, discount = 0, 1
		for t in range(horizon):
			z, reward = self.model.next(z, actions[t])
			G += discount * reward
			discount *= self.cfg.discount
		G += discount * torch.min(*self.model.Q(z, self.model.pi(z, self.cfg.min_std)))
		return G

	@torch.no_grad()
	def plan(self, obs, eval_mode=False, step=None, t0=True):
		"""
		Plan next action using TD-MPC inference.
		obs: raw input observation.
		eval_mode: uniform sampling and action noise is disabled during evaluation.
		step: current time step. determines e.g. planning horizon.
		t0: whether current step is the first step of an episode.
		"""
		# Seed steps
		if step < self.cfg.seed_steps and not eval_mode:
			return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1)

		# Sample policy trajectories
		obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
		horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step)))
		num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples)
		if num_pi_trajs > 0:
			pi_actions = torch.empty(horizon, num_pi_trajs, self.cfg.action_dim, device=self.device)
			z = self.model.h(obs).repeat(num_pi_trajs, 1)
			for t in range(horizon):
				pi_actions[t] = self.model.pi(z, self.cfg.min_std)
				z, _ = self.model.next(z, pi_actions[t])

		# Initialize state and parameters
		z = self.model.h(obs).repeat(self.cfg.num_samples+num_pi_trajs, 1)
		mean = torch.zeros(horizon, self.cfg.action_dim, device=self.device)
		std = 2*torch.ones(horizon, self.cfg.action_dim, device=self.device)
		if not t0 and hasattr(self, '_prev_mean'):
			mean[:-1] = self._prev_mean[1:]

		# Iterate CEM
		for i in range(self.cfg.iterations):
			actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * \
				torch.randn(horizon, self.cfg.num_samples, self.cfg.action_dim, device=std.device), -1, 1)
			if num_pi_trajs > 0:
				actions = torch.cat([actions, pi_actions], dim=1)

			# Compute elite actions
			value = self.estimate_value(z, actions, horizon).nan_to_num_(0)
			elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
			elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]

			# Update parameters
			max_value = elite_value.max(0)[0]
			score = torch.exp(self.cfg.temperature*(elite_value - max_value))
			score /= score.sum(0)
			_mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
			_std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - _mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9))
			_std = _std.clamp_(self.std, 2)
			mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std

		# Outputs
		score = score.squeeze(1).cpu().numpy()
		actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
		self._prev_mean = mean
		mean, std = actions[0], _std[0]
		a = mean
		if not eval_mode:
			a += std * torch.randn(self.cfg.action_dim, device=std.device)
		return a

	def update_pi(self, zs):
		"""Update policy using a sequence of latent states."""
		self.pi_optim.zero_grad(set_to_none=True)
		self.model.track_q_grad(False)

		# Loss is a weighted sum of Q-values
		pi_loss = 0
		for t,z in enumerate(zs):
			a = self.model.pi(z, self.cfg.min_std)
			Q = torch.min(*self.model.Q(z, a))
			pi_loss += -Q.mean() * (self.cfg.rho ** t)

		pi_loss.backward()
		torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False)
		self.pi_optim.step()
		self.model.track_q_grad(True)
		return pi_loss.item()

	@torch.no_grad()
	def _td_target(self, next_obs, reward):
		"""Compute the TD-target from a reward and the observation at the following time step."""
		next_z = self.model.h(next_obs)
		td_target = reward + self.cfg.discount * \
			torch.min(*self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)))
		return td_target

	def update(self, replay_buffer, step):
		"""Main update function. Corresponds to one iteration of the TOLD model learning."""
		obs, next_obses, action, reward, idxs, weights = replay_buffer.sample()
		self.optim.zero_grad(set_to_none=True)
		self.std = h.linear_schedule(self.cfg.std_schedule, step)
		self.model.train()

		# Representation
		z = self.model.h(self.aug(obs))
		zs = [z.detach()]

		consistency_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0
		for t in range(self.cfg.horizon):

			# Predictions
			Q1, Q2 = self.model.Q(z, action[t])
			z, reward_pred = self.model.next(z, action[t])
			with torch.no_grad():
				next_obs = self.aug(next_obses[t])
				next_z = self.model_target.h(next_obs)
				td_target = self._td_target(next_obs, reward[t])
			zs.append(z.detach())

			# Losses
			rho = (self.cfg.rho ** t)
			consistency_loss += rho * torch.mean(h.mse(z, next_z), dim=1, keepdim=True)
			reward_loss += rho * h.mse(reward_pred, reward[t])
			value_loss += rho * (h.mse(Q1, td_target) + h.mse(Q2, td_target))
			priority_loss += rho * (h.l1(Q1, td_target) + h.l1(Q2, td_target))

		# Optimize model
		total_loss = self.cfg.consistency_coef * consistency_loss.clamp(max=1e4) + \
					 self.cfg.reward_coef * reward_loss.clamp(max=1e4) + \
					 self.cfg.value_coef * value_loss.clamp(max=1e4)
		weighted_loss = (total_loss.squeeze(1) * weights).mean()
		weighted_loss.register_hook(lambda grad: grad * (1/self.cfg.horizon))
		weighted_loss.backward()
		grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False)
		self.optim.step()
		replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach())

		# Update policy + target network
		pi_loss = self.update_pi(zs)
		if step % self.cfg.update_freq == 0:
			h.ema(self.model, self.model_target, self.cfg.tau)

		self.model.eval()
		return {'consistency_loss': float(consistency_loss.mean().item()),
				'reward_loss': float(reward_loss.mean().item()),
				'value_loss': float(value_loss.mean().item()),
				'pi_loss': pi_loss,
				'total_loss': float(total_loss.mean().item()),
				'weighted_loss': float(weighted_loss.mean().item()),
				'grad_norm': float(grad_norm)}


Overwriting tdmpc/src/algorithm/tdmpc.py


In [21]:
%%writefile tdmpc/cfgs/default.yaml
# environment
task: quadruped-walk
modality: 'state'
action_repeat: ???
discount: 0.99
episode_length: 1000/${action_repeat}
train_steps: 500000/${action_repeat}

# planning
iterations: 6
num_samples: 512
num_elites: 64
mixture_coef: 0.05
min_std: 0.05
temperature: 0.5
momentum: 0.1

# learning
batch_size: 512
max_buffer_size: 1000000
horizon: 5
reward_coef: 0.5
value_coef: 0.1
consistency_coef: 2
rho: 0.5
kappa: 0.1
lr: 1e-3
std_schedule: linear(0.5, ${min_std}, 25000)
horizon_schedule: linear(1, ${horizon}, 25000)
per_alpha: 0.6
per_beta: 0.4
grad_clip_norm: 10
seed_steps: 5000
update_freq: 2
tau: 0.01

# architecture
enc_dim: 256
mlp_dim: 512
latent_dim: 50
latent: true

# wandb (insert your own)
use_wandb: true
wandb_project: 'TD-MPC'
wandb_entity: 'kurdun_maria'

# misc
seed: 1
exp_name: default
eval_freq: 20000
eval_episodes: 10
save_video: true
save_model: true
save_checkpoint: true
save_freq: 100000
checkpoint_path: null
frozen_h: false
frozen_d: false

Overwriting tdmpc/cfgs/default.yaml


In [22]:
!source new-env/bin/activate; cd tdmpc ; python src/train.py task=quadruped-walk \
                                                            exp_name=walk_without_latent \
                                                            latent=false

Encoder =  Identity()
Dynamics requires grad =  True
----------------------------------------
  [1m[32mTask:           [0m Quadruped Walk
  [1m[32mTrain steps:    [0m 500,000
  [1m[32mObservations:   [0m 78
  [1m[32mActions:        [0m 12
  [1m[32mExperiment:     [0m walk_without_latent
----------------------------------------
[1m[34mLogs will be synced with wandb.[0m
 [34mtrain[0m   [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 510.8            [30mT:[0m 0:00:00       
 [32meval[0m    [30mE:[0m 1                [30mS:[0m 0                [30mR:[0m 58.5             [30mT:[0m 0:00:00       
 [34mtrain[0m   [30mE:[0m 2                [30mS:[0m 1,000            [30mR:[0m 498.2            [30mT:[0m 0:00:57       
 [34mtrain[0m   [30mE:[0m 3                [30mS:[0m 2,000            [30mR:[0m 6.8              [30mT:[0m 0:00:58       
 [34mtrain[0m   [30mE:[0m 4                [30mS:[0m 3,000            [30mR: