In [121]:
!pip --version
import sys
sys.path.append("/Users/kerekmen/miniconda3/envs/s4/lib/python3.12/site-packages")
print(sys.path)

import torch
torch.random.manual_seed(0)
import numpy as np

pip 24.0 from /Users/kerekmen/miniconda3/envs/s4/lib/python3.12/site-packages/pip (python 3.12)
['/opt/homebrew/Cellar/python@3.12/3.12.3/Frameworks/Python.framework/Versions/3.12/lib/python312.zip', '/opt/homebrew/Cellar/python@3.12/3.12.3/Frameworks/Python.framework/Versions/3.12/lib/python3.12', '/opt/homebrew/Cellar/python@3.12/3.12.3/Frameworks/Python.framework/Versions/3.12/lib/python3.12/lib-dynload', '', '/opt/homebrew/Cellar/jupyterlab/4.1.6_1/libexec/lib/python3.12/site-packages', '/opt/homebrew/opt/z3/lib/python3.12/site-packages', '/opt/homebrew/opt/llvm/lib/python3.12/site-packages', '/opt/homebrew/opt/certifi/lib/python3.12/site-packages', '/opt/homebrew/opt/z3/lib/python3.12/site-packages', '/opt/homebrew/opt/llvm/lib/python3.12/site-packages', '/opt/homebrew/lib/python3.12/site-packages', '/Users/kerekmen/miniconda3/envs/s4/lib/python3.12/site-packages', '/Users/kerekmen/miniconda3/envs/s4/lib/python3.12/site-packages']


In [122]:
def random_SSM(N):
    A = torch.randn(N, N)
    B = torch.randn(N, 1)
    C = torch.randn(1, N)
    return A, B, C

In [123]:
def discretize(A, B, C, step):
    I = np.eye(A.shape[0])
    BL = torch.linalg.inv(torch.tensor(I - (step / 2.0) * A))
    Ab = BL @ (I + (step / 2.0) * A)
    Bb = (BL * step) @ B
    return Ab, Bb, C

In [124]:
def scan_SSM(Ab, Bb, Cb, u, x0):
    """
    Simulate the state-space model using a for-loop to replicate JAX lax.scan functionality.

    Parameters:
    - Ab (torch.Tensor): The state transition matrix.
    - Bb (torch.Tensor): The input matrix.
    - Cb (torch.Tensor): The output matrix.
    - u (torch.Tensor): The input sequence (time steps, input_dim).
    - x0 (torch.Tensor): The initial state.

    Returns:
    - x_out (torch.Tensor): Sequence of states.
    - y_out (torch.Tensor): Sequence of outputs.
    """
    Ab = torch.tensor(Ab)
    Bb = torch.tensor(Bb)
    Cb = torch.tensor(Cb)
    
    
    def step(x_k_1, u_k):
        x_k_1 = torch.tensor(x_k_1)
        u_k = torch.tensor(u_k)
        
        x_k = Ab @ x_k_1 + Bb @ u_k
        y_k = Cb @ x_k
        return x_k, y_k

    x_out = [x0]
    y_out = []

    x_k = x0
    for u_k in u:
        x_k, y_k = step(x_k, u_k)
        x_out.append(x_k)
        y_out.append(y_k)

    return torch.stack(x_out[1:]), torch.stack(y_out)

In [125]:
def run_SSM(A, B, C, u):
    L = u.shape[0]
    N = A.shape[0]
    Ab, Bb, Cb = discretize(A, B, C, step=1.0 / L)
    # Run recurrence
    return scan_SSM(Ab, Bb, Cb, u[:, np.newaxis], np.zeros(N,))[1]

# example of calculting ode in mechanics

In [126]:
def example_force(t):
    x = np.sin(10 * t)
    return x * (x > 0.5)

In [127]:
def example_mass(k, b, m):
    A = np.array([[0, 1], [-k / m, -b / m]])
    B = np.array([[0], [1.0 / m]])
    C = np.array([[1.0, 0]])
    return A, B, C

In [130]:
def example_ssm():
    # SSM
    ssm = example_mass(k=40, b=5, m=1)

    # L samples of u(t).
    L = 100
    step = 1.0 / L
    ks = np.arange(L)
    u = example_force(ks * step)

    # Approximation of y(t).
    y = run_SSM(*ssm, u)
    print(y.shape)

In [131]:
example_ssm()

torch.Size([100, 1])


  Ab = torch.tensor(Ab)
  Bb = torch.tensor(Bb)
  x_k_1 = torch.tensor(x_k_1)


In [139]:
np.pad(np.arange(10), (0, 20))

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0])

In [145]:
np.random.rand(10,).shape

(10,)

In [148]:
torch.randn(10,)

tensor([-0.5966,  0.1820, -0.8567,  1.1006, -1.0712,  0.1227, -0.5663,  0.3731,
        -0.8920, -1.5091])

In [149]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dummy input tensor of shape [16]
input_tensor = torch.randn(16)

# Expand to [1, 1, 16] for convolution
expanded_input = input_tensor.unsqueeze(0).unsqueeze(0)

print(f'Expanded input shape: {expanded_input.shape}')  # Should be [1, 1, 16]

# Define a convolution layer with the kernel of size (16, 1, 1)
conv_layer = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=16)

# Apply convolution
output = conv_layer(expanded_input)

print(f'Output shape: {output.shape}')


Expanded input shape: torch.Size([1, 1, 16])
Output shape: torch.Size([1, 1, 1])


In [173]:
a = torch.randn(16).view(4, -1)

In [174]:
a = a.triu(1)
a

tensor([[ 0.0000, -1.1991, -0.0257,  1.8024],
        [ 0.0000,  0.0000, -0.5687, -0.4755],
        [ 0.0000,  0.0000,  0.0000,  1.2937],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])

In [176]:
a - a.transpose(-1, -2)

tensor([[ 0.0000, -1.1991, -0.0257,  1.8024],
        [ 1.1991,  0.0000, -0.5687, -0.4755],
        [ 0.0257,  0.5687,  0.0000,  1.2937],
        [-1.8024,  0.4755, -1.2937,  0.0000]])

In [196]:
torch.dot                            # [D], [D] -> []
batched_dot = torch.vmap(torch.dot, in_dims=0)  # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)

tensor([0.5583, 0.4664])

In [197]:
x[0].dot(y[0])

tensor(0.5583)

In [198]:
x[1].dot(y[1]) 

tensor(0.4664)

In [212]:
layer = nn.Linear(10, 1)
linear_new = torch.vmap(layer, in_dims=0)

In [213]:
linear_old = nn.ModuleList([layer for i in range(2)])

In [214]:
input = torch.randn(5, 10)

x = []
for lin in linear_old:
    x.append(lin(input))
x

[tensor([[-0.4168],
         [ 0.5909],
         [-0.6702],
         [ 0.8237],
         [-0.0082]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4168],
         [ 0.5909],
         [-0.6702],
         [ 0.8237],
         [-0.0082]], grad_fn=<AddmmBackward0>)]

In [216]:
linear_new(input)

tensor([[-0.4168],
        [ 0.5909],
        [-0.6702],
        [ 0.8237],
        [-0.0082]], grad_fn=<AddBackward0>)

In [228]:
input = torch.arange(20).view(2, 2, 5) # B x Samples x features
input

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]]])

In [221]:
x = list()
for i in input:
    x.append(i)

In [229]:
torch.stack(x).shape

torch.Size([5, 2, 3])

In [232]:
torch.sum(input, dim=-1)

tensor([[10, 35],
        [60, 85]])

In [241]:
def generate_look_ahead_mask(size):
    # old: mask = torch.triu(torch.ones(size, size) * float('-inf'), diagonal=1)
    # new from: https://discuss.pytorch.org/t/attn-mask-in-nn-multiheadattention/173603/3
    # seems like we got the attention mask wrong, it should be -inf for the future tokens
    # so we need to set the diagonal to 0 and the upper triangle to -inf
    arr = [[-np.inf for _ in range(size)] for _ in range(size)]
    arr = torch.tensor(arr)
    mask = torch.triu(arr, diagonal=1)
    return mask

mask = generate_look_ahead_mask(5)

In [242]:
torch.softmax(mask, dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

In [243]:
mask

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [244]:
src = torch.tensor([13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
         1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,
         1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87,
        83,  1, 46, 82, 15,  1, 40, 65, 82, 82]) 

tgt = torch.tensor([ 3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,  1,
        70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,  1,
        34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87, 83,
         1, 46, 82, 15,  1, 40, 65, 82, 82, 73])

In [277]:
torch.cat((src, tgt[-1:None]))

tensor([13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
         1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,
         1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87,
        83,  1, 46, 82, 15,  1, 40, 65, 82, 82, 73])

In [287]:
m = torch.tensor(np.array([(src, tgt), (src, tgt)]))

In [307]:
torch.cat((m[:, 0, :], m[:, 1, -1:None]), dim=-1)

tensor([[13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
          1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,
          1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87,
         83,  1, 46, 82, 15,  1, 40, 65, 82, 82, 73],
        [13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
          1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,
          1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87,
         83,  1, 46, 82, 15,  1, 40, 65, 82, 82, 73]])

In [294]:
m[:, 0, :]

tensor([[13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
          1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,
          1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87,
         83,  1, 46, 82, 15,  1, 40, 65, 82, 82],
        [13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
          1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80, 13,
          1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72, 79, 87,
         83,  1, 46, 82, 15,  1, 40, 65, 82, 82]])

In [295]:
m

tensor([[[13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82,
          89,  1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73,
          80, 13,  1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83,
          72, 79, 87, 83,  1, 46, 82, 15,  1, 40, 65, 82, 82],
         [ 3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
           1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73, 80,
          13,  1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83, 72,
          79, 87, 83,  1, 46, 82, 15,  1, 40, 65, 82, 82, 73]],

        [[13,  3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82,
          89,  1, 70, 79, 82,  1, 65,  1, 70, 73, 69, 76, 68,  1, 84, 82, 73,
          80, 13,  1, 34,  1, 84, 79, 85, 82,  1, 71, 85, 73, 68, 69,  1, 83,
          72, 79, 87, 83,  1, 46, 82, 15,  1, 40, 65, 82, 82],
         [ 3, 34, 84,  1, 70, 85, 68, 71, 69,  1, 70, 65, 67, 84, 79, 82, 89,
           1, 70, 79, 82,  1,