# 1. Imports

In [1]:
from TModels import Reformer

import torch
import torch.nn as nn

# 2. Parameters

In [2]:
FEEDBACK = True
INPUT_DIMS = 11
OUTPUT_DIMS = 4
FEEDBACK_DIMS = INPUT_DIMS+ (OUTPUT_DIMS+1 if FEEDBACK else 0)
EMBED_SIZE = 32
MAX_SEQ_LEN = 512
LAYERS = 4
HEADS = 1
KV_HEADS = None
DIFFERENTIAL = True
BIAS = False

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float32

# 3. Create model

In [4]:
MODEL = Reformer(INPUT_DIMS+FEEDBACK_DIMS, OUTPUT_DIMS, 1, EMBED_SIZE, MAX_SEQ_LEN, LAYERS, HEADS, KV_HEADS,
                 DIFFERENTIAL, 0.1, BIAS, FEEDBACK, DEVICE, DTYPE,
                 pri_actv=nn.SiLU(), sec_actv=nn.Sigmoid(), prob=False, dist='normal')

In [5]:
MODEL.single_mode(False)

# 4. Create data

In [6]:
SEQ_LEN = 64
RECORDS = 10

In [7]:
train_inp = torch.randn(RECORDS, SEQ_LEN, INPUT_DIMS+FEEDBACK_DIMS).to(DEVICE, DTYPE)

# Test

In [8]:
torch.cuda.empty_cache()

In [9]:
with torch.no_grad():
    test_policy, test_prob = MODEL.get_action(train_inp)

In [10]:
test_policy[0, -4:].cpu().numpy()

array([[ 0.93140054, -0.1368388 ,  0.16179374,  0.06433627],
       [ 1.8779953 ,  1.4328828 , -0.9313403 ,  1.3934828 ],
       [ 0.60219544,  0.3142664 ,  0.8114655 ,  0.5416646 ],
       [ 0.73250127,  0.5142795 , -1.1129389 ,  0.04497951]],
      dtype=float32)

In [11]:
test_prob[0, -4:].cpu().numpy()

array([[-0.77868885, -1.0432377 , -0.7311285 , -0.8478738 ],
       [-2.3040166 , -1.3584607 , -2.5207262 , -1.3140364 ],
       [-0.6603654 , -0.6749545 , -0.6967725 , -0.6732235 ],
       [-0.7112341 , -0.6355683 , -3.014853  , -0.8758185 ]],
      dtype=float32)

In [12]:
with torch.no_grad():
    test_policy = MODEL(train_inp)
test_policy[0, -4:].cpu().numpy()

array([[ 0.02161962,  0.5002566 ,  0.16197279,  0.14268732],
       [ 0.18546641,  0.11883789,  0.09956068,  0.3233798 ],
       [ 0.06234944,  0.09738401,  0.36293542, -0.062895  ],
       [ 0.000696  , -0.08218002,  0.26579368, -0.25191885]],
      dtype=float32)

In [13]:
with torch.no_grad():
    test_latent = MODEL.get_latent(MODEL.pol_proj, train_inp, single=False)

In [14]:
with torch.no_grad():
    test_mean = MODEL.get_mean(test_latent)

In [15]:
test_mean[0].cpu().numpy()

array([[0.44983318, 0.5848123 , 0.54332495, 0.6388184 ],
       [0.43702668, 0.606492  , 0.5334037 , 0.5527589 ],
       [0.43265912, 0.50567967, 0.5197109 , 0.54649854],
       [0.4349234 , 0.505543  , 0.59839267, 0.5323648 ],
       [0.43226507, 0.4881165 , 0.59302056, 0.5958958 ],
       [0.43614095, 0.53665197, 0.5898036 , 0.52296   ],
       [0.43172356, 0.49527866, 0.5711678 , 0.5259811 ],
       [0.435579  , 0.5724082 , 0.5239279 , 0.63131624],
       [0.43562177, 0.5897793 , 0.5326336 , 0.6290742 ],
       [0.43309802, 0.53434753, 0.5217888 , 0.5189388 ],
       [0.4312061 , 0.45381516, 0.55461645, 0.50729096],
       [0.4312908 , 0.5610524 , 0.6637143 , 0.56940633],
       [0.43128207, 0.56198937, 0.5659406 , 0.5490827 ],
       [0.43998963, 0.47985917, 0.57044035, 0.5753273 ],
       [0.4355295 , 0.5361447 , 0.4929497 , 0.5245417 ],
       [0.45190054, 0.6018764 , 0.511063  , 0.5497908 ],
       [0.4365834 , 0.580499  , 0.5907361 , 0.48919868],
       [0.4318971 , 0.53598243,

In [16]:
with torch.no_grad():
    test_std = MODEL.get_std(test_latent)

In [17]:
test_std[0].cpu().numpy()

array([[0.7753242 , 0.6820403 , 0.8890484 , 0.80441153],
       [0.8225698 , 0.6813627 , 0.8539791 , 0.7598093 ],
       [0.7714612 , 0.6815383 , 0.8370975 , 0.8007763 ],
       [0.7892363 , 0.6812275 , 0.8006683 , 0.8114363 ],
       [0.71163225, 0.6807042 , 0.8181736 , 0.83070505],
       [0.7998198 , 0.6840701 , 0.75121623, 0.76964825],
       [0.72359955, 0.68321925, 0.7350376 , 0.7448467 ],
       [0.76962924, 0.6807617 , 0.7803558 , 0.7691398 ],
       [0.7817997 , 0.68079823, 0.7706167 , 0.75648844],
       [0.7318111 , 0.68186605, 0.7890731 , 0.7796617 ],
       [0.76162815, 0.69903004, 0.7040756 , 0.7480933 ],
       [0.76374865, 0.68076813, 0.7788914 , 0.75523555],
       [0.7383492 , 0.6806901 , 0.78418934, 0.7797598 ],
       [0.74161166, 0.68090844, 0.7617625 , 0.8023399 ],
       [0.80099666, 0.6835131 , 0.77908677, 0.7865932 ],
       [0.8154441 , 0.6806965 , 0.7445138 , 0.7540894 ],
       [0.7604937 , 0.6810342 , 0.8466199 , 0.8103652 ],
       [0.82560456, 0.683379  ,

In [18]:
MODEL.mean_std

Linear(in_features=32, out_features=8, bias=False)

In [19]:
test_latent.shape

torch.Size([10, 64, 32])

In [20]:
with torch.no_grad():
    test_dist, test_extra = MODEL.dist(test_mean, test_std, test_latent, verbose=2)

In [21]:
test_dist is None

False

In [22]:
test_corr_params, test_corr_matrix, test_std_diag = test_extra['corr_params'], test_extra['corr_matrix'], test_extra['std_diag']

KeyError: 'corr_params'

In [None]:
test_corr_params[0, -1]

In [None]:
test_corr_matrix[0, -1]

In [None]:
test_std_diag[0, -1]

In [None]:
(test_std_diag @ test_std_diag)[0, -1]

In [None]:
(test_std_diag @ test_corr_matrix @ test_std_diag)[0, -1]

In [None]:
MODEL.single_mode(False)
MODEL.single