In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

In [2]:
# Set the seed for reproducibility
torch.manual_seed(0)

<torch._C.Generator at 0x2b07eb2c190>

## Let's Compute the Exponential Moving Average (EMA) of Matrix $ R $
$$
\alpha_t = \tanh \left( W_\alpha \cdot \left( r_{t-1} \| r_t \right) + b_\alpha \right)
$$

$$
r_t^{EMA} = \alpha_t \odot r_t + \left( (1-\alpha_t) \odot \delta_t \right) \odot r_{t-1}
$$

## Using `for` Loop

In [3]:

# Define the dimensions
batch_size = 10  # Example batch size
input_dim = 5    # Dimension of the input vectors
hidden_dim = 5   # Dimension of the hidden vectors (can be different from input_dim)

# Define the input matrices
R = torch.randn(batch_size, input_dim)  # Example input matrix R (batch_size, input_dim)

# Initialize the parameters
W_alpha = torch.randn(hidden_dim, 2 * input_dim)  # Weight matrix for alpha (hidden_dim, 2*input_dim)
b_alpha = torch.randn(hidden_dim)                 # Bias vector for alpha (hidden_dim)

# Initialize delta_t, assuming delta_t has the same dimension as r_t
delta_t = torch.randn(batch_size, input_dim)  # Example delta_t (batch_size, input_dim)

# Function to compute alpha_t and r_t^EMA for each time step
def compute_r_t_EMA(R, W_alpha, b_alpha, delta_t):
    batch_size, input_dim = R.shape
    r_t_EMA = torch.zeros_like(R)  # Initialize the EMA matrix
    R = torch.cat((torch.zeros(1, input_dim), R), dim=0)
    batch_size, input_dim = R.shape
    for t in range(1, batch_size):
        r_t_prev = R[t-1]  # r_{t-1}
        r_t = R[t]         # r_t

        # Compute alpha_t
        concat_r = torch.cat((r_t_prev, r_t), dim=-1)  # Concatenate r_{t-1} and r_t
        alpha_t = torch.tanh(F.linear(concat_r, W_alpha, b_alpha))  # Compute alpha_t

        # Compute r_t^EMA
        r_t_EMA[t-1] = alpha_t * r_t + (1 - alpha_t) * delta_t[t-1] * r_t_prev

    return r_t_EMA

# Compute the EMA
r_t_EMA = compute_r_t_EMA(R, W_alpha, b_alpha, delta_t)

# Print the results
print("R:")
print(R)
print("\nr_t^EMA:")
print(r_t_EMA)

R:
tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487],
        [ 0.6920, -0.3160, -2.1152,  0.3223, -1.2633],
        [ 0.3500,  0.3081,  0.1198,  1.2377,  1.1168],
        [-0.2473, -1.3527, -1.6959,  0.5667,  0.7935],
        [ 0.5988, -1.5551, -0.3414,  1.8530,  0.7502],
        [-0.5855, -0.1734,  0.1835,  1.3894,  1.5863],
        [ 0.9463, -0.8437, -0.6136,  0.0316,  1.0554],
        [ 0.1778, -0.2303, -0.3918,  0.5433, -0.3952],
        [ 0.2055, -0.4503,  1.5210,  3.4105, -1.5312],
        [-1.2341,  1.8197, -0.5515, -1.3253,  0.1886]])

r_t^EMA:
tensor([[ 1.1255, -1.1315, -0.2501, -0.4319, -0.8481],
        [-1.5625, -0.2751,  0.2855,  0.3203, -1.2620],
        [ 0.3421,  0.0778,  0.1201, -0.9087, -0.4462],
        [ 0.5169, -1.3516, -1.6827,  0.5665, -0.9533],
        [-1.0922, -1.5451, -0.4545,  1.8512,  0.7349],
        [-0.1794, -0.1532, -0.0430,  1.3810,  1.5742],
        [-1.2180, -0.8367,  0.3979,  0.0316,  1.0524],
        [ 0.5649, -0.1781, -0.4797,  0.1974, -0.360

## Using Matrix Operations for Efficient Computation
Converting the equations to matrix form allows PyTorch to efficiently compute gradients during backpropagation. This approach leverages batch processing, which is computationally efficient and essential for learning weight matrices.

In [4]:


# Define the dimensions
# batch_size = 10  # Example batch size
# input_dim = 5    # Dimension of the input vectors
# hidden_dim = 5   # Dimension of the hidden vectors (can be different from input_dim)

# Define the input matrices
# R = torch.randn(batch_size, input_dim)  # Example input matrix R (batch_size, input_dim)

# Initialize the parameters
# W_alpha = torch.randn(hidden_dim, 2 * input_dim)  # Weight matrix for alpha (hidden_dim, 2*input_dim)
# b_alpha = torch.randn(hidden_dim)                 # Bias vector for alpha (hidden_dim)

# Initialize delta_t, assuming delta_t has the same dimension as r_t
# delta_t = torch.randn(batch_size, input_dim)  # Example delta_t (batch_size, input_dim)

# Function to compute alpha_t and r_t^EMA for each time step
def compute_r_t_EMA_matrix(R, W_alpha, b_alpha, delta_t):
    batch_size, input_dim = R.shape

    R = torch.cat((torch.zeros(1, input_dim), R), dim=0)
    batch_size, input_dim = R.shape
    # Prepare r_t and r_{t-1} matrices
    r_t = R[1:]                # Exclude the first row
    r_t_prev = R[:-1]          # Exclude the last row

    
    # Compute alpha_t
    concat_r = torch.cat((r_t_prev, r_t), dim=1)  # Concatenate r_{t-1} and r_t along the feature dimension
    alpha_t = torch.tanh(F.linear(concat_r, W_alpha, b_alpha))  # Compute alpha_t

    # Compute r_t^EMA
    r_t_EMA = alpha_t * r_t + (1 - alpha_t) * delta_t * r_t_prev

    return r_t_EMA

# Compute the EMA
r_t_EMA = compute_r_t_EMA_matrix(R, W_alpha, b_alpha, delta_t)

# Print the results
print("R:")
print(R)
print("\nr_t^EMA:")
print(r_t_EMA)


R:
tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487],
        [ 0.6920, -0.3160, -2.1152,  0.3223, -1.2633],
        [ 0.3500,  0.3081,  0.1198,  1.2377,  1.1168],
        [-0.2473, -1.3527, -1.6959,  0.5667,  0.7935],
        [ 0.5988, -1.5551, -0.3414,  1.8530,  0.7502],
        [-0.5855, -0.1734,  0.1835,  1.3894,  1.5863],
        [ 0.9463, -0.8437, -0.6136,  0.0316,  1.0554],
        [ 0.1778, -0.2303, -0.3918,  0.5433, -0.3952],
        [ 0.2055, -0.4503,  1.5210,  3.4105, -1.5312],
        [-1.2341,  1.8197, -0.5515, -1.3253,  0.1886]])

r_t^EMA:
tensor([[ 1.1255, -1.1315, -0.2501, -0.4319, -0.8481],
        [-1.5625, -0.2751,  0.2855,  0.3203, -1.2620],
        [ 0.3421,  0.0778,  0.1201, -0.9087, -0.4462],
        [ 0.5169, -1.3516, -1.6827,  0.5665, -0.9533],
        [-1.0922, -1.5451, -0.4545,  1.8512,  0.7349],
        [-0.1794, -0.1532, -0.0430,  1.3810,  1.5742],
        [-1.2180, -0.8367,  0.3979,  0.0316,  1.0524],
        [ 0.5649, -0.1781, -0.4797,  0.1974, -0.360

### Step By Step

In [5]:
R = torch.cat((torch.zeros(1, input_dim), R), dim=0)
r_t = R[1:]                # Exclude the first row
r_t_prev = R[:-1] 
concat_r = torch.cat((r_t_prev, r_t), dim=1)  # Concatenate r_{t-1} and r_t along the feature dimension
print(concat_r)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1258, -1.1524, -0.2506,
         -0.4339,  0.8487],
        [-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152,
          0.3223, -1.2633],
        [ 0.6920, -0.3160, -2.1152,  0.3223, -1.2633,  0.3500,  0.3081,  0.1198,
          1.2377,  1.1168],
        [ 0.3500,  0.3081,  0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959,
          0.5667,  0.7935],
        [-0.2473, -1.3527, -1.6959,  0.5667,  0.7935,  0.5988, -1.5551, -0.3414,
          1.8530,  0.7502],
        [ 0.5988, -1.5551, -0.3414,  1.8530,  0.7502, -0.5855, -0.1734,  0.1835,
          1.3894,  1.5863],
        [-0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463, -0.8437, -0.6136,
          0.0316,  1.0554],
        [ 0.9463, -0.8437, -0.6136,  0.0316,  1.0554,  0.1778, -0.2303, -0.3918,
          0.5433, -0.3952],
        [ 0.1778, -0.2303, -0.3918,  0.5433, -0.3952,  0.2055, -0.4503,  1.5210,
          3.4105, -1.5312],
        [ 0.2055, -

In [6]:
alpha_t = torch.tanh(F.linear(concat_r, W_alpha, b_alpha))  # Compute alpha_t

In [7]:
alpha_t

tensor([[-0.9997,  0.9819,  0.9981,  0.9954, -0.9993],
        [-1.0000,  0.9253, -0.1193,  0.9950,  0.9978],
        [ 0.9953,  0.3407,  0.9998, -0.8890, -0.1880],
        [-0.9998,  0.9994,  0.9916,  0.9993, -0.9125],
        [-0.9995,  0.9955,  0.8808,  0.9988,  0.9694],
        [ 0.4306,  0.9438,  0.2610,  0.9958,  0.9950],
        [-0.9905,  0.9936, -0.6817,  1.0000,  0.9927],
        [-0.9795,  0.3898,  0.3966,  0.3738,  0.9761],
        [ 0.9746,  0.9999, -0.0159, -0.9929,  0.9998],
        [ 1.0000,  0.6435,  0.8557,  0.7227,  0.9969]])

In [8]:
delta_t

tensor([[-0.2596,  0.1183,  0.2440,  1.1646,  0.2886],
        [ 0.3866, -0.2011, -0.1179,  0.1922, -0.7722],
        [-1.9003,  0.1307, -0.7043,  0.3147,  0.1574],
        [ 0.3854,  0.9671, -0.9911,  0.3016, -0.1073],
        [ 0.9985, -0.4987,  0.7611,  0.6183,  0.3140],
        [ 0.2133, -0.1201,  0.3605, -0.3140, -1.0787],
        [ 0.2408, -1.3962, -0.0661, -0.3584,  0.4069],
        [ 0.3946,  0.1715,  0.8760, -0.2871,  1.0216],
        [-0.5111, -1.7137,  0.3920,  0.5945,  0.6623],
        [-1.2063,  0.6074, -0.5472, -1.1005, -0.7201]])

In [9]:
alpha_t.shape

torch.Size([10, 5])

In [10]:
delta_t.shape

torch.Size([10, 5])

In [11]:
alpha_t * delta_t

tensor([[ 0.2596,  0.1162,  0.2435,  1.1593, -0.2884],
        [-0.3866, -0.1861,  0.0141,  0.1912, -0.7704],
        [-1.8914,  0.0445, -0.7041, -0.2798, -0.0296],
        [-0.3853,  0.9665, -0.9828,  0.3014,  0.0979],
        [-0.9979, -0.4965,  0.6704,  0.6176,  0.3045],
        [ 0.0919, -0.1133,  0.0941, -0.3127, -1.0733],
        [-0.2385, -1.3872,  0.0451, -0.3584,  0.4039],
        [-0.3865,  0.0669,  0.3474, -0.1073,  0.9973],
        [-0.4981, -1.7135, -0.0062, -0.5903,  0.6622],
        [-1.2063,  0.3909, -0.4682, -0.7953, -0.7179]])

In [12]:
r_t.shape

torch.Size([10, 5])

In [13]:
r_t_prev.shape

torch.Size([10, 5])

In [14]:
(1 - alpha_t) * delta_t

tensor([[-5.1921e-01,  2.1376e-03,  4.6026e-04,  5.3371e-03,  5.7694e-01],
        [ 7.7320e-01, -1.5012e-02, -1.3200e-01,  9.5797e-04, -1.7121e-03],
        [-8.9785e-03,  8.6150e-02, -1.4424e-04,  5.9451e-01,  1.8698e-01],
        [ 7.7065e-01,  6.1330e-04, -8.2972e-03,  2.2502e-04, -2.0524e-01],
        [ 1.9964e+00, -2.2317e-03,  9.0716e-02,  7.2296e-04,  9.5949e-03],
        [ 1.2147e-01, -6.7487e-03,  2.6639e-01, -1.3341e-03, -5.4461e-03],
        [ 4.7934e-01, -8.9921e-03, -1.1123e-01, -1.1534e-06,  2.9686e-03],
        [ 7.8108e-01,  1.0465e-01,  5.2862e-01, -1.7978e-01,  2.4370e-02],
        [-1.3002e-02, -2.0807e-04,  3.9825e-01,  1.1848e+00,  1.1159e-04],
        [-1.8694e-06,  2.1654e-01, -7.8929e-02, -3.0521e-01, -2.2258e-03]])

In [15]:
# Awesome

In [16]:
R_EMA = compute_r_t_EMA_matrix(R, W_alpha, b_alpha, delta_t)

# Print the results
print("R:")
print(R)
print("R^EMA:")
print(R_EMA)

RuntimeError: The size of tensor a (11) must match the size of tensor b (10) at non-singleton dimension 0

In [None]:
# Define dimensions
input_dim = 5# Define the input dimension
at_dim = 12 # Define the hidden dimension

# Initialize parameters randomly
W_z = torch.nn.Parameter(torch.randn(at_dim, input_dim))
b_z = torch.nn.Parameter(torch.randn(at_dim))

W_q = torch.nn.Parameter(torch.randn(at_dim, at_dim))
b_q = torch.nn.Parameter(torch.randn(at_dim))

W_k = torch.nn.Parameter(torch.randn(at_dim, at_dim))
b_k = torch.nn.Parameter(torch.randn(at_dim))

W_v = torch.nn.Parameter(torch.randn(at_dim, at_dim))
b_v = torch.nn.Parameter(torch.randn(at_dim))

# Now you can use these parameters in your code
# R_EMA = compute_r_t_EMA_matrix(R, W_alpha, b_alpha, delta_t)

# Print the results
print("R:")
print(R)
print("R^EMA:")
print(R_EMA)

# Calculate Z using linear transformation
silu = nn.SiLU()
Z = F.linear(R_EMA, W_z, b_z)
Z = silu(Z)

# Calculate Q, K, and V using linear transformations of Z
Q = F.linear(Z, W_q, b_q)
K = F.linear(Z, W_k, b_k)
V = F.linear(Z, W_v, b_v)

# Calculate attention scores and weights
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(at_dim)
attention_weights = F.softmax(attention_scores, dim=-1)

# Apply attention weights to values
Z_at = torch.matmul(attention_weights, V)

# Print the result of attention
print("Z_at:")
print(Z_at)
print(Z_at.shape)

In [17]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# Generate input values
x = torch.linspace(-5, 5, 100)

# Apply SiLU function to input values
silu = nn.SiLU()
y = silu(x)

# Plot the graph
plt.plot(x.numpy(), y.numpy(), label='SiLU')
plt.title('SiLU Activation Function')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.grid(True)
plt.show()

ModuleNotFoundError: No module named 'matplotlib'

## MEGA Decoder

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MEGADecoder(nn.Module):
    def __init__(self, r_dim, r_size, p_dim, at_dim, o_dim):
        super(MEGADecoder, self).__init__()
        self.r_dim = r_dim
        self.r_size = r_size
        self.p_dim = p_dim
        self.at_dim = at_dim
        self.o_dim = o_dim

        self.W_alpha = nn.Parameter(torch.Tensor(r_dim, r_dim * 2))
        self.b_alpha = nn.Parameter(torch.Tensor(r_dim))

        self.W_delta = nn.Parameter(torch.Tensor(r_dim, r_dim * 2))
        self.b_delta = nn.Parameter(torch.Tensor(r_dim))

        self.W_EMA = nn.Parameter(torch.Tensor(p_dim, r_dim))
        self.b_EMA = nn.Parameter(torch.Tensor(p_dim))
    
        self.W_q = nn.Parameter(torch.Tensor(at_dim, p_dim))
        self.b_q = nn.Parameter(torch.Tensor(at_dim))
    
        self.W_k = nn.Parameter(torch.Tensor(at_dim, p_dim))
        self.b_k = nn.Parameter(torch.Tensor(at_dim))
    
        self.W_v = nn.Parameter(torch.Tensor(at_dim, p_dim))
        self.b_v = nn.Parameter(torch.Tensor(at_dim))

        self.W_f = nn.Parameter(torch.Tensor(at_dim, p_dim))
        self.b_f = nn.Parameter(torch.Tensor(at_dim))

        self.W_EMA_c = nn.Parameter(torch.Tensor(p_dim, r_dim))
        self.W_z_C = nn.Parameter(torch.Tensor(at_dim, r_dim))
        
        self.b_C = nn.Parameter(torch.Tensor(r_dim))

        self.W_i = nn.Parameter(torch.Tensor(1, p_dim))
        self.b_i = nn.Parameter(torch.Tensor(1))

        self.W_o = nn.Parameter(torch.Tensor(o_dim, r_dim))
        self.b_o = nn.Parameter(torch.Tensor(o_dim))

        
    
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W_alpha)
        nn.init.xavier_uniform_(self.W_delta)
        nn.init.xavier_uniform_(self.W_EMA)
        nn.init.xavier_uniform_(self.W_q)
        nn.init.xavier_uniform_(self.W_k)
        nn.init.xavier_uniform_(self.W_v)
        nn.init.xavier_uniform_(self.W_f)
        nn.init.xavier_uniform_(self.W_EMA_c)
        nn.init.xavier_uniform_(self.W_z_C)
        nn.init.xavier_uniform_(self.W_i)
        nn.init.xavier_uniform_(self.W_o)

        nn.init.zeros_(self.b_alpha)
        nn.init.zeros_(self.b_delta)
        nn.init.zeros_(self.b_EMA)
        nn.init.zeros_(self.b_q)
        nn.init.zeros_(self.b_k)
        nn.init.zeros_(self.b_v)
        nn.init.zeros_(self.b_f)
        nn.init.zeros_(self.b_C)
        nn.init.zeros_(self.b_i)
        nn.init.zeros_(self.b_o)
        
    def forward(self, R):
        R_1 = torch.cat((torch.zeros(1, self.r_dim), R), dim=0)
        m_size, _ = R.shape # m_size : size of the modified R (r_size + 1)

        r_t = R_1[1:]
        r_t_prev = R_1[:-1]

        concat_r = torch.cat((r_t_prev, r_t), dim=1)
        alpha_t = torch.tanh(F.linear(concat_r, self.W_alpha, self.b_alpha))  # Compute alpha_t
        delta_t = torch.tanh(F.linear(concat_r, self.W_delta, self.b_delta))  # Compute delta_t

        R_EMA = alpha_t * r_t + (1 - alpha_t) * delta_t * r_t_prev

        silu = nn.SiLU()

        R_EMA_prime = F.linear(R_EMA, self.W_EMA, self.b_EMA)
        R_EMA_prime = silu(R_EMA_prime)
        
        Q = F.linear(R_EMA_prime, self.W_q, self.b_q)
        K = F.linear(R_EMA_prime, self.W_k, self.b_k)
        V = F.linear(R_EMA_prime, self.W_v, self.b_v)

        at_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.at_dim))
        at_weights = F.softmax(at_scores, dim=-1)
        
        Z_EMA = torch.matmul(at_weights, V)

        f = torch.sigmoid(F.linear(R_EMA_prime, self.W_f, self.b_f))
        
        Z_EMA_f = f * Z_EMA

        Z_EMA_C = torch.matmul(R_EMA_prime, self.W_EMA_c) + torch.matmul(Z_EMA_f, self.W_z_C) + self.b_C
        Z_EMA_C = silu(Z_EMA_C)

        i = torch.sigmoid(F.linear(R_EMA_prime, self.W_i, self.b_i))
        R_h = i * Z_EMA_C + (1 - i) * R
        
        r_cap = F.linear(R_h, self.W_o, self.b_o)
        p_cap = F.softmax(r_cap, dim=0)
        
        return p_cap
        
    def autoregressive_forward(self, R, seq_len):
        outputs = []
        for t in range(seq_len):
            p_cap = self.forward(R)
            outputs.append(p_cap)
            R = torch.cat((R, p_cap), dim=0)  # Update input with the new output
            R = R[1:]  # Keep the input sequence length constant

        outputs = torch.stack(outputs, dim=0)
        return outputs

In [26]:
import torch

# Define the dimensions
r_dim = 2
r_size = 10
p_dim = 8
at_dim = 6
o_dim = 2

# Create an instance of the MEGADecoder model
model = MEGADecoder(r_dim, r_size, p_dim, at_dim, o_dim)

# Create some random input data R
R = torch.randn(r_size, r_dim)

# Define the length of the sequence to generate
seq_len = 10

# Autoregressive forward pass through the model
output = model.forward(R)

print("Autoregressive output shape:", output.shape)
print("Autoregressive output:", output)


Autoregressive output shape: torch.Size([10, 2])
Autoregressive output: tensor([[0.0458, 0.0320],
        [0.0569, 0.0418],
        [0.3356, 0.4229],
        [0.0443, 0.0322],
        [0.0654, 0.0431],
        [0.0667, 0.0585],
        [0.0639, 0.0478],
        [0.0619, 0.0427],
        [0.1266, 0.1669],
        [0.1329, 0.1122]], grad_fn=<SoftmaxBackward0>)


In [25]:
R

tensor([[ 1.9533, -0.4719],
        [ 1.5116, -0.1124],
        [-0.3787,  1.5686],
        [ 1.9994, -1.1063],
        [-1.2984,  0.6461],
        [-0.0882, -1.0636],
        [-0.3504, -0.4489],
        [ 0.0429,  0.3975],
        [-0.9352,  0.2816],
        [-2.0988,  2.7141]])

In [19]:

r_dim = 2
r_size = 10
z_dim = 8
at_dim = 6
g_dim = 4


# Create some random input data R and damp
R = torch.randn(r_size, r_dim)
damp = torch.rand_like(R)  # Create damp matrix with the same size as R


print(R)

R = torch.cat((torch.zeros(1, r_dim), R), dim=0)
m_size, _ = R.shape # m_size : size of the modified R (r_size + 1)
print(R)
r_t = R[1:]
r_t_prev = R[:-1]

print(r_t.shape)
print(r_t_prev.shape)


W_alpha = torch.randn(r_dim, r_dim * 2)
print(W_alpha.shape)
b_alpha = torch.randn(r_dim)
print(b_alpha.shape)
concat_r = torch.cat((r_t_prev, r_t), dim=1)
print(concat_r.shape)
alpha_t = torch.tanh(F.linear(concat_r, W_alpha, b_alpha))  # Compute alpha_t

tensor([[ 1.4998,  0.5431],
        [ 0.4865,  0.6227],
        [ 0.9738,  0.7655],
        [ 1.2955,  0.8909],
        [-0.4898, -1.1727],
        [-0.6870, -2.3349],
        [ 0.1581,  0.1000],
        [-0.0595,  2.0118],
        [-0.3368,  0.3260],
        [ 0.5352,  1.9733]])
tensor([[ 0.0000,  0.0000],
        [ 1.4998,  0.5431],
        [ 0.4865,  0.6227],
        [ 0.9738,  0.7655],
        [ 1.2955,  0.8909],
        [-0.4898, -1.1727],
        [-0.6870, -2.3349],
        [ 0.1581,  0.1000],
        [-0.0595,  2.0118],
        [-0.3368,  0.3260],
        [ 0.5352,  1.9733]])
torch.Size([10, 2])
torch.Size([10, 2])
torch.Size([2, 4])
torch.Size([2])
torch.Size([10, 4])


## Testing

In [20]:
import sys
sys.path.append('..')

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [22]:
from src import MEGADecoder

ImportError: cannot import name 'RCell' from 'src.r_cell' (D:\KDM\Perducer-v1.0\expm\..\src\r_cell.py)

In [23]:
# Define the dimensions
r_dim = 2
r_size = 10
z_dim = 8
at_dim = 6
g_dim = 4

# Create an instance of the MEGADecoder model
model = MEGADecoder(r_dim, r_size, z_dim, at_dim, g_dim)

# Create some random input data R and damp
    R = torch.randn(r_size, r_dim)
damp = torch.rand_like(R)  # Create damp matrix with the same size as R

# Forward pass through the model
output = model(R, damp)

print("Output shape:", output.shape)
print("Output:", output)

IndentationError: unexpected indent (2227218485.py, line 12)

Resources:
https://dmytro-kuzmenko.medium.com/mega-attention-breakdown-8b1b56cf715f

Note: When adding two matrices, it's crucial to align their dimensions to facilitate addition. The term 'g_dim' represents the ground dimension, serving as a reference for ensuring uniformity before performing the addition operation.