### 没有batch_size这个维度的前提下

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CDQN(nn.Module):
    
    def __init__(self, state_size, action_size, N, hidden=[128, 256]):
        super(CDQN, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.N = N
        
        self.fc1 = nn.Linear(state_size, hidden[0])
        self.fc2 = nn.Linear(hidden[0], hidden[1])
        self.fc3 = nn.Linear(hidden[1], N*action_size)
        self.output = nn.LogSoftmax(dim=2)
    
    def forward(self, state):
        x = state
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, self.action_size, self.N)
        x = self.output(x)
        return x

In [2]:
import numpy as np

def new_distribution_checker(prev_distribution, reward, gamma):
    new_distribution = np.zeros((1,N))
    for j in range(N):
        prob = prev_distribution[0,j] 
        new_val = vals[j] * gamma + reward
        new_val = np.clip(new_val, v_min, v_max)
        lower_index = np.floor((new_val - v_min)/unit).astype('int')
        upper_index = np.minimum(lower_index + 1, N - 1)
        lower_distance_ratio = 1 - (new_val - vals[lower_index])/unit
        new_distribution[0,lower_index] += prob * lower_distance_ratio
        new_distribution[0,upper_index] += prob * (1 - lower_distance_ratio)
    return new_distribution

In [3]:
def new_distribution(prev_distribution, reward, gamma):
    new_vals = vals * gamma + reward
    new_vals = np.clip(new_vals, v_min, v_max)
    lower_indexes  = np.floor((new_vals - v_min)/unit).astype('int')
    lower_distances = 1 - np.minimum((new_vals - vals[lower_indexes])/unit,1)
    upper_indexes = np.minimum(lower_indexes + 1, N-1)
    transition = np.zeros((N,N))
    first_dim = range(N)
    transition[first_dim, lower_indexes] += lower_distances
    transition[first_dim, upper_indexes] += 1 - lower_distances
    return prev_distribution.dot(transition)

In [4]:
N = 51
v_min = -5
v_max = 5
state_size = 8
action_size = 4

unit = (v_max - v_min)/(N - 1)
vals = np.linspace(v_min,v_max,N)
test_state = np.random.standard_normal((1,state_size))
test_net = CDQN(state_size, action_size, N)
with torch.no_grad():
    result = test_net(torch.tensor(test_state,dtype = torch.float32)).exp()
result = result[0, np.random.randint(0,action_size,1),:].numpy()
result.shape

(1, 51)

In [5]:
test_reward = np.random.normal(1,1)
test_gamma = 0.99
new_dict = new_distribution(result, test_reward, test_gamma)
new_dict_checker = new_distribution_checker(result, test_reward, test_gamma)

(np.abs(new_dict - new_dict_checker) <= 0.001).all()

True

In [6]:
import time
t1 = time.time()
for _ in range(10):
    new_distribution(result, test_reward, test_gamma)
t2 = time.time() - t1
print("numpy implement:",t2)

t1 = time.time()
for _ in range(10):
    new_distribution_checker(result, test_reward, test_gamma)
t3 = time.time() - t1
print("python implement:",t3)
print("numpy is {} times to pthon implement".format(t3/t2))

numpy implement: 0.0019927024841308594
python implement: 0.043897151947021484
numpy is 22.028954295285953 times to pthon implement


### 有batch_size的前提下

In [7]:
def new_distribution_checker(prev_distribution, reward, gamma):
    length = len(reward)
    assert prev_distribution.shape[0] == length
    if len(prev_distribution.shape) == 2:
        prev_distribution = np.expand_dims(prev_distribution, 1)
    new_distribution = np.zeros((length, 1, N))
    for i in range(length):
        for j in range(N):
            prob = prev_distribution[i,0,j]
            new_val = vals[j] * gamma + reward[i]
            new_val = np.clip(new_val,v_min,v_max)
            lower_index = np.floor((new_val - v_min)/unit).astype('int')
            upper_index = np.minimum(lower_index+1,N-1)
            lower_distance = 1 - (new_val - vals[lower_index])/unit
            new_distribution[i,0,lower_index] += lower_distance*prob
            new_distribution[i,0,upper_index] += (1-lower_distance)*prob
    return new_distribution

In [8]:
def new_distribution(prev_distribution, reward, gamma):
    length = reward.shape[0]
    assert prev_distribution.shape[0] == length
    reward = reward.reshape(-1,1)
    new_vals = vals.reshape(1,-1) * gamma + reward
    new_vals = np.clip(new_vals,v_min,v_max)
    lower_indexes = np.floor((new_vals-v_min)/unit).astype('int')
    upper_indexes = np.minimum(lower_indexes+1,N-1)
    lower_distance = 1 - np.minimum((new_vals - vals[lower_indexes])/unit,1)
    transition = np.zeros((length,N,N))
    first_dim = np.repeat(range(length),N)
    second_dim = length * list(range(N)) 
    transition[first_dim, second_dim, lower_indexes.reshape(-1)] += lower_distance.reshape(-1)
    transition[first_dim, second_dim, upper_indexes.reshape(-1)] += 1 - lower_distance.reshape(-1)
    if len(prev_distribution.shape) == 2:
        prev_distribution = np.expand_dims(prev_distribution,1) # (L,1,N)
    return np.matmul(prev_distribution,transition)

In [9]:
N = 51
v_min = -5
v_max = 5
state_size = 8
action_size = 4
batch_size = 128

unit = (v_max - v_min)/(N - 1)
vals = np.linspace(v_min,v_max,N)
test_state = np.random.standard_normal((batch_size,1,state_size))
test_net = CDQN(state_size, action_size, N)
with torch.no_grad():
    test_distribution = test_net(torch.tensor(test_state,dtype=torch.float32)).exp()
test_distribution = test_distribution[range(batch_size),np.random.randint(0,action_size,batch_size),:].numpy()

In [10]:
test_distribution.shape

(128, 51)

In [11]:
test_reward = np.random.uniform(-2,2,batch_size)
test_gamma = 0.99
new_dict = new_distribution(test_distribution, test_reward, test_gamma)
new_dict_checker = new_distribution_checker(test_distribution, test_reward, test_gamma)
# np.abs(new_dict - new_dict_checker)
(np.abs(new_dict - new_dict_checker) <= 0.001).all()

True

In [12]:
import time
t1 = time.time()
for _ in range(10):
    new_distribution(test_distribution, test_reward, test_gamma)
t2 = time.time() - t1
print("numpy implement:",t2)

t1 = time.time()
for _ in range(10):
    new_distribution_checker(test_distribution, test_reward, test_gamma)
t3 = time.time() - t1
print("python implement:",t3)
print("numpy is {} times to pthon implement".format(t3/t2))

numpy implement: 0.06682038307189941
python implement: 2.231031894683838
numpy is 33.38849303337912 times to pthon implement


### torch实现batch_size

In [64]:
def torch_new_distribution(prev_distribution, reward, gamma):
    vals_torch = torch.tensor(vals,dtype = torch.float32)
    reward = reward.view(-1,1)
    length = reward.size(0)
    assert prev_distribution.shape[0] == length
    new_vals = vals_torch.view(1,-1) * gamma + reward
    new_vals = torch.clamp(new_vals,v_min,v_max)
    lower_indexes = torch.floor((new_vals-v_min)/unit).long()
    print(lower_indexes.shape)
    upper_indexes = torch.min(lower_indexes+1,other = torch.tensor(N-1,dtype=torch.long))
    lower_vals = vals_torch[lower_indexes]
    print(lower_vals.shape)
    lower_distance = 1 - torch.min((new_vals-lower_vals)/unit,other=torch.tensor(1,dtype=torch.float32))
    transition = torch.zeros((length,N,N))
    first_dim = torch.tensor(range(length),dtype=torch.long).view(-1,1).repeat(1,N).view(-1)
    second_dim = torch.tensor(range(N),dtype=torch.long).repeat(length)
    transition[first_dim, second_dim, lower_indexes.view(-1)] += lower_distance.view(-1)
    transition[first_dim, second_dim, upper_indexes.view(-1)] += 1 - lower_distance.view(-1)
    if len(prev_distribution.shape) == 2:
        prev_distribution = prev_distribution.unsqueeze(1)
    return torch.bmm(prev_distribution,transition)

In [65]:
test_distribution_torch = torch.tensor(test_distribution, dtype=torch.float32)
test_reward_torch = torch.tensor(test_reward, dtype=torch.float32)

new_dist_torch = torch_new_distribution(test_distribution_torch, test_reward_torch, test_gamma)
(np.abs(new_dist_torch.numpy() - new_dict)<=0.0001).all()

torch.Size([128, 51])
torch.Size([128, 51])


True

In [15]:
test_distribution_torch = torch.tensor(test_distribution, dtype=torch.float32)
test_reward_torch = torch.tensor(test_reward, dtype=torch.float32)

t1 = time.time()
for _ in range(10):
    torch_new_distribution(test_distribution_torch, test_reward_torch, test_gamma)
print(time.time() - t1)

t1 = time.time()
for _ in range(10):
    new_distribution(test_distribution, test_reward, test_gamma)
print(time.time() - t1)

0.025931596755981445
0.06382966041564941


In [59]:
import torch
value = torch.randn(128,4)

value,indexs = torch.max(input = value,dim = 1,keepdims = True)
indexs = indexs.unsqueeze(1).repeat(1,1,51)

distribution = torch.randn(128,4,51)
print(indexs.shape)
print(distribution.shape)

new_prob = torch.gather(input = distribution, dim = 1, index = indexs)
print(new_prob.shape)

torch.Size([128, 1, 51])
torch.Size([128, 4, 51])
torch.Size([128, 1, 51])


In [52]:
distribution

tensor([[[ 2.0751,  0.3832,  1.0936,  ..., -0.8018,  0.2763,  0.1824],
         [ 0.6037,  0.6597,  0.6175,  ...,  0.0908,  0.6352,  0.8452],
         [ 0.9136, -0.4437,  2.5809,  ...,  0.1672, -1.0536,  0.1368],
         [ 0.6401, -0.7442, -0.1780,  ...,  0.3915, -0.0816,  0.1285]],

        [[-0.6981,  2.5210,  2.0635,  ...,  1.7243, -0.0054, -0.0454],
         [-0.3141, -1.2203, -0.9132,  ...,  0.4656, -0.2350, -0.7108],
         [ 0.4908,  0.6577,  1.2031,  ..., -0.1226,  0.9048,  0.6337],
         [-0.6079,  0.0950, -0.3259,  ...,  0.0594,  0.7983,  0.5678]],

        [[ 0.2770, -0.3961,  0.1367,  ...,  0.8476, -1.0872, -1.0798],
         [ 1.9798, -0.8580, -0.3838,  ...,  0.0803, -1.6462,  0.1190],
         [ 0.1148, -0.0337,  0.2880,  ..., -1.5193, -0.2405,  0.0761],
         [-1.3058,  0.6355, -0.2018,  ...,  1.4698, -1.1038, -0.5393]],

        ...,

        [[ 0.6385,  1.3177,  0.5646,  ...,  0.8797, -1.0995, -0.4916],
         [-3.2929,  1.5723, -1.0038,  ..., -0.4519,  0.91

In [49]:
prob.sum(axis = 2)

tensor([[  7.0875],
        [  6.8500],
        [  5.7859],
        [-13.0540],
        [ -9.4969],
        [ -2.0813],
        [ -2.1127],
        [ -2.5662],
        [  2.4416],
        [  2.5607],
        [ -2.7764],
        [  1.8979],
        [ -1.3486],
        [  9.1250],
        [ -6.7386],
        [  3.0038],
        [  0.8740],
        [  5.1195],
        [ 14.9664],
        [ 10.9062],
        [  3.1666],
        [ -2.8788],
        [ -3.5492],
        [ -0.8710],
        [ 12.6076],
        [ 11.8837],
        [  0.5997],
        [  5.6149],
        [  0.4704],
        [  0.8800],
        [  1.1790],
        [ -2.8200],
        [  1.2729],
        [ -2.6118],
        [-10.3240],
        [ -0.9134],
        [ -5.5520],
        [  3.9151],
        [  8.3325],
        [  1.7525],
        [ -1.4840],
        [  3.0378],
        [ -7.3753],
        [ -5.2572],
        [ -4.1449],
        [ -2.4071],
        [ -9.0305],
        [  9.9860],
        [  4.0289],
        [ -9.2216],


In [57]:
x = torch.linspace(-5,5,51).view(1,1,51)

value = torch.randn(128,4,51)

torch.Size([128, 4, 51])

In [61]:
x = torch.linspace(-5,5,51)
x

tensor([-5.0000, -4.8000, -4.6000, -4.4000, -4.2000, -4.0000, -3.8000, -3.6000,
        -3.4000, -3.2000, -3.0000, -2.8000, -2.6000, -2.4000, -2.2000, -2.0000,
        -1.8000, -1.6000, -1.4000, -1.2000, -1.0000, -0.8000, -0.6000, -0.4000,
        -0.2000,  0.0000,  0.2000,  0.4000,  0.6000,  0.8000,  1.0000,  1.2000,
         1.4000,  1.6000,  1.8000,  2.0000,  2.2000,  2.4000,  2.6000,  2.8000,
         3.0000,  3.2000,  3.4000,  3.6000,  3.8000,  4.0000,  4.2000,  4.4000,
         4.6000,  4.8000,  5.0000])

In [83]:
torch.tensor(range(128),dtype=torch.long).view(-1,1).repeat(1,N).view(-1)

tensor([  0,   0,   0,  ..., 127, 127, 127])

In [95]:
torch.tensor(range(51),dtype=torch.long).repeat(128)

tensor([ 0,  1,  2,  ..., 48, 49, 50])

In [125]:
a = torch.rand(128,1,51)
b = torch.rand(128,1,51)
print((a * b).mean(dim=2,keepdims = False).sum())
print((a * b).sum(dim=2,keepdims = False).mean())

tensor(32.6059)
tensor(12.9914)


In [121]:
a

tensor([[[0.2304, 0.9909, 0.5938,  ..., 0.5209, 0.6763, 0.8405]],

        [[0.2790, 0.2014, 0.4780,  ..., 0.3732, 0.3972, 0.0119]],

        [[0.0977, 0.1232, 0.8942,  ..., 0.2844, 0.5890, 0.9955]],

        ...,

        [[0.3420, 0.0195, 0.0962,  ..., 0.0829, 0.0053, 0.2469]],

        [[0.2213, 0.9321, 0.8601,  ..., 0.0686, 0.8871, 0.8576]],

        [[0.0290, 0.3218, 0.6698,  ..., 0.0227, 0.9045, 0.9442]]])