# 1. Imports

In [1]:
from TModels import Reformer
from TModels.util import storage

import torch
import torch.nn as nn

In [2]:
storage.STORAGE_DIR

'.\\home\\storage\\'

# 2. Parameters

In [3]:
FEEDBACK = True
INPUT_DIMS = 11
OUTPUT_DIMS = 4
FEEDBACK_DIMS = INPUT_DIMS+ (OUTPUT_DIMS+1 if FEEDBACK else 0)
EMBED_SIZE = 32
MAX_SEQ_LEN = 512
LAYERS = 4
HEADS = 1
KV_HEADS = None
DIFFERENTIAL = True
BIAS = False

In [4]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float32

# 3. Create model

In [5]:
MODEL = Reformer(INPUT_DIMS+FEEDBACK_DIMS, OUTPUT_DIMS, 1, EMBED_SIZE, MAX_SEQ_LEN, LAYERS, HEADS, KV_HEADS,
                 DIFFERENTIAL, 0.1, BIAS, FEEDBACK, DEVICE, DTYPE,
                 pri_actv=nn.SiLU(), sec_actv=nn.Sigmoid(), prob=False, dist='normal')

In [6]:
MODEL.single_mode(False)

# 4. Create data

In [7]:
SEQ_LEN = 64
RECORDS = 10

In [8]:
train_inp = torch.randn(RECORDS, SEQ_LEN, INPUT_DIMS+FEEDBACK_DIMS).to(DEVICE, DTYPE)

# Test

In [9]:
torch.cuda.empty_cache()

In [10]:
with torch.no_grad():
    test_policy, test_prob = MODEL.get_action(train_inp)

In [11]:
test_policy[0, -4:].cpu().numpy()

array([[ 2.59615111e+00, -5.85088134e-01,  4.85139608e-01,
        -3.46977949e-01],
       [-2.89305329e-01, -1.16066456e-01, -4.58972633e-01,
         1.31866217e+00],
       [-4.77671623e-04,  5.83943188e-01,  3.92051876e-01,
         1.31441188e+00],
       [ 2.96391785e-01, -4.75585461e-02,  6.87694311e-01,
         6.51149929e-01]], dtype=float32)

In [12]:
test_prob[0, -4:].cpu().numpy()

array([[-4.704492  , -1.656739  , -0.63632584, -1.4085428 ],
       [-1.2683196 , -0.9586648 , -1.2923429 , -1.1007595 ],
       [-0.92307174, -0.5587231 , -0.5997149 , -1.1593663 ],
       [-0.688187  , -0.8738472 , -0.6514114 , -0.5681058 ]],
      dtype=float32)

In [13]:
with torch.no_grad():
    test_policy = MODEL(train_inp)
test_policy[0, -4:].cpu().numpy()

array([[-0.0128023 ,  0.05695152,  0.38148743,  0.3525976 ],
       [ 0.23588344,  0.22957674, -0.19486743,  0.03097785],
       [ 0.05630422, -0.11126813,  0.24086677,  0.16234511],
       [-0.00752723,  0.40036193,  0.2832898 , -0.09665078]],
      dtype=float32)

In [14]:
with torch.no_grad():
    test_latent = MODEL.get_latent(MODEL.pol_proj, train_inp, single=False)

In [15]:
with torch.no_grad():
    test_mean = MODEL.get_mean(test_latent)

In [16]:
test_mean[0].cpu().numpy()

array([[0.43208006, 0.44752106, 0.47871396, 0.547723  ],
       [0.49401018, 0.5832348 , 0.44461805, 0.48210973],
       [0.43682158, 0.4917918 , 0.4789674 , 0.5256978 ],
       [0.4365368 , 0.51397115, 0.49503326, 0.5118097 ],
       [0.43497112, 0.53272694, 0.46833318, 0.47734433],
       [0.43695086, 0.544677  , 0.48249856, 0.50018525],
       [0.46060246, 0.53284943, 0.47581288, 0.4692252 ],
       [0.44859645, 0.69355255, 0.46950236, 0.5190076 ],
       [0.44916576, 0.6219628 , 0.43844798, 0.4515098 ],
       [0.4766424 , 0.6250023 , 0.4468464 , 0.44561255],
       [0.43824303, 0.5813689 , 0.4555352 , 0.49054345],
       [0.4441731 , 0.55428934, 0.43725082, 0.49454314],
       [0.4401511 , 0.584599  , 0.4812964 , 0.45642427],
       [0.43978885, 0.70164394, 0.47258583, 0.48645264],
       [0.44141406, 0.5924039 , 0.46765956, 0.44685724],
       [0.43406504, 0.5344792 , 0.4884993 , 0.43951353],
       [0.48876378, 0.6983953 , 0.4381847 , 0.43885693],
       [0.50692505, 0.59518   ,

In [17]:
with torch.no_grad():
    test_std = MODEL.get_std(test_latent)

In [18]:
test_std[0].cpu().numpy()

array([[0.7181406 , 0.72484815, 0.787713  , 0.8607458 ],
       [0.69976115, 0.7003555 , 0.82031393, 0.8551404 ],
       [0.7903768 , 0.7305795 , 0.74342024, 0.88760835],
       [0.7200433 , 0.7294413 , 0.8122717 , 0.8635201 ],
       [0.7473584 , 0.7091983 , 0.7604784 , 0.80489016],
       [0.6962748 , 0.7432739 , 0.8214704 , 0.91067934],
       [0.7337147 , 0.6941726 , 0.79076654, 0.87454516],
       [0.77653056, 0.7217472 , 0.8813976 , 0.8278057 ],
       [0.77316654, 0.7314873 , 0.75858974, 0.80655366],
       [0.74464273, 0.68631893, 0.7662988 , 0.82410944],
       [0.71258247, 0.77709174, 0.7318598 , 0.86048335],
       [0.75681466, 0.70879364, 0.72987336, 0.75812316],
       [0.7653966 , 0.7639707 , 0.76253873, 0.8888306 ],
       [0.835681  , 0.7521084 , 0.7494852 , 0.8671004 ],
       [0.7254023 , 0.7317365 , 0.79839486, 0.8133906 ],
       [0.74730754, 0.7329521 , 0.7693707 , 0.7986905 ],
       [0.76973236, 0.69859236, 0.7760682 , 0.85222226],
       [0.7430782 , 0.70719   ,

In [19]:
MODEL.mean_std

Linear(in_features=32, out_features=8, bias=False)

In [20]:
test_latent.shape

torch.Size([10, 64, 32])

In [21]:
with torch.no_grad():
    test_dist, test_extra = MODEL.dist(test_mean, test_std, test_latent, verbose=2)

In [22]:
test_dist is None

False

In [23]:
test_corr_params, test_corr_matrix, test_std_diag = test_extra['corr_params'], test_extra['corr_matrix'], test_extra['std_diag']

KeyError: 'corr_params'

In [None]:
test_corr_params[0, -1]

In [None]:
test_corr_matrix[0, -1]

In [None]:
test_std_diag[0, -1]

In [None]:
(test_std_diag @ test_std_diag)[0, -1]

In [None]:
(test_std_diag @ test_corr_matrix @ test_std_diag)[0, -1]

In [None]:
MODEL.single_mode(False)
MODEL.single