# 1. Imports

In [1]:
from build import Reformer

import torch
import torch.nn as nn

# 2. Parameters

In [2]:
FEEDBACK = True
INPUT_DIMS = 11
OUTPUT_DIMS = 4
FEEDBACK_DIMS = INPUT_DIMS+ (OUTPUT_DIMS+1 if FEEDBACK else 0)
EMBED_SIZE = 32
MAX_SEQ_LEN = 512
LAYERS = 4
HEADS = 1
KV_HEADS = None
DIFFERENTIAL = True
BIAS = False

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float32

# 3. Create model

In [4]:
MODEL = Reformer(INPUT_DIMS+FEEDBACK_DIMS, OUTPUT_DIMS, 1, EMBED_SIZE, MAX_SEQ_LEN, LAYERS, HEADS, KV_HEADS,
                 DIFFERENTIAL, 0.1, BIAS, FEEDBACK, DEVICE, DTYPE,
                 pri_actv=nn.SiLU(), sec_actv=nn.Sigmoid(), prob=False, dist='normal')

In [5]:
MODEL.single_mode(False)

# 4. Create data

In [6]:
SEQ_LEN = 64
RECORDS = 10

In [7]:
train_inp = torch.randn(RECORDS, SEQ_LEN, INPUT_DIMS+FEEDBACK_DIMS).to(DEVICE, DTYPE)

# Test

In [8]:
torch.cuda.empty_cache()

In [9]:
with torch.no_grad():
    test_policy, test_prob = MODEL.get_action(train_inp)

In [10]:
test_policy[0, -4:].cpu().numpy()

array([[ 0.31315535,  1.4885674 ,  1.9157126 , -0.42636538],
       [ 0.42192072,  0.7216093 , -0.7566068 ,  0.35240996],
       [ 1.183352  ,  0.18598446,  0.52469844,  1.1271529 ],
       [ 0.8829951 ,  0.08699873,  0.17565444,  0.921885  ]],
      dtype=float32)

In [11]:
test_prob[0, -4:].cpu().numpy()

array([[-0.67616236, -1.6321081 , -2.5279925 , -1.4023129 ],
       [-0.6547201 , -0.58080894, -2.054491  , -0.5262951 ],
       [-1.0431573 , -0.6434283 , -0.63612366, -1.0166552 ],
       [-0.74163353, -0.7197501 , -0.7030773 , -0.78519225]],
      dtype=float32)

In [12]:
with torch.no_grad():
    test_policy = MODEL(train_inp)
test_policy[0, -4:].cpu().numpy()

array([[-1.10805035e-01, -3.91961932e-02, -2.12976933e-01,
         1.22021943e-01],
       [-1.53779984e-05,  8.38030577e-02, -1.33778602e-01,
         3.53544682e-01],
       [ 1.26930714e-01,  1.01579577e-01,  3.81718278e-02,
         3.38917822e-01],
       [ 3.58568192e-01,  1.76797777e-01,  4.39589381e-01,
         4.05766398e-01]], dtype=float32)

In [13]:
with torch.no_grad():
    test_latent = MODEL.get_latent(MODEL.pol_proj, train_inp, single=False)

In [14]:
with torch.no_grad():
    test_mean = MODEL.get_mean(test_latent)

In [15]:
test_mean[0].cpu().numpy()

array([[0.44488418, 0.45048523, 0.6008035 , 0.50486857],
       [0.4825362 , 0.45154715, 0.53547   , 0.6276694 ],
       [0.4698878 , 0.45722798, 0.51470953, 0.56414926],
       [0.4633518 , 0.45819384, 0.59464437, 0.57283944],
       [0.46823412, 0.44362444, 0.6136421 , 0.5506484 ],
       [0.45120934, 0.45098773, 0.6221001 , 0.5413138 ],
       [0.47960413, 0.4329936 , 0.5710971 , 0.47447306],
       [0.49250448, 0.44657254, 0.5839204 , 0.570655  ],
       [0.48740354, 0.45517468, 0.5883198 , 0.5405111 ],
       [0.49157447, 0.44367963, 0.6003924 , 0.5817482 ],
       [0.4753    , 0.45985132, 0.6135741 , 0.5414556 ],
       [0.4953502 , 0.4457776 , 0.61071974, 0.5423265 ],
       [0.5162331 , 0.46354803, 0.56420225, 0.54153186],
       [0.4941462 , 0.45521453, 0.6219086 , 0.5538573 ],
       [0.4886489 , 0.4434235 , 0.64752394, 0.54139715],
       [0.4587053 , 0.43975934, 0.621249  , 0.5412699 ],
       [0.4572217 , 0.4411254 , 0.6207407 , 0.552625  ],
       [0.5064129 , 0.43935046,

In [16]:
with torch.no_grad():
    test_std = MODEL.get_std(test_latent)

In [17]:
test_std[0].cpu().numpy()

array([[0.69712293, 0.9076096 , 0.74140614, 0.72026736],
       [0.6862685 , 0.8994848 , 0.7031754 , 0.71419966],
       [0.6846762 , 0.85050637, 0.69768274, 0.7222807 ],
       [0.684757  , 0.88155293, 0.71169007, 0.7059252 ],
       [0.7145288 , 0.91787475, 0.7219237 , 0.7128739 ],
       [0.6877118 , 0.8985749 , 0.7066095 , 0.7120015 ],
       [0.70035946, 0.8709994 , 0.73345464, 0.7265952 ],
       [0.6900642 , 0.88900787, 0.69407654, 0.7442038 ],
       [0.6933506 , 0.86473525, 0.71072465, 0.721922  ],
       [0.69135827, 0.90861195, 0.70091754, 0.7074975 ],
       [0.6885493 , 0.8785931 , 0.70137423, 0.7128997 ],
       [0.69341964, 0.85316217, 0.7173903 , 0.6983367 ],
       [0.68232435, 0.9235999 , 0.69678205, 0.7085387 ],
       [0.6932842 , 0.91135967, 0.7084982 , 0.7241763 ],
       [0.6925472 , 0.8946854 , 0.7315603 , 0.7725883 ],
       [0.6964093 , 0.90416276, 0.7388962 , 0.72595763],
       [0.68400955, 0.8595813 , 0.7218003 , 0.7703309 ],
       [0.69406813, 0.9251643 ,

In [18]:
MODEL.mean_std

Linear(in_features=32, out_features=8, bias=False)

In [19]:
test_latent.shape

torch.Size([10, 64, 32])

In [20]:
with torch.no_grad():
    test_dist, test_extra = MODEL.dist(test_mean, test_std, test_latent, verbose=2)

In [21]:
test_dist is None

False

In [22]:
test_corr_params, test_corr_matrix, test_std_diag = test_extra['corr_params'], test_extra['corr_matrix'], test_extra['std_diag']

KeyError: 'corr_params'

In [None]:
test_corr_params[0, -1]

In [None]:
test_corr_matrix[0, -1]

In [None]:
test_std_diag[0, -1]

In [None]:
(test_std_diag @ test_std_diag)[0, -1]

In [None]:
(test_std_diag @ test_corr_matrix @ test_std_diag)[0, -1]

In [None]:
MODEL.single_mode(False)
MODEL.single