In [90]:
import torch
import torch.nn as nn
import torch.optim as optim

### Tensors

In [91]:
t = torch.rand(3, 4)
t = torch.ones_like(t)
print(t)

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])


In [92]:
t = torch.Tensor([[1, 2, 3], [4, 5, 6]])
print(t.device)
print(t.shape)
print(t[0, 1])

cpu
torch.Size([2, 3])
tensor(2.)


### Layers

In [93]:
d_in, d_out = 5, 3 # Input and output dimensions
lin = nn.Linear(d_in, d_out) # Linear layer with 5 inputs and 3 outputs
relu = nn.ReLU() # ReLU activation layer

In [94]:
channels = 4
x = torch.rand(channels, d_in) # Input tensor

# Forward pass
y = relu(lin(x))

### Sequential

In [95]:
d_h = 2 # Hidden layer dimension
mlp = nn.Sequential(nn.Linear(d_in, d_h), nn.BatchNorm1d(d_h), nn.ReLU(), 
                    nn.Linear(d_h, d_out))

In [96]:
# Forward pass
y = mlp(x)
y.shape

torch.Size([4, 3])

### Optimize

In [97]:
# random generator with fixed seed
g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x12a5d9430>

In [98]:
x = torch.rand(channels, d_in, generator=g) # Input tensor

In [99]:
# y_true = torch.rand(channels, d_out, generator=g) # True output tensor
y_true = torch.zeros(channels, d_out) # True output tensor

In [100]:
adam_opt = optim.Adam(mlp.parameters(), lr=1e-1)

In [101]:
# Loop
epochs = 100
adam_opt.zero_grad()

for i in range(epochs):
    # Forward pass
    y = mlp(x)
    # Compute loss
    loss = ((y - y_true) ** 2).sum()
    # Backward pass
    loss.backward()
    # Update weights
    adam_opt.step()
    adam_opt.zero_grad()
    # Print loss
    print(f'Epoch {i}: {loss}')

Epoch 0: 4.132096290588379
Epoch 1: 1.7699825763702393
Epoch 2: 0.6587584018707275
Epoch 3: 0.19137465953826904
Epoch 4: 0.039265070110559464
Epoch 5: 0.04842814803123474
Epoch 6: 0.10276862978935242
Epoch 7: 0.1415027678012848
Epoch 8: 0.16058887541294098
Epoch 9: 0.16690826416015625
Epoch 10: 0.17378759384155273
Epoch 11: 0.17818835377693176
Epoch 12: 0.1788991540670395
Epoch 13: 0.17188221216201782
Epoch 14: 0.15410096943378448
Epoch 15: 0.12613330781459808
Epoch 16: 0.09224522113800049
Epoch 17: 0.058732956647872925
Epoch 18: 0.03157965838909149
Epoch 19: 0.014402243308722973
Epoch 20: 0.007599662058055401
Epoch 21: 0.00899957399815321
Epoch 22: 0.015331693924963474
Epoch 23: 0.023510247468948364
Epoch 24: 0.03125907480716705
Epoch 25: 0.03718496859073639
Epoch 26: 0.04054369777441025
Epoch 27: 0.04095907881855965
Epoch 28: 0.03834375739097595
Epoch 29: 0.033051908016204834
Epoch 30: 0.02600446343421936
Epoch 31: 0.018529649823904037
Epoch 32: 0.01194830983877182
Epoch 33: 0.007148