In [1]:
import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utils
from datetime import datetime

In [2]:
from SheepDogEnv import SheepDogEnv


def sample_expert_data(n_episode):
    states = []
    actions = []
    with tqdm(total=n_episode, desc="进度条") as pbar:
        for i in range(n_episode):
            done = False
            while not done:
                _st = env._get_obs_array()
                L1 = _st[0]
                L2 = env.circle_R
                theta1 = env.sheep_polar_coor[1]
                theta2 = env.dog_theta[0]
                if(L1**2 + L2**2 - 2*L1*L2*np.cos(theta2-theta1)) < 0:
                    break
                L3 = np.sqrt(L1**2 + L2**2 - 2*L1*L2*np.cos(theta2-theta1))
                if(L3 == 0):
                    break
                theta3 = np.arcsin(L2/L3*np.sin(theta2-theta1))
                if(np.abs(theta2-theta1) < np.arccos(L1/L2)):
                    theta3 = np.pi/2 if theta3 > 0 else -np.pi/2
                action = theta3

                # if(np.abs(_st[3]/env.dog_theta_v) > (env.circle_R-_st[0])/env.sheep_v + 1):
                #     action = 0

                _st[0] /= env.circle_R
                states.append(_st)
                actions.append(action)
                _, _, done, _, _ = env.step(action)
                # print(i, _st, action)
            pbar.update(1)
            env.reset()
            env.sheep_polar_coor = np.array(
                [np.random.randint(env.sheep_v, env.circle_R-1), np.random.random()*np.pi*2])
            env.dog_theta = np.array(
                [env.sheep_polar_coor[1]+np.pi/2+np.pi*np.random.random()])
    return np.array(states), np.array(actions)


env = SheepDogEnv(circle_R=350, sheep_v=30, dog_v=80,
                  sec_split_n=5, store_mode=False, render_mode=False)

random.seed(0)
n_episode = 20000
expert_s, expert_a = sample_expert_data(n_episode)

print(len(expert_s))
# n_samples = 30  # 采样30个数据
# random_index = random.sample(range(expert_s.shape[0]), n_samples)
# expert_s = expert_s[random_index]
# expert_a = expert_a[random_index]


进度条: 100%|██████████| 20000/20000 [02:34<00:00, 129.08it/s]


3889677


In [55]:
# index = expert_s[:,3]!=0

# print(len(expert_s))
# expert_a=expert_a[index]
# expert_s=expert_s[index]
print(len(expert_s))

np.save("expert_a-{}.npy".format(str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))),expert_a)
np.save("expert_s-{}.npy".format(str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))),expert_s)

1135509


In [38]:
import os

def list_files_with_prefix(directory, prefix):
    file_list = [file for file in os.listdir(directory) if file.startswith(prefix)]
    return file_list

expert_s_files = list_files_with_prefix('./', 'expert_s')
expert_a_files = list_files_with_prefix('./', 'expert_a')

# expert_s=np.array([[0,0,0,0]])
# expert_a=np.array([])
for expert_s_file in  expert_s_files:
    tmp_expert_s=np.load(expert_s_file)
    expert_s=np.concatenate((expert_s,tmp_expert_s),axis=0)
# expert_s=expert_s[1:]
for expert_a_file in  expert_a_files:
    tmp_expert_a=np.load(expert_a_file)
    expert_a=np.concatenate((expert_a,tmp_expert_a),axis=0)

print(len(expert_s),len(expert_a))

10388787 10388788


In [21]:
expert_s=expert_s[:,[0,3]]

In [3]:
class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = np.pi/2*torch.tanh(self.fc3(x))
        return x
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

In [22]:
import time


class BehaviorClone:
    def __init__(self, state_dim, hidden_dim, action_dim, lr):
        self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)

    def learn(self, states, actions):
        states = torch.tensor(states, dtype=torch.float).to(device)
        actions = torch.tensor(
            actions, dtype=torch.float).view(-1, 1).to(device)
        mse_loss = F.mse_loss(self.policy(states), actions).float()
        # print(mse_loss)

        self.optimizer.zero_grad()
        mse_loss.backward()
        self.optimizer.step()

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(device)
        action = self.policy(state).cpu().detach().numpy()
        return action[0]


def test_agent(agent, env, n_episode):
    sheep_coor_test = [0, 1.7, np.pi, 4.5, 2*np.pi-0.1]
    dog_theta_test = [0, 1.7, np.pi, 4.5, 2*np.pi-0.1]
    specific_return_list = []
    catched_num = 0
    catch_init_state_list=[]
    for s in sheep_coor_test:
        for d in dog_theta_test:
            env.reset()
            env.sheep_polar_coor = np.array([env.sheep_v, s])
            env.dog_theta = np.array([d])
            episode_return = 0
            for i in range(2000):
                _st = env._get_obs_array()
                _st[0] /= env.circle_R
                action = agent.take_action(_st[[0,3]])[0]
                _, reward, done, _, _ = env.step(action)
                episode_return += reward
                # print(_st, action, reward)
                if done:
                    break
            specific_return_list.append(episode_return)
            # print("s:{},d:{},最终得分：{}".format(s, d, episode_return))
            if reward < -900:
                catched_num += 1
                catch_init_state_list.append((s,d))
    print("sheep has been catched: {}/{},init_states: {}".format(len(sheep_coor_test)
          * len(dog_theta_test), catched_num,",".join([str(i) for i in catch_init_state_list])))
    if(catched_num==0):
        torch.save(agent.policy.state_dict(),"bc-"+str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
    return np.mean(specific_return_list)


torch.manual_seed(0)
np.random.seed(0)

lr = 5e-6
env = SheepDogEnv(circle_R=350, sheep_v=30, dog_v=80,
                  sec_split_n=5, store_mode=False, render_mode=False)
bc_agent = BehaviorClone(
    2, 128, env.action_space.shape[0], lr)
n_iterations = 20000
batch_size = 512
test_returns = []

with tqdm(total=n_iterations, desc="进度条") as pbar:
    for i in range(n_iterations):
        sample_indices = np.random.randint(low=0,
                                           high=expert_s.shape[0],
                                           size=batch_size)
        t1 = time.time()
        bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])
        if i > 9000 and (i+1) % 100 == 0:
            current_return = test_agent(bc_agent, env, 5)
            test_returns.append(current_return)
            # if (i + 1) % 100 == 0:
            #     pbar.set_postfix(
            #         {'return': '%.3f' % np.mean(test_returns[-10:])})
        pbar.update(1)


  state = torch.tensor([state], dtype=torch.float).to(device)
进度条:  46%|████▌     | 9162/20000 [00:14<00:58, 183.85it/s]

sheep has been catched: 25/0,init_states: 


进度条:  46%|████▋     | 9262/20000 [00:15<01:14, 143.64it/s]

sheep has been catched: 25/0,init_states: 


进度条:  47%|████▋     | 9361/20000 [00:16<01:23, 127.17it/s]

sheep has been catched: 25/0,init_states: 


进度条:  47%|████▋     | 9461/20000 [00:17<01:25, 123.59it/s]

sheep has been catched: 25/0,init_states: 


进度条:  48%|████▊     | 9558/20000 [00:18<01:27, 119.85it/s]

sheep has been catched: 25/0,init_states: 


进度条:  48%|████▊     | 9658/20000 [00:19<01:26, 119.27it/s]

sheep has been catched: 25/0,init_states: 


进度条:  49%|████▉     | 9754/20000 [00:20<01:29, 115.06it/s]

sheep has been catched: 25/0,init_states: 


进度条:  49%|████▉     | 9861/20000 [00:21<01:23, 121.20it/s]

sheep has been catched: 25/0,init_states: 


进度条:  50%|████▉     | 9959/20000 [00:22<01:24, 119.44it/s]

sheep has been catched: 25/0,init_states: 


进度条:  50%|█████     | 10067/20000 [00:23<01:15, 131.30it/s]

sheep has been catched: 25/0,init_states: 


进度条:  51%|█████     | 10174/20000 [00:24<01:10, 138.66it/s]

sheep has been catched: 25/0,init_states: 


进度条:  51%|█████▏    | 10273/20000 [00:25<01:14, 130.75it/s]

sheep has been catched: 25/0,init_states: 


进度条:  52%|█████▏    | 10381/20000 [00:26<01:13, 130.05it/s]

sheep has been catched: 25/0,init_states: 


进度条:  52%|█████▏    | 10479/20000 [00:27<01:19, 120.09it/s]

sheep has been catched: 25/0,init_states: 


进度条:  53%|█████▎    | 10575/20000 [00:28<01:20, 117.36it/s]

sheep has been catched: 25/0,init_states: 


进度条:  53%|█████▎    | 10671/20000 [00:29<01:19, 116.80it/s]

sheep has been catched: 25/0,init_states: 


进度条:  54%|█████▍    | 10778/20000 [00:30<01:12, 127.14it/s]

sheep has been catched: 25/0,init_states: 


进度条:  54%|█████▍    | 10885/20000 [00:31<01:08, 132.75it/s]

sheep has been catched: 25/0,init_states: 


进度条:  55%|█████▍    | 10987/20000 [00:32<01:07, 132.63it/s]

sheep has been catched: 25/0,init_states: 


进度条:  55%|█████▌    | 11091/20000 [00:33<01:06, 134.59it/s]

sheep has been catched: 25/0,init_states: 


进度条:  56%|█████▌    | 11192/20000 [00:34<01:08, 129.18it/s]

sheep has been catched: 25/0,init_states: 


进度条:  56%|█████▋    | 11291/20000 [00:35<01:12, 120.50it/s]

sheep has been catched: 25/0,init_states: 


进度条:  57%|█████▋    | 11385/20000 [00:36<01:15, 113.71it/s]

sheep has been catched: 25/0,init_states: 


进度条:  57%|█████▋    | 11483/20000 [00:37<01:14, 113.89it/s]

sheep has been catched: 25/0,init_states: 


进度条:  58%|█████▊    | 11582/20000 [00:38<01:13, 114.69it/s]

sheep has been catched: 25/0,init_states: 


进度条:  58%|█████▊    | 11682/20000 [00:39<01:12, 114.43it/s]

sheep has been catched: 25/0,init_states: 


进度条:  59%|█████▉    | 11768/20000 [00:40<01:17, 106.46it/s]

sheep has been catched: 25/0,init_states: 


进度条:  59%|█████▉    | 11856/20000 [00:41<01:18, 103.21it/s]

sheep has been catched: 25/0,init_states: 


进度条:  60%|█████▉    | 11959/20000 [00:42<01:13, 109.97it/s]

sheep has been catched: 25/0,init_states: 


进度条:  60%|██████    | 12066/20000 [00:43<01:05, 122.06it/s]

sheep has been catched: 25/0,init_states: 


进度条:  61%|██████    | 12172/20000 [00:44<00:59, 130.52it/s]

sheep has been catched: 25/0,init_states: 


进度条:  61%|██████▏   | 12273/20000 [00:45<01:00, 127.02it/s]

sheep has been catched: 25/0,init_states: 


进度条:  62%|██████▏   | 12365/20000 [00:46<01:09, 109.56it/s]

sheep has been catched: 25/0,init_states: 


进度条:  62%|██████▏   | 12459/20000 [00:47<01:08, 110.53it/s]

sheep has been catched: 25/0,init_states: 


进度条:  63%|██████▎   | 12568/20000 [00:48<01:01, 121.70it/s]

sheep has been catched: 25/0,init_states: 


进度条:  63%|██████▎   | 12674/20000 [00:49<00:56, 129.48it/s]

sheep has been catched: 25/0,init_states: 


进度条:  64%|██████▍   | 12781/20000 [00:50<00:54, 132.05it/s]

sheep has been catched: 25/0,init_states: 


进度条:  64%|██████▍   | 12883/20000 [00:51<00:58, 120.77it/s]

sheep has been catched: 25/0,init_states: 


进度条:  65%|██████▍   | 12979/20000 [00:53<01:01, 114.08it/s]

sheep has been catched: 25/0,init_states: 


进度条:  65%|██████▌   | 13078/20000 [00:54<01:00, 114.26it/s]

sheep has been catched: 25/0,init_states: 


进度条:  66%|██████▌   | 13179/20000 [00:55<00:59, 113.82it/s]

sheep has been catched: 25/0,init_states: 


进度条:  66%|██████▋   | 13278/20000 [00:56<01:00, 111.66it/s]

sheep has been catched: 25/0,init_states: 


进度条:  67%|██████▋   | 13373/20000 [00:57<01:01, 108.32it/s]

sheep has been catched: 25/0,init_states: 


进度条:  67%|██████▋   | 13473/20000 [00:58<00:58, 110.75it/s]

sheep has been catched: 25/0,init_states: 


进度条:  68%|██████▊   | 13567/20000 [00:59<00:59, 107.33it/s]

sheep has been catched: 25/0,init_states: 


进度条:  68%|██████▊   | 13665/20000 [01:00<00:59, 106.75it/s]

sheep has been catched: 25/0,init_states: 


进度条:  69%|██████▉   | 13764/20000 [01:01<00:58, 106.85it/s]

sheep has been catched: 25/0,init_states: 


进度条:  69%|██████▉   | 13865/20000 [01:02<00:56, 108.67it/s]

sheep has been catched: 25/0,init_states: 


进度条:  70%|██████▉   | 13966/20000 [01:04<00:56, 107.46it/s]

sheep has been catched: 25/0,init_states: 


进度条:  70%|███████   | 14064/20000 [01:05<00:56, 104.22it/s]

sheep has been catched: 25/0,init_states: 


进度条:  71%|███████   | 14159/20000 [01:06<00:57, 101.33it/s]

sheep has been catched: 25/1,init_states: (6.183185307179587, 0)


进度条:  71%|███████▏  | 14260/20000 [01:07<00:58, 98.96it/s] 

sheep has been catched: 25/1,init_states: (6.183185307179587, 0)


进度条:  72%|███████▏  | 14356/20000 [01:08<00:56, 99.60it/s]

sheep has been catched: 25/1,init_states: (6.183185307179587, 0)


进度条:  72%|███████▏  | 14467/20000 [01:09<00:49, 110.86it/s]

sheep has been catched: 25/1,init_states: (6.183185307179587, 0)


进度条:  73%|███████▎  | 14574/20000 [01:11<00:46, 117.08it/s]

sheep has been catched: 25/6,init_states: (0, 0),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  73%|███████▎  | 14674/20000 [01:12<00:46, 113.39it/s]

sheep has been catched: 25/7,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  74%|███████▍  | 14779/20000 [01:13<00:45, 114.69it/s]

sheep has been catched: 25/6,init_states: (0, 0),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  74%|███████▍  | 14875/20000 [01:14<00:49, 103.99it/s]

sheep has been catched: 25/7,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  75%|███████▍  | 14971/20000 [01:15<00:51, 96.87it/s] 

sheep has been catched: 25/7,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  75%|███████▌  | 15069/20000 [01:16<00:49, 99.60it/s]

sheep has been catched: 25/7,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  76%|███████▌  | 15162/20000 [01:18<00:51, 94.50it/s]

sheep has been catched: 25/7,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  76%|███████▋  | 15259/20000 [01:19<00:50, 93.96it/s]

sheep has been catched: 25/7,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  77%|███████▋  | 15354/20000 [01:20<00:50, 91.20it/s]

sheep has been catched: 25/8,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  77%|███████▋  | 15463/20000 [01:22<00:48, 94.22it/s]

sheep has been catched: 25/9,init_states: (0, 0),(0, 6.183185307179587),(1.7, 1.7),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 6.183185307179587)


进度条:  78%|███████▊  | 15564/20000 [01:23<00:43, 102.44it/s]

sheep has been catched: 25/13,init_states: (0, 0),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 6.183185307179587),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)


进度条:  78%|███████▊  | 15663/20000 [01:24<00:40, 105.84it/s]

sheep has been catched: 25/13,init_states: (0, 0),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 6.183185307179587),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)


进度条:  79%|███████▉  | 15758/20000 [01:25<00:41, 101.65it/s]

sheep has been catched: 25/13,init_states: (0, 0),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 6.183185307179587),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)


进度条:  79%|███████▉  | 15850/20000 [01:26<00:44, 92.35it/s] 

sheep has been catched: 25/13,init_states: (0, 0),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 6.183185307179587),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)


进度条:  80%|███████▉  | 15954/20000 [01:28<00:43, 93.55it/s]

sheep has been catched: 25/13,init_states: (0, 0),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 6.183185307179587),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(4.5, 3.141592653589793),(4.5, 4.5),(6.183185307179587, 0),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)


进度条:  80%|███████▉  | 15999/20000 [01:28<00:22, 180.99it/s]


KeyboardInterrupt: 

In [6]:
def test_agent(agent, env, n_episode):
    sheep_coor_test = [0, 1.7, np.pi, 4.5, 2*np.pi-0.1]
    dog_theta_test = [0, 1.7, np.pi, 4.5, 2*np.pi-0.1]
    specific_return_list = []
    catched_num = 0
    catch_init_state_list=[]
    for s in sheep_coor_test:
        for d in dog_theta_test:
            env.reset()
            env.sheep_polar_coor = np.array([env.sheep_v, s])
            env.dog_theta = np.array([d])
            episode_return = 0
            for i in range(2000):
                _st = env._get_obs_array()
                _st[0] /= env.circle_R
                action = agent.take_action(_st)[0]
                _, reward, done, _, _ = env.step(action)
                episode_return += reward
                # print(_st, action, reward)
                if done:
                    break
            specific_return_list.append(episode_return)
            print("s:{},d:{},最终得分：{}".format(s, d, episode_return))
            if reward < -900:
                catched_num += 1
                catch_init_state_list.append((s,d))
    print("sheep has been catched: {}/{},init_states: {}".format(len(sheep_coor_test)
          * len(dog_theta_test), catched_num,",".join([str(i) for i in catch_init_state_list])))
    return np.mean(specific_return_list)


test_agent(bc_agent, env, 5)


s:0,d:0,最终得分：-1324.328041493539
s:0,d:1.7,最终得分：613.714447679481
s:0,d:3.141592653589793,最终得分：755.8463644482215
s:0,d:4.5,最终得分：774.964312341708
s:0,d:6.183185307179587,最终得分：756.8657026697047
s:1.7,d:0,最终得分：773.0200719665925
s:1.7,d:1.7,最终得分：732.752254705843
s:1.7,d:3.141592653589793,最终得分：632.8096865214945
s:1.7,d:4.5,最终得分：777.4381304789692
s:1.7,d:6.183185307179587,最终得分：771.7153007977226
s:3.141592653589793,d:0,最终得分：771.8395152019557
s:3.141592653589793,d:1.7,最终得分：746.3676194663918
s:3.141592653589793,d:3.141592653589793,最终得分：772.5737135569807
s:3.141592653589793,d:4.5,最终得分：744.5173100029378
s:3.141592653589793,d:6.183185307179587,最终得分：771.821691564696
s:4.5,d:0,最终得分：770.7485704354729
s:4.5,d:1.7,最终得分：751.5368815301912
s:4.5,d:3.141592653589793,最终得分：762.3170737973666
s:4.5,d:4.5,最终得分：766.9470525297797
s:4.5,d:6.183185307179587,最终得分：770.2244192137242
s:6.183185307179587,d:0,最终得分：-1317.9678791263775
s:6.183185307179587,d:1.7,最终得分：484.1831442943279
s:6.183185307179587,d:3.141592653589793,最

573.9550423604549

In [15]:
torch.save(bc_agent.policy.state_dict(),"bc_network-1")

In [80]:
from SheepDogEnv import SheepDogEnv


def sample_expert_data(n_episode):
    states = []
    actions = []
    with tqdm(total=n_episode, desc="进度条") as pbar:
        for i in range(n_episode):
            done = False
            while not done:
                _st = env._get_obs_array()
                L1 = _st[0]
                L2 = env.circle_R
                theta1 = _st[1]
                theta2 = _st[2]
                if(L1**2 + L2**2 - 2*L1*L2*np.cos(theta2-theta1))<0:
                    break
                L3 = np.sqrt(L1**2 + L2**2 - 2*L1*L2*np.cos(theta2-theta1))
                if(L3==0):
                    break
                theta3 = np.arcsin(L2/L3*np.sin(theta2-theta1))
                if(np.abs(theta2-theta1) < np.arccos(L1/L2)):
                    theta3 = np.pi/2 if theta3 > 0 else -np.pi/2
                action = theta3

                # if(np.abs(_st[3]/env.dog_theta_v) > (env.circle_R-_st[0])/env.sheep_v + 1):
                #     action = 0

                _st[0] /= env.circle_R
                states.append(_st)
                actions.append(action)
                _, _, done, _, _ = env.step(action)
                # print(i, _st, action)
            pbar.update(1)
            env.reset()
            env.sheep_polar_coor = np.array(
                [np.random.randint(env.sheep_v, env.circle_R-1), 0.0])
            env.dog_theta = np.array([0])
    return np.array(states), np.array(actions)


env = SheepDogEnv(circle_R=350, sheep_v=30, dog_v=80,
                  sec_split_n=5, store_mode=False, render_mode=False)

random.seed(0)
n_episode = 500
expert_s, expert_a = sample_expert_data(n_episode)

print(len(expert_s))
# n_samples = 30  # 采样30个数据
# random_index = random.sample(range(expert_s.shape[0]), n_samples)
# expert_s = expert_s[random_index]
# expert_a = expert_a[random_index]


进度条: 100%|██████████| 500/500 [00:29<00:00, 17.03it/s]


743871


In [81]:
import time


class BehaviorClone:
    def __init__(self, state_dim, hidden_dim, action_dim, lr):
        self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)

    def learn(self, states, actions):
        states = torch.tensor(states, dtype=torch.float).to(device)
        actions = torch.tensor(
            actions, dtype=torch.float).view(-1, 1).to(device)
        mse_loss = F.mse_loss(self.policy(states), actions).float()
        # print(mse_loss)

        self.optimizer.zero_grad()
        mse_loss.backward()
        self.optimizer.step()

    def take_action(self, state):
        state = torch.tensor([state], dtype=torch.float).to(device)
        action = self.policy(state).cpu().detach().numpy()
        return action[0]


def test_agent(agent, env, n_episode):
    sheep_coor_test = [0, 1.7, np.pi, 4.5, 2*np.pi-0.1]
    dog_theta_test = [0, 1.7, np.pi, 4.5, 2*np.pi-0.1]
    specific_return_list = []
    catched_num = 0
    catch_init_state_list=[]
    for s in sheep_coor_test:
        for d in dog_theta_test:
            env.reset()
            env.sheep_polar_coor = np.array([env.sheep_v, s])
            env.dog_theta = np.array([d])
            episode_return = 0
            for i in range(2000):
                _st = env._get_obs_array()
                _st[0] /= env.circle_R
                action = agent.take_action(_st)[0]
                _, reward, done, _, _ = env.step(action)
                episode_return += reward
                # print(_st, action, reward)
                if done:
                    break
            specific_return_list.append(episode_return)
            # print("s:{},d:{},最终得分：{}".format(s, d, episode_return))
            if reward < -900:
                catched_num += 1
                catch_init_state_list.append((s,d))
    print("sheep has been catched: {}/{},init_states: {}".format(len(sheep_coor_test)
          * len(dog_theta_test), catched_num,",".join([str(i) for i in catch_init_state_list])))
    if(catched_num==0):
        torch.save(agent.policy.state_dict(),"bc-"+str(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")))
    return np.mean(specific_return_list)


torch.manual_seed(0)
np.random.seed(0)

lr = 5e-6
env = SheepDogEnv(circle_R=350, sheep_v=30, dog_v=80,
                  sec_split_n=5, store_mode=False, render_mode=False)
# bc_agent = BehaviorClone(
#     env.observation_space.shape[0], 128, env.action_space.shape[0], lr)
n_iterations = 2000
batch_size = 512
test_returns = []

with tqdm(total=n_iterations, desc="进度条") as pbar:
    for i in range(n_iterations):
        sample_indices = np.random.randint(low=0,
                                           high=expert_s.shape[0],
                                           size=batch_size)
        t1 = time.time()
        bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])
        if i > 900 and (i+1) % 200 == 0:
            current_return = test_agent(bc_agent, env, 5)
            test_returns.append(current_return)
            # if (i + 1) % 100 == 0:
            #     pbar.set_postfix(
            #         {'return': '%.3f' % np.mean(test_returns[-10:])})
        # pbar.update(1)


进度条:   0%|          | 0/2000 [00:00<?, ?it/s]

sheep has been catched: 25/25,init_states: (0, 0),(0, 1.7),(0, 3.141592653589793),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 3.141592653589793),(1.7, 4.5),(1.7, 6.183185307179587),(3.141592653589793, 0),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(3.141592653589793, 4.5),(3.141592653589793, 6.183185307179587),(4.5, 0),(4.5, 1.7),(4.5, 3.141592653589793),(4.5, 4.5),(4.5, 6.183185307179587),(6.183185307179587, 0),(6.183185307179587, 1.7),(6.183185307179587, 3.141592653589793),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)
sheep has been catched: 25/24,init_states: (0, 0),(0, 1.7),(0, 3.141592653589793),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 3.141592653589793),(1.7, 4.5),(1.7, 6.183185307179587),(3.141592653589793, 0),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(3.141592653589793, 4.5),(3.141592653589793, 6.183185307179587),(4.5, 0),(4.5, 3.141592653589793),(4.5, 4.5),(4.5, 6.183185307179587),(6.183

进度条:   0%|          | 0/2000 [00:30<?, ?it/s]

sheep has been catched: 25/25,init_states: (0, 0),(0, 1.7),(0, 3.141592653589793),(0, 4.5),(0, 6.183185307179587),(1.7, 0),(1.7, 1.7),(1.7, 3.141592653589793),(1.7, 4.5),(1.7, 6.183185307179587),(3.141592653589793, 0),(3.141592653589793, 1.7),(3.141592653589793, 3.141592653589793),(3.141592653589793, 4.5),(3.141592653589793, 6.183185307179587),(4.5, 0),(4.5, 1.7),(4.5, 3.141592653589793),(4.5, 4.5),(4.5, 6.183185307179587),(6.183185307179587, 0),(6.183185307179587, 1.7),(6.183185307179587, 3.141592653589793),(6.183185307179587, 4.5),(6.183185307179587, 6.183185307179587)



