-
Notifications
You must be signed in to change notification settings - Fork 1
/
normalizer.py
33 lines (28 loc) · 1.61 KB
/
normalizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import numpy as np
from dataset import get_dataset
env_names = ["hopper", "walker2d", "ant", "halfcheetah"]
def normalize(env_name):
TS = torch.load("data/"+env_name+"/TS-read-again-unnormalized.pt")
TA = torch.load("data/"+env_name+"/TA-read-again-unnormalized.pt")
torch.save(TS, 'data/'+env_name+"/TS-unnormalized.pt")
torch.save(TA, "data/"+env_name+"/TA-unnormalized.pt")
states_TA, actions_TA, next_states_TA, terminals_TA, steps_TA, is_TS_TA = get_dataset(TA)
states_TS, actions_TS, next_states_TS, terminals_TS, steps_TS, is_TS_TS = get_dataset(TS)
TA_mean, TA_std = states_TA.mean(dim=0), states_TA.std(dim=0) + 1e-10
print("mean:", TA_mean, "std:", TA_std)
states_TA = (states_TA - TA_mean.view(1, -1)) / TA_std.view(1, -1)
states_TS = (states_TS - TA_mean.view(1, -1)) / TA_std.view(1, -1)
next_states_TA = (next_states_TA - TA_mean.view(1, -1)) / TA_std.view(1, -1)
next_states_TS = (next_states_TS - TA_mean.view(1, -1)) / TA_std.view(1, -1)
assert len(TS) == states_TS.shape[0] and len(TA) == states_TA.shape[0], "Error!"
#print(states_TS[0], TS[0]["state"])
#exit(0)
for i in range(len(TS)):
TS[i]["state"], TS[i]["next_state"] = states_TS[i].cpu().numpy(), next_states_TS[i].cpu().numpy()
for i in range(len(TA)):
TA[i]["state"], TA[i]["next_state"] = states_TA[i].cpu().numpy(), next_states_TA[i].cpu().numpy()
torch.save(TS, 'data/'+env_name+"/TS-normalized.pt")
torch.save(TA, "data/"+env_name+"/TA-normalized.pt")
for env_name in env_names:
normalize(env_name)