# Neural Memory System - PyTorchEncoderLinear demo

## Environment setup

In [1]:
import os
from pathlib import Path

In [2]:
CURRENT_FOLDER = Path(os.getcwd())

In [3]:
CD_KEY = "--PYTORCH_ENCODER_LINEAR_DEMO_IN_ROOT"

if (
    CD_KEY not in os.environ
    or os.environ[CD_KEY] is None
    or len(os.environ[CD_KEY]) == 0
    or os.environ[CD_KEY] == "false"
):
    %cd -q ../../..
    
    ROOT_FOLDER = Path(os.getcwd()).relative_to(os.getcwd())
    CURRENT_FOLDER = CURRENT_FOLDER.relative_to(ROOT_FOLDER.absolute())
    
os.environ[CD_KEY] = "true"

In [4]:
print(f"Root folder:    {ROOT_FOLDER}")
print(f"Current folder: {CURRENT_FOLDER}")

Root folder:    .
Current folder: demo/components/encoders


## Modules

In [5]:
import torch
import torch.nn

In [6]:
from nemesys.modelling.encoders.modules.pytorch_encoder_linear import PyTorchEncoderLinear

In [7]:
torch.set_printoptions(sci_mode=False)

## Encoder setup

In [8]:
content_key = "content"

in_features = 128
out_features = 4

In [9]:
encoder = PyTorchEncoderLinear(
    in_features=in_features,
    out_features=out_features,
    content_key=content_key,
)

## Data setup

In [10]:
batch_size = 4
sequence_length = 4

In [11]:
data_batch_tensor = torch.normal(
    mean=0, std=1, size=(batch_size, sequence_length, in_features)
)

In [12]:
data_batch = {
    content_key: data_batch_tensor,
}

## Results

In [13]:
result = encoder(data_batch)

In [14]:
print(result)

{'content': tensor([[[-1.0879,  0.7450,  0.1692, -0.4972],
         [-0.1790,  0.0096, -0.4710, -0.1816],
         [ 0.2590, -1.0448,  0.7050,  0.8934],
         [-0.8352, -1.8020, -0.7525,  0.1799]],

        [[-0.4038,  0.0855, -0.1205,  0.7312],
         [ 0.0677, -0.6606, -0.4024,  0.2764],
         [ 0.4004,  0.1318,  0.4349, -0.8467],
         [ 0.4286,  1.1520,  0.4444,  0.6032]],

        [[-0.0438, -0.3562,  0.9923, -0.5161],
         [-0.1946,  0.6147, -0.1772, -0.1157],
         [-0.1370,  1.4367, -0.1107, -0.5074],
         [ 0.5097,  0.4605, -1.0206, -0.4527]],

        [[ 0.5473, -0.2980, -0.5326,  0.6689],
         [-0.6044, -0.5035, -0.5976,  0.1729],
         [ 0.1855, -0.1002, -0.3496,  0.8018],
         [-0.5788, -0.8932,  0.2982, -0.4266]]], grad_fn=<UnsafeViewBackward>)}
