In [1]:
import torch

$$x_i\in \mathbb{R}^{31}$$

In [2]:
x = torch.normal(0, 1, [31])

In [3]:
x.shape

torch.Size([31])

$$h\in \mathbb{R}^{100\times20\times31}$$

In [4]:
h = torch.normal(0, 1, [100, 20, 31])

In [5]:
g = torch.matmul(h, x)
g = torch.heaviside(g, torch.tensor(0.))

$$g = h\cdot x\in\mathbb{R}^{100\times20}$$

In [6]:
g.shape

torch.Size([100, 20])

$$w\in\mathbb{R}^{100\times20\times31}$$

In [7]:
w = torch.normal(0, 1, [100, 20, 31]) 

$$ew = g\cdot w\in\mathbb{R}^{100\times100\times31}$$
$$ew_{sum} \in\mathbb{R}^{100\times31}$$

In [8]:
ew = torch.matmul(g, w)

In [9]:
ew.shape

torch.Size([100, 100, 31])

In [10]:
ew_sum = ew.sum(dim=1)

In [11]:
ew_sum.shape

torch.Size([100, 31])

$$r_{output} = ew_{sum}\cdot x\in\mathbb{R}^{100}$$

In [12]:
r_output = torch.matmul(ew_sum, x)

In [13]:
r_output.shape

torch.Size([100])

In [14]:
target = 1.

In [15]:
grad = (r_output[:, None] - target) * x[None]

In [16]:
grad.shape

torch.Size([100, 31])

In [17]:
(r_output[:, None] - target).shape

torch.Size([100, 1])

In [18]:
learning_rate = 0.01

In [19]:
w -= learning_rate * g[:, :, None] * grad[:, None]

In [20]:
w.shape

torch.Size([100, 20, 31])

---

In [21]:
x = torch.normal(0, 1, [31, 455])

In [22]:
x.shape

torch.Size([31, 455])

In [23]:
g = torch.matmul(h, x)

In [24]:
g.shape

torch.Size([100, 20, 455])

In [42]:
g[:, :, 0].shape

torch.Size([100, 20])

In [25]:
w.shape

torch.Size([100, 20, 31])

In [63]:
a = torch.normal(0, 1, [100, 20, 455])
b = torch.normal(0, 1, [100, 20, 31])

res = []
for i in range(a.shape[2]):
    a_i = a[:, :, i]
    r = torch.matmul(a_i, b)
    res.append(r.sum(dim=1))

final_res = torch.stack(res)

In [64]:
final_res.shape

torch.Size([455, 100, 31])

In [65]:
# Create random tensors
a = torch.normal(0, 1, [100, 20, 455])
b = torch.normal(0, 1, [100, 20, 31])

# Use torch.einsum to perform the batched matrix multiplication and summation
# 'ijk,ijl->ikl' means:
# i = batch dimension
# j = shared dimension for matrix multiplication
# k = output dimensions
# The resulting tensor will have dimensions (455, 100, 31)
res = torch.einsum('ijk,ijl->ikl', a, b).sum(dim=1)

# Transpose the result to match the desired final shape (455, 100)
final_res = res.transpose(0, 1)

print(final_res.shape)

torch.Size([31, 100])


In [68]:
torch.einsum('ijk,ijl->ikl', a, b).shape

torch.Size([100, 455, 31])

In [26]:
ew = torch.tensordot(w, g, dims=([1], [1]))

In [27]:
ew.shape

torch.Size([100, 31, 100, 455])

In [28]:
ew_sum = ew.sum(dim=2)

In [29]:
ew_sum.shape

torch.Size([100, 31, 455])

In [30]:
x.shape

torch.Size([31, 455])

In [31]:
ew_sum = torch.permute(ew_sum,[1, 2, 0])
ew_sum.shape

torch.Size([31, 455, 100])

In [32]:
r = torch.tensordot(ew_sum, x, dims=([0, 1], [0, 1]))

In [33]:
r.shape

torch.Size([100])

In [1]:
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model
model = SimpleModel()

# Access the weights of all layers
for name, param in model.named_parameters():
    if 'weight' in name:
        print(f"Layer: {name}, Weight: {param}")


Layer: fc1.weight, Weight: Parameter containing:
tensor([[ 0.0616, -0.3055, -0.3115,  0.1766,  0.1004,  0.0869,  0.2336, -0.1653,
         -0.0275,  0.0825],
        [ 0.0922,  0.0576,  0.0327,  0.2759, -0.0082, -0.1777, -0.1555,  0.1495,
          0.2586, -0.0996],
        [ 0.1725,  0.2585, -0.2306, -0.2004,  0.2054,  0.2606,  0.0105,  0.1917,
         -0.0732, -0.1815],
        [ 0.1846,  0.2444, -0.1553, -0.2934, -0.0498, -0.0057,  0.2718, -0.1875,
         -0.0040, -0.0633],
        [ 0.0047,  0.2079,  0.2680,  0.3126, -0.2563,  0.2392, -0.1924, -0.2330,
          0.2164, -0.2386]], requires_grad=True)
Layer: fc2.weight, Weight: Parameter containing:
tensor([[-0.0114, -0.4244,  0.1671, -0.2300,  0.4070],
        [-0.1300, -0.3973, -0.0941, -0.1816,  0.1613]], requires_grad=True)
