In [1]:
import torch

# Define the custom Tanh activation function
class Tanh:
    def forward(self, x):
        x_exp, neg_x_exp = torch.exp(x), torch.exp(-x)
        return (x_exp - neg_x_exp) / (x_exp + neg_x_exp)

# Define the neural network
class SimpleNeuralNetwork:
    def __init__(self):
        # Initialize weights randomly in the range [-0.5, 0.5]
        self.weights = (torch.rand(8) - 0.5)  # 8 random weights in [-0.5, 0.5]
        self.biases = torch.tensor([0.5, 0.7])  # Biases as a tensor
        
        # Reshape weights for matrix operations
        self.w_hidden = self.weights[:4].reshape(2, 2)  # Weights for hidden layer
        self.w_output = self.weights[4:].reshape(2, 2)  # Weights for output layer
        
        self.tanh = Tanh()  # Use the custom Tanh implementation

    def forward(self, inputs):
        # Reshape inputs for matrix multiplication
        inputs = inputs.reshape(2, 1)
        
        # Hidden layer calculations
        net_hidden = torch.mm(self.w_hidden, inputs) + self.biases[0]  # Matrix multiplication
        out_hidden = self.tanh.forward(net_hidden)
        
        # Output layer calculations
        net_output = torch.mm(self.w_output, out_hidden) + self.biases[1]  # Matrix multiplication
        out_output = self.tanh.forward(net_output)
        
        return out_output.flatten()  # Return flattened output

    def calculate_error(self, outputs, targets):
        # Squared error for each output neuron
        error = 0.5 * torch.sum((targets - outputs) ** 2)
        return error

# Inputs and targets
inputs = torch.tensor([0.05, 0.10])
targets = torch.tensor([0.01, 0.99])

# Initialize the neural network
nn = SimpleNeuralNetwork()

# Print initialized weights
print("Initialized Weights:")
print(nn.weights)

# Forward pass
outputs = nn.forward(inputs)
print(f"Outputs: o1 = {outputs[0].item():.4f}, o2 = {outputs[1].item():.4f}")

# Calculate error
error = nn.calculate_error(outputs, targets)
print(f"Total Error: {error.item():.4f}")

Initialized Weights:
tensor([ 0.2892,  0.2590,  0.1349, -0.4456, -0.0562, -0.3067,  0.1278, -0.0220])
Outputs: o1 = 0.4928, o2 = 0.6373
Total Error: 0.1788
