# 1) Install dependencies and import modules

In [362]:
!pip install torch 

Defaulting to user installation because normal site-packages is not writeable
Looking in links: /usr/share/pip-wheels


In [360]:
import torch
import torch.nn as nn
import torch.optim as optim

# 2) Create a model

In [367]:
class TorchBasicsModel(nn.Module):
    def __init__(self):
        super(TorchBasicsModel, self).__init__()
        
        #Generate n within specified range [1,3]
        self.n = nn.Parameter(1 + 2 * torch.rand(1))

    def forward(self, x):
        #Clamp n in way, that it remains in necessary interval
        n_clamp = torch.clamp(self.n, min=1, max=3)

        y = torch.pow(x, n_clamp)
        return y

# 3) Training the model

In [373]:
def create_dataset(batch_size, input_size):
    # Generate random input data
    x = torch.rand((batch_size, input_size))
    
    # Generate corresponding target data in range [1,3]
    n = torch.randint(1, 4, (batch_size, 1)).float()
    
    y = torch.pow(x, n)
    return x, y

In [371]:
# Define hyperparameters
batch_size = 8
input_size = 16
num_epochs = 3000
learning_rate = 1e-1

In [375]:
# Instantiate the model
model = TorchBasicsModel()

# Define MSE loss function and SGD optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [332]:
print(f'Initial value of n: {model.n.item()}')

# Training loop
for epoch in range(num_epochs):
    # Generate random training data
    x_train, y_train = create_dataset(batch_size, input_size)
    
    # Forward pass
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


print(f'Final value of n: {model.n.item()}')

Initial value of n: 2.61716890335083
Epoch [100/3000], Loss: 0.0314
Epoch [200/3000], Loss: 0.0074
Epoch [300/3000], Loss: 0.0240
Epoch [400/3000], Loss: 0.0169
Epoch [500/3000], Loss: 0.0136
Epoch [600/3000], Loss: 0.0178
Epoch [700/3000], Loss: 0.0130
Epoch [800/3000], Loss: 0.0134
Epoch [900/3000], Loss: 0.0039
Epoch [1000/3000], Loss: 0.0124
Epoch [1100/3000], Loss: 0.0161
Epoch [1200/3000], Loss: 0.0099
Epoch [1300/3000], Loss: 0.0166
Epoch [1400/3000], Loss: 0.0133
Epoch [1500/3000], Loss: 0.0177
Epoch [1600/3000], Loss: 0.0092
Epoch [1700/3000], Loss: 0.0085
Epoch [1800/3000], Loss: 0.0161
Epoch [1900/3000], Loss: 0.0132
Epoch [2000/3000], Loss: 0.0133
Epoch [2100/3000], Loss: 0.0111
Epoch [2200/3000], Loss: 0.0163
Epoch [2300/3000], Loss: 0.0100
Epoch [2400/3000], Loss: 0.0118
Epoch [2500/3000], Loss: 0.0085
Epoch [2600/3000], Loss: 0.0143
Epoch [2700/3000], Loss: 0.0164
Epoch [2800/3000], Loss: 0.0135
Epoch [2900/3000], Loss: 0.0195
Epoch [3000/3000], Loss: 0.0120
Final value 

# 4) Model usage demonstration

In [377]:
#Create random batched vector and use forward method
x = torch.rand((8,16))
y = model(x)

In [379]:
y

tensor([[8.0660e-02, 3.2435e-01, 1.4751e-02, 9.0167e-01, 1.4298e-02, 7.5666e-02,
         1.4055e-02, 7.5672e-01, 5.4340e-01, 3.9707e-02, 2.7000e-01, 8.0763e-01,
         2.2602e-01, 2.0862e-01, 9.3224e-02, 1.4543e-02],
        [1.2469e-01, 1.6525e-01, 1.2610e-01, 2.1307e-03, 4.2762e-01, 6.3648e-02,
         6.6643e-02, 9.7582e-01, 6.7547e-01, 6.4563e-02, 8.0592e-01, 9.9746e-05,
         5.8217e-02, 8.4152e-01, 1.2109e-04, 2.6690e-02],
        [7.3771e-01, 2.1780e-03, 4.8226e-01, 3.3530e-01, 6.3120e-01, 2.9788e-01,
         4.8147e-01, 2.7035e-01, 8.2209e-02, 9.8074e-01, 8.2748e-03, 3.3320e-01,
         6.5082e-02, 2.8246e-01, 9.3366e-01, 2.8589e-01],
        [2.3259e-01, 2.4992e-01, 1.0037e-01, 2.7239e-01, 5.7713e-01, 6.5116e-01,
         2.8156e-02, 7.9808e-03, 7.7992e-01, 5.6709e-01, 4.7289e-02, 3.4418e-02,
         9.8151e-01, 5.4788e-01, 4.6068e-01, 2.2861e-03],
        [2.1735e-03, 5.5280e-01, 7.8721e-05, 9.0277e-02, 4.8200e-01, 3.9758e-02,
         2.9575e-01, 2.3760e-02, 9.0986

In [381]:
#Check if all outputs are positive
torch.all(y > 0)

tensor(True)