In [34]:
import numpy as np
import torch
import torch.nn as nn

BATCH_SIZE = 1

# Network parameters
timesteps = 10
layers = 3
input_size = 10
hidden_size = 10  # Size of each hidden layer
output_size = 1   # Output size for each timestep

# Define the dense layer class that connects forward and backward
class BidirectionalDenseLayer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BidirectionalDenseLayer, self).__init__()
        self.backward_layer = nn.Linear(input_size, hidden_size)
        self.forward_layer = nn.Linear(hidden_size, hidden_size)
        self.activations = torch.zeros(BATCH_SIZE, hidden_size)
    
    def forward(self, x_forward=None, x_backward=None):
        h = torch.zeros(BATCH_SIZE, hidden_size)

        if x_forward is not None:
            # Forward propagation from the current layer's input
            h_forward = self.forward_layer(x_forward)
            h += h_forward
        
        # Backward connection from previous layer (if exists)
        if x_backward is not None:
            h_backward = self.backward_layer(x_backward)
            h += h_backward
        
        self.activations = torch.relu(h)  # ReLU activation
        return self.activations

# Define the network with 3 layers and 10 timesteps
class BiDirectionalNetwork(nn.Module):
    def __init__(self, layers, input_size, hidden_size, output_size):
        super(BiDirectionalNetwork, self).__init__()
        self.layers = nn.ModuleList([BidirectionalDenseLayer(hidden_size, hidden_size) for _ in range(layers)])
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Loop over timesteps
        for t in range(timesteps):
            layer_activations = []
            
            # Forward pass through each layer
            for i, layer in enumerate(self.layers):
                if i == 0:
                    # First layer does not have a lower layer to connect to
                    layer(x, self.layers[i+1].activations)
                elif i == len(self.layers) - 1:
                    layer(self.layers[i-1].activations, None)
                else:
                    # Layers above connect to both the previous layer and the current layer's forward pass
                    layer(self.layers[i+1].activations, self.layers[i-1].activations)

        # Output layer (optional for the last layer at each timestep)
        outputs = self.output_layer(self.layers[-1].activations)  # Output from the last layer at the last timestep
        
        return outputs  # Return both output and activations for later gradient computation

# Dummy input to simulate a sequence over 10 timesteps, with batch size = 2
x = torch.randn(BATCH_SIZE, input_size)

# Initialize the network
network = BiDirectionalNetwork(layers, input_size, hidden_size, output_size)

# Forward pass
output = network(x)

# Loss function and backpropagation (global backprop after all timesteps)
criterion = nn.MSELoss()
target = torch.zeros(1, 1)  # Dummy target for loss computation

# Compute loss and backpropagate
loss = criterion(output, target)
loss.backward()  # This will compute gradients globally, considering all timesteps and layers

# Display the output, loss, and gradients
print("Output:", output)
print("Loss:", loss.item())

# Print gradients for the first layer's weights to demonstrate backpropagation
print("Gradients for first layer's forward weights:", network.layers[0].forward_layer.weight.grad)

Output: tensor([[0.0114]], grad_fn=<AddmmBackward0>)
Loss: 0.000130068336147815
Gradients for first layer's forward weights: tensor([[ 6.1538e-05,  7.1619e-06, -6.2029e-05, -3.7185e-05, -1.5717e-04,
          2.0723e-05,  6.9320e-05,  6.9700e-05, -2.6066e-06,  6.9380e-05],
        [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
        [-1.5739e-04, -1.8318e-05,  1.5865e-04,  9.5106e-05,  4.0199e-04,
         -5.3004e-05, -1.7730e-04, -1.7827e-04,  6.6667e-06, -1.7745e-04],
        [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e