# Skip Connections

We learned that changing the activation functions and initialization strategies may help in resolving vanishing gradient issue for relatively shallower neural networks. If the network is deeper, then these strategies fail. The solution is skip connection. The objective of this exercise is to understand the role of skip connection in a simple single-neuron-per-layer neural network. 

A simple three layer single-neuron-per-layer neural network is designed. 

![Skip Neural Network](SkipNN.jpg)


The closed-form expressions for the gradients are defined for this network. Closed-form expressions are also validated by comparing them with PyTorch gradients. 

Analyze the gradient flows for this network. 

In [13]:
import torch
import torch.nn as nn
import plotly.graph_objs as go
import plotly.express as px

## Backpropagation in a Single Neuron Per Layer Network with Skip Connections


Assume a network where each layer contains a single neuron, and skip connections exist from the input to each subsequent layer:

### Forward Pass

1. **Input Layer**:
   $$
   a^{[0]} = x
   $$

2. **Layer 1**:
   $$
   f^{[1]} = W^{[1]} a^{[0]}  
   $$

   $$
   a^{[1]} = g^{[1]}(f^{[1]}) + a^{[0]} \quad 
   $$

where  $a^{[0]}$  is due to the skip connection from Layer 0

3. **Layer 2** (with skip connection from Layer 1):
   $$
   f^{[2]} = W^{[2]} a^{[1]}  
   $$

   $$
   a^{[2]} = g^{[2]}(f^{[2]}) + a^{[1]} 
   $$

   where  $a^{[1]}$  is due to the skip connection from Layer 1

4. **Layer 3** (with skip connection from Layer 2):
   $$
   f^{[3]} = W^{[3]} a^{[2]}  
   $$
   
   $$
   a^{[3]} = g^{[3]}(f^{[3]})  + a^{[2]} 
   $$

where  $a^{[2]}$  is due to the skip connection from Layer 2

5. **Output Layer**:
   $$
   \hat{y} = a^{[3]}
   $$

### Loss Function

Using Mean Squared Error as the loss:

$$
\mathcal{E} = \frac{1}{2} (\hat{y} - y)^2
$$

## Gradients 

For any layer $l$, we need to find $\frac{\partial E}{\partial W^{[l]}}$. Let's start with the output layer and work backwards.

### 1. Layer 3
The loss gradient with respect to $W^{[3]}$ can be derived as:

$$
\frac{\partial E}{\partial W^{[3]}} = \frac{\partial E}{\partial a^{[3]}} \cdot \frac{\partial a^{[3]}}{\partial f^{[3]}} \cdot \frac{\partial f^{[3]}}{\partial W^{[3]}}
$$

where:
- From MSE loss: $$\frac{\partial E}{\partial a^{[3]}} = (a^{[3]} - y)$$ 
- Derivative of activation: $$\frac{\partial a^{[3]}}{\partial f^{[3]}} = g'^{[3]}(f^{[3]})$$ 
- From $f^{[3]} = W^{[3]}a^{[2]}$: $$\frac{\partial f^{[3]}}{\partial W^{[3]}} = a^{[2]}$$ 

Therefore:
$$
\frac{\partial E}{\partial W^{[3]}} = (a^{[3]} - y) \delta^{[3]} \cdot a^{[2]}
$$
where $$\delta^{[3]} =  g'^{[3]}(f^{[3]})$$

### 2. Layer 2

The loss gradient with respect to $W^{[2]}$ can be derived as:

$$
\frac{\partial E}{\partial W^{[2]}} = (a^{[3]} - y) \cdot \left(\delta^{[2]}  +  g'^{[2]}(f^{[2]})\right) \cdot a^{[1]}
$$

where $$\delta^{[2]} = \delta^{[3]} W^{[3]}  g'^{[2]}(f^{[2]})$$

### 3. Layer 1

The loss gradient with respect to $W^{[1]}$ can be derived as:

$$
\frac{\partial E}{\partial W^{[1]}} = (a^{[3]} - y) \cdot \left(\delta^{[1]}   + \delta^{[3]} \cdot  W^{[3]}  g'^{[1]}(f^{[1]})  + g'^{[2]} (f^{[2]}) W^{[2]} g'^{[1]} (f^{[1]}) + g'^{[1]}(f^{[1]})\right)\cdot a^{[0]}
$$

where $$\delta^{[1]} = \delta^{[2]} W^{[2]}  g'^{[1]}(f^{[1]})$$

This network is implemented and gradients are validated via backpropagation using PyTorch

In [17]:
import torch
import torch.nn as nn

# Define the neural network model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.w1 = nn.Parameter(torch.tensor([0.8]))
        self.w2 = nn.Parameter(torch.tensor([0.8]))
        self.w3 = nn.Parameter(torch.tensor([0.8]))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        f1 = self.w1 * x
        a1 = self.sigmoid(f1) + x

        f2 = self.w2 * a1
        a2 = self.sigmoid(f2) + a1

        f3 = self.w3 * a2
        a3 = self.sigmoid(f3) + a2

        return a3, f1, a1, f2, a2, f3

# Instantiate the model
model = SimpleNN()

# Create input and target tensors
x = torch.tensor([1.0])
y_true = torch.tensor([0.5])

# Forward pass
a3, f1, a1, f2, a2, f3 = model(x)

# Compute loss
# Note in Pytorch, MSE Loss does include the factor 0.5
criterion = nn.MSELoss()
loss = 0.5*criterion(a3, y_true)

print("Forward Pass Values:")
print(f"x (a0): {x.item():.8f}")
print(f"a3 (y_pred): {a3.item():.8f}")
print(f"Loss: {loss.item():.8f}")


# Perform backward pass to compute gradients
loss.backward()

# Manual gradient calculation
def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

def sigmoid_derivative(x):
    s = 1 / (1 + torch.exp(-x))
    return s * (1 - s)

# Layer 3
delta3 = sigmoid_derivative(f3).item()
dw3_manual = (a3.item() - y_true.item()) * delta3 * a2.item()

# Layer 2
delta2 = delta3 * model.w3.item() * sigmoid_derivative(f2).item()
dw2_manual = (a3.item() - y_true.item()) * (delta2 * a1.item() + sigmoid_derivative(f2).item() * a1.item())

# Layer 1
delta1 = delta2 * model.w2.item() * sigmoid_derivative(f1).item()
dw1_manual = (a3.item() - y_true.item()) * (
    delta1 * x.item() +
    delta3 * model.w3.item() * sigmoid_derivative(f1).item() * x.item() +
    sigmoid_derivative(f2).item() * model.w2.item() * sigmoid_derivative(f1).item() * x.item() +
    sigmoid_derivative(f1).item() * x.item()
)

print("\nManual Gradient Calculation:")
print(f"dw1_manual: {dw1_manual:.8f}")
print(f"dw2_manual: {dw2_manual:.8f}")
print(f"dw3_manual: {dw3_manual:.8f}")

print("\nPyTorch Gradients:")
print(f"dE/dw1: {model.w1.grad.item():.8f}")
print(f"dE/dw2: {model.w2.grad.item():.8f}")
print(f"dE/dw3: {model.w3.grad.item():.8f}")

print("\nDifferences in Gradients (PyTorch AutoGrad - Manual):")
print(f"w1 diff: {abs(model.w1.grad.item() - dw1_manual):.8f}")
print(f"w2 diff: {abs(model.w2.grad.item() - dw2_manual):.8f}")
print(f"w3 diff: {abs(model.w3.grad.item() - dw3_manual):.8f}")

Forward Pass Values:
x (a0): 1.00000000
a3 (y_pred): 3.36391068
Loss: 4.10099220

Manual Gradient Calculation:
dw1_manual: 0.75138211
dw2_manual: 0.85736593
dw3_manual: 0.75415744

PyTorch Gradients:
dE/dw1: 0.75138211
dE/dw2: 0.85736591
dE/dw3: 0.75415742

Differences in Gradients (PyTorch AutoGrad - Manual):
w1 diff: 0.00000001
w2 diff: 0.00000002
w3 diff: 0.00000001


# Homework:

Prove the closed-form expressions for the gradients using chain rule.