In [1]:
import scm
import causal_mdp
import causal_discovery
import numpy as np

np.set_printoptions(precision=4)

In [2]:
class MyMdp(causal_mdp.CausalMdp):
    def __init__(self):
        super().__init__()

        x = scm.ExoVar(name='x[t]')  # x
        y = scm.ExoVar(name='y[t]')  # y
        u = scm.ExoVar(name='u[t]', default=(0, 0))
        ax = scm.ExoVar(name='ax[t]')  # dx
        ay = scm.ExoVar(name='ay[t]')  # dy
        xo = scm.EndoVar((x, ax, u), lambda x, ax, u: x + ax + u[0], name='x[t+1]')
        yo = scm.EndoVar((y, ay, u), lambda y, ay, u: y + ay + u[1], name='y[t+1]')
        r = scm.EndoVar((xo, yo), lambda x, y: - (x*x + y*y), name='r[t+1]')
        
        self.config(state_vars=[(x, xo), (y, yo)],
                    action_vars=(ax, ay),
                    reward_vars=[r])

    def sample(self):
        x = np.random.normal(scale=10)
        y = np.random.normal(scale=10)
        ax = np.random.normal(scale=2)
        ay = np.random.normal(scale=2)
        u = np.random.normal(scale=0.2, size=2)
        return (x, y), (ax, ay), u

mdp = MyMdp()
mdp.plot().view('./causal_graph')

'causal_graph.pdf'

In [5]:
buffer = causal_mdp.ExperienceBuffer(5000)
buffer.declear_state('x[t]', 'x[t+1]', float)
buffer.declear_state('y[t]', 'y[t+1]', float)
buffer.declear_reward('r[t+1]')
buffer.declear_action('ax[t]', float)
buffer.declear_action('ay[t]', float)
buffer._declear('u[t]', float, shape=2)

n_sample = 5000  

for t in range(n_sample):
    mdp.clear()
    state, action, u = mdp.sample()
    # mdp.assign(u=u)
    new_state, reward = mdp.model(state, action)
    buffer.write_from_mdp(mdp)
    print(f'step {t}:')
    for var in mdp.variables:
        print(f'\t{var.name}: {var.value}')

step 0:
	u[t]: (0, 0)
	x[t]: -2.106814864398095
	ax[t]: 2.5167127740561055
	ay[t]: 1.1096126307382503
	x[t+1]: 0.40989790965801065
	y[t]: -23.001595320590276
	y[t+1]: -21.891982689852025
	r[t+1]: -479.4269223891227
step 1:
	u[t]: (0, 0)
	x[t]: -18.88730435900095
	ax[t]: 2.085373923703943
	ay[t]: -0.9739819449493325
	x[t+1]: -16.801930435297006
	y[t]: 0.1863467967079695
	y[t+1]: -0.787635148241363
	r[t+1]: -282.925235479305
step 2:
	u[t]: (0, 0)
	x[t]: -3.8257430639349943
	ax[t]: -1.3624652113833462
	ay[t]: -1.1706605893044475
	x[t+1]: -5.18820827531834
	y[t]: 0.3658336283475796
	y[t+1]: -0.8048269609568679
	r[t+1]: -27.565251545164774
step 3:
	u[t]: (0, 0)
	x[t]: -7.4660383297808055
	ax[t]: 4.3302388895487445
	ay[t]: 0.4908033888957607
	x[t+1]: -3.135799440232061
	y[t]: 5.954723545505141
	y[t+1]: 6.445526934400902
	r[t+1]: -51.37805559144719
step 4:
	u[t]: (0, 0)
	x[t]: 5.176677645100552
	ax[t]: -1.223995689170359
	ay[t]: 1.3870098502018238
	x[t+1]: 3.9526819559301933
	y[t]: 4.57963475

In [7]:
# p值 = 出现更极端例子的概率 = 接受独立假设的概率

parents = causal_discovery.discover(buffer, 0.1)

independent test (x[t], y[t+1]) done, p-value = 0.38669 
independent test (x[t], x[t+1]) done, p-value = 0.00000 
independent test (x[t], r[t+1]) done, p-value = 0.00001 
independent test (ax[t], y[t+1]) done, p-value = 0.86980 
independent test (ax[t], x[t+1]) done, p-value = 0.00001 
independent test (ax[t], r[t+1]) done, p-value = 0.04046 
independent test (ay[t], y[t+1]) done, p-value = 0.00524 
independent test (ay[t], x[t+1]) done, p-value = 0.00383 
independent test (ay[t], r[t+1]) done, p-value = 0.33870 
independent test (y[t], y[t+1]) done, p-value = 0.00000 
independent test (y[t], x[t+1]) done, p-value = 0.69509 
independent test (y[t], r[t+1]) done, p-value = 0.00000 
-------------------discovered-causal-graph---------------------
(ay[t], y[t]) --> y[t+1]
(x[t], ax[t], ay[t]) --> x[t+1]
(x[t], ax[t], y[t]) --> r[t+1]


{'y[t+1]': ['ay[t]', 'y[t]'],
 'x[t+1]': ['x[t]', 'ax[t]', 'ay[t]'],
 'r[t+1]': ['x[t]', 'ax[t]', 'y[t]']}

In [11]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.x = nn.parameter.Parameter(torch.randn(5, 5))

net = Net()
y = net.x[(2,3), :]
print(y)


tensor([[-2.0226e-01,  1.0011e+00,  4.1016e-02, -2.1168e-01,  4.4501e-02],
        [ 5.8520e-01,  7.5187e-03, -6.4849e-01,  6.0488e-04, -1.0947e-01]],
       grad_fn=<IndexBackward0>)
