Skip to content

YangRui2015/UWMSG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Uncertainty Weighted MSG for Offline RL with Data Corruption

This repo contains the official implemented UWMSG algorithm for the NeurIPS 2023 paper "Corruption-Robust Offline Reinforcement Learning with General Function Approximation".

Getting started

First install torch>=1.7.0, wandb, numpy, gym, d4rl, tqdm, pyrallis

Random Attack

Run UWMSG with random reward corruption:

CUDA_VISIBLE_DEVICES=${gpu} python UWMSG.py --random_corruption  --corruption_reward --corruption_range ${corruption_range} --corruption_rate ${corruption_rate}  --env_name ${env_name} --seed ${seed} --use_UW  # optional: --use_default_parameter

{env_name} can be 'halfcheetah-medium-v2', 'walker2d-medium-replay-v2', and 'hopper-medium-replay-v2'. {corruption_range} and {corruption_rate} are hyperparameters listed in our appendix. You can use the default hyperparameters by replacing the two hyperparameter with '--use_default_parameter'.

Run UWMSG with random dynamics corruption:

CUDA_VISIBLE_DEVICES=${gpu} python UWMSG.py --random_corruption  --corruption_dynamics --corruption_range ${corruption_range} --corruption_rate ${corruption_rate}  --env_name ${env_name} --seed ${seed} --use_UW  

Adversarial Attack

Run UWMSG with adversarial reward corruption:

CUDA_VISIBLE_DEVICES=${gpu} python UWMSG.py --corruption_reward --corruption_range ${corruption_range} --corruption_rate ${corruption_rate}  --env_name ${env_name} --seed ${seed} --use_UW 

Run UWMSG with adversarial dynamics corruption:

CUDA_VISIBLE_DEVICES=${gpu} python UWMSG.py  --corruption_dynamics --corruption_range ${corruption_range} --corruption_rate ${corruption_rate}  --env_name ${env_name} --seed ${seed} --use_UW 

Note the adversarial dynamics attack needs to load an offline dataset in the 'load_attack_data' directory with corresponding attack ratio and attack scale. The data can be generated by setting 'corrupt_model_path' and 'gradient_attack' in UWMSG.py or EDAC.py. We use MSG for pretraining on the halfcheetah and walker2d tasks, and EDAC for pretraining on the hopper task.

Baselines

You can replace the UWMSG.py with SACN.py and EDAC.py to run SACN and EDAC. In addition, by removing the flag '--use_UW', you can run the MSG algorithm.

Citation

If you find our work helpful for your research, please cite:

@inproceedings{
ye2023corruptionrobust,
title={Corruption-Robust Offline Reinforcement Learning with General Function Approximation},
author={Ye, Chenlu and Yang, Rui and Gu, Quanquan and Zhang, Tong},
booktitle={Advances in Neural Information Processing Systems},
year={2023}
url={https://openreview.net/forum?id=K9M7XNS9BX}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages