In [1]:
import torch
import numpy as np
import sys
import os
import torch.nn as nn
import gym
%load_ext autoreload
%autoreload 2

device = torch.device('cuda:4')

### multi dim matmul check

In [2]:
a = torch.Tensor(np.arange(10))
a.reshape(2,5)

tensor([[0., 1., 2., 3., 4.],
        [5., 6., 7., 8., 9.]])

In [3]:
torch.reshape(a, [2,5])

tensor([[0., 1., 2., 3., 4.],
        [5., 6., 7., 8., 9.]])

In [4]:
a = torch.reshape(a, [2,5])
a

tensor([[0., 1., 2., 3., 4.],
        [5., 6., 7., 8., 9.]])

In [5]:
torch.mean(a, axis=-1)

tensor([2., 7.])

In [6]:
a = torch.randn([32, 4, 200])
b = torch.randn([200, 1])
c = torch.matmul(a, b)
c = c.squeeze()
c.shape

torch.Size([32, 4])

### multihead mixture

In [6]:
dim = 10
num_cols = 1
dtype = torch.float32
mat_shape = (dim, dim) if num_cols is None else (dim, num_cols)
# mat = torch.rand(mat_shape, dtype=dtype, device=device)
mat = torch.tensor([[0.9753],
         [0.6686],
         [0.8801],
         [0.2483],
         [0.3483],
         [0.3410],
         [0.4854],
         [0.8116],
         [0.8025],
         [0.3909]], dtype=dtype, device=device)

In [7]:
print(mat.shape)
norm = torch.norm(mat, p=1, dim=0, keepdim=True)
print(norm)
mat /= norm
print(mat,torch.sum(mat))

torch.Size([10, 1])
tensor([[5.9520]], device='cuda:4')
tensor([[0.1639],
        [0.1123],
        [0.1479],
        [0.0417],
        [0.0585],
        [0.0573],
        [0.0816],
        [0.1364],
        [0.1348],
        [0.0657]], device='cuda:4') tensor(1., device='cuda:4')


### MultiHeadQNetwork check

In [2]:
os.getcwd()
os.path.dirname(os.path.dirname(os.getcwd()))

'/home/shenshuo/workspace'

In [14]:
class Linear0(nn.Linear):
    def reset_parameters(self):
        nn.init.constant_(self.weight, 0.0)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0.0)


class Scale(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, x):
        return x * self.scale


class MultiHeadQNetwork(nn.Module):
    def __init__(self, env, frames=4, num_heads=1, transform_strategy='STOCHASTIC', transform_matrix=None):
        self.num_heads = num_heads
        self.transform_strategy = transform_strategy
        self.transform_matrix = transform_matrix
        super(MultiHeadQNetwork, self).__init__()
        self.network = nn.Sequential(
            Scale(1 / 255),
            nn.Conv2d(frames, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            Linear0(512, env.action_space.n * num_heads),
        )

    def forward(self, x, device):
        x = torch.Tensor(x)
        x.to(device)
        unordered_q_heads = self.network(x)
        unordered_q_heads = torch.reshape(unordered_q_heads, [-1, env.action_space.n, self.num_heads])
        q_heads, q_values = combine_q_functions(unordered_q_heads, self.transform_strategy, transform_matrix)
        return q_heads, q_values


In [4]:
def combine_q_functions(q_functions, transform_strategy, transform_matrix=None):
    q_values = torch.mean(q_functions, axis=-1)
    if transform_strategy=='STOCHASTIC':
        # q_functions input shape: (batch_size, num_actions, num_heads)
        # left_stochastic_matrix shape: (num_heads, num_convex_combinations=1(defualt))
        # q_functions output shape: (batch_size, num_actions, 1)
        q_functions = torch.matmul(q_functions, transform_matrix).squeeze()
    elif transform_strategy=='IDENTITY':
        pass
    else:
        raise ValueError(
            '{} is not a valid reordering strategy'.format(transform_strategy))
    return q_functions, q_values

def random_stochastic_matrix(dim, dtype=torch.float32, device=device):
    """Generates a random left stochastic matrix."""
    # check dopamine's notebook 
    # after test this has same result
    mat_shape = (dim, 1)
    mat = torch.rand(mat_shape, dtype=dtype, device=device)
    mat /= torch.norm(mat, p=1, dim=0, keepdim=True)
    return mat


In [5]:
transform_matrix = random_stochastic_matrix(200, device=device)
transform_matrix.shape

torch.Size([200, 1])

In [15]:
env = gym.make('PongNoFrameskip-v0')
q_network = MultiHeadQNetwork(env, num_heads=200, 
                              transform_matrix=transform_matrix).to(device)
q_network.network

Sequential(
  (0): Scale()
  (1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (2): ReLU()
  (3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (4): ReLU()
  (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Flatten(start_dim=1, end_dim=-1)
  (8): Linear(in_features=3136, out_features=512, bias=True)
  (9): ReLU()
  (10): Linear0(in_features=512, out_features=1200, bias=True)
)

In [16]:
state = torch.rand([32, 4, 84, 84], device=device, dtype=torch.float32)
state.device

device(type='cuda', index=4)

In [17]:
print(device)
q_functions, q_values = q_network.forward(state, device=device)
q_functions.shape, q_values.shape

cuda:4


TypeError: expected CPU (got CUDA)

In [9]:
s_actions = torch.ones(32, device=device)
old_val = q_functions.gather(1, s_actions.long().view(-1, 1))
old_val.shape

torch.Size([32, 1])

In [75]:
print(q_network.parameters)

<bound method Module.parameters of MultiHeadQNetwork(
  (network): Sequential(
    (0): Scale()
    (1): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (4): ReLU()
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
    (7): Flatten(start_dim=1, end_dim=-1)
    (8): Linear(in_features=3136, out_features=512, bias=True)
    (9): ReLU()
    (10): Linear0(in_features=512, out_features=1200, bias=True)
  )
)>


In [79]:
for k ,v in q_network.state_dict().items():
    print(k)

network.1.weight
network.1.bias
network.3.weight
network.3.bias
network.5.weight
network.5.bias
network.8.weight
network.8.bias
network.10.weight
network.10.bias


In [88]:
q_network.transform_strategy,q_network.transform_matrix.shape


('STOCHASTIC', torch.Size([200, 1]))

In [22]:
a = np.ones(10)
type(a) == np.ndarray

True