In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch
import argparse
from environment import atari_env
from utils import read_config
from torchvision import transforms
import numpy as np
from torch.optim import Adam
from environment import atari_env
from rnet import RepresentNet, TDClass
torch.manual_seed(1)

parser = argparse.ArgumentParser(description='A3C')
parser.add_argument(
    '--env',
    default='SpaceInvaders-v0',
    metavar='ENV',
    help='environment to train on (default: Pong-v0)')
parser.add_argument(
    '--env-config',
    default='config.json',
    metavar='EC',
    help='environment to crop and resize info (default: config.json)')
parser.add_argument(
    '--skip-rate',
    type=int,
    default=4,
    metavar='SR',
    help='frame skip rate (default: 4)')
parser.add_argument(
    '--max-episode-length',
    type=int,
    default=10000,
    metavar='M',
    help='maximum length of an episode (default: 10000)')
args = parser.parse_args([])
setup_json = read_config(args.env_config)
env_conf = setup_json["Default"]
for i in setup_json.keys():
    if i in args.env:
        env_conf = setup_json[i]


In [None]:
transform = transforms.ToTensor()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

r_net = RepresentNet().to(device)
c_net = TDClass().to(device)
optimizer_r = Adam(r_net.parameters(), lr=5e-5)
optimizer_c = Adam(c_net.parameters(), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()

batch_size = 64
trace = []
td_class = [(0,1),(1,2),(2,3),(3,5),(5,7),(7,9)]

# env = atari_env(args.env, env_conf, args)  # gym.make("Pong-v0")
# # 先累积轨迹
# s = env.reset()
# print(s.shape)
# for i in range(10000):
#     s = torch.from_numpy(s).float().to(device)
#     trace.append(s)
#     s, r, done, _ = env.step(env.action_space.sample())
#     if done:
#         print("reset")
#         s = env.reset()
#         for i in range(20):  # 前面20帧显示不完整
#             s, r, done, _ = env.step(env.action_space.sample())
#     if i%1000==0:
#         print("accumulate", i)
for i in range(11):
    trace_e = np.load("traces/save/pong-s-{}.npy".format(i))
    trace_e = torch.from_numpy(trace_e).float()
    trace.append(trace_e)
trace = torch.cat(tuple(trace), dim=0)
print("trace:", trace.shape)

print("begin train")
# 训练
for epoch in range(5000000):
    range_c = np.random.randint(0, len(td_class))
    TD = np.random.randint(td_class[range_c][0], td_class[range_c][1])
    begin = np.random.randint(0, len(trace)-TD-batch_size)
    former = trace[begin:begin+batch_size].to(device)
    latter = trace[begin+TD:begin+TD+batch_size].to(device)
    target = torch.zeros(batch_size, dtype=torch.long).to(device) + range_c

    rep_f, rep_l = r_net(former), r_net(latter)
    output = c_net(rep_f, rep_l, epoch%30==0)
    loss = loss_fn(output, target)

    optimizer_r.zero_grad()
    optimizer_c.zero_grad()
    loss.backward()
    if epoch%500==0:
#   print(c_net.l_2.weight.grad.max())
        print("grad", r_net.conv1.weight.grad.mean().item())
#     nn.utils.clip_grad_norm_(r_net.parameters(), 20.)
#     nn.utils.clip_grad_norm_(c_net.parameters(), 20.)
    optimizer_r.step()
    optimizer_c.step()
    if epoch%500==0:
        print("range {} TD {} loss {:.3f}".format(range_c, TD, loss.item()))
    if epoch%10000==0:
        print(output)
        # 测试
        all_num = 0
        accu_num = 0
        for _ in range(50):
            range_c = np.random.randint(0, len(td_class))
            TD = np.random.randint(td_class[range_c][0], td_class[range_c][1])
            begin = np.random.randint(0, len(trace)-TD-batch_size)
            former = trace[begin:begin+batch_size].to(device)
            latter = trace[begin+TD:begin+TD+batch_size].to(device)
            target = torch.zeros(batch_size, dtype=torch.long).to(device) + range_c
            rep_f, rep_l = r_net(former), r_net(latter)
            output = c_net(rep_f, rep_l, False)
            accu_num += (output.max(dim=1)[1]==target).sum().item()
            all_num += batch_size
        print("accu_rate {:.4f}".format(accu_num/all_num))

In [None]:
torch.save(r_net.cpu(), "pre_models/r_net_5wsam.pkl")
torch.save(c_net.cpu(), "pre_models/c_net_5wsam.pkl")