In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
# Prepare samples
samples_inp = torch.tensor([[0.1,0.2,0.3],[0.1,0.1,0.1],[0.4,0.4,0.4],[0.9,0.9,0.9],[0.3,0.2,0.1],[0.0,0.0,0.1],[0.0,0.0,0.0]],dtype=torch.float)
samples_out = torch.tensor([[0.123],[0.111],[0.444],[0.999],[0.321],[0.001],[0.0]],dtype=torch.float)

In [None]:
samples_inp = samples_inp.unsqueeze(-1).unsqueeze(-1)
samples_out = samples_out.unsqueeze(-1).unsqueeze(-1)

In [None]:
class FCPerceptron(nn.Module):
    # initialization
    def __init__(self, num_inputs, num_hiddens, num_outputs):
        super().__init__()
        self.layer1 = nn.Conv2d(num_inputs, num_hiddens, (1,1))
        self.layer2 = nn.Conv2d(num_hiddens, num_outputs, (1,1))
    # forward
    def forward(self, x):
        return self.layer2(self.layer1(x))

In [None]:
# Define model
model = FCPerceptron(3,10,1)

In [None]:
# Define loss function
loss_function = nn.MSELoss(reduction='sum')

In [None]:
# Define optimizer
optimizer = optim.Adam(model.parameters())

In [None]:
# Training
num_epochs = 8000

In [None]:
for t in range(num_epochs):

    # Forward pass
    out = model(samples_inp)
    loss = loss_function(out, samples_out)
    if t % 10 == 0:
        print(t, loss.item())

    # Reset gradients
    optimizer.zero_grad()

    # Backward pass
    loss.backward()

    # Update model parameters (weights)
    optimizer.step()

In [None]:
# Saving
torch.save(model.state_dict(), 'perceptron.pth') # weights only
torch.save(model,'perceptron.pt') # whole model

In [None]:
# test
result = model(torch.tensor([[[[0.7]],[[0.8]],[[0.9]]]],dtype=torch.float))
print('test result',round(result[0].item(),3))

In [None]:
# run in parallel 3 x 3 perceptrons that share weights
data = torch.tensor([[[[0.7,0.8,0.9],[0.4,0.5,0.6],[0.1,0.2,0.3]],[[0.7,0.8,0.8],[0.4,0.5,0.5],[0.1,0.2,0.2]],[[0.7,0.8,1.0],[0.4,0.4,0.4],[0.1,0.2,0.1]]]],dtype=torch.float)
result = model(data.permute(0,3,1,2))
import numpy as np
print('test result:\n',np.round(result[0].detach().cpu().numpy(),3))