In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SoftResurrectReLU(nn.Module):
    """
    SoftResurrectReLU Activation:
    - Retains ReLU-like behavior for positive values
    - Allows small gradient flow for negative inputs via a smooth curve
    - Designed to prevent neuron death and enhance long training stability
    - works like a beast for very low lr like 1e-3 to 1e-6 and so on......
    """
    def __init__(self, alpha=1.0, beta=0.5):
        super(SoftResurrectReLU, self).__init__()
        self.alpha = alpha  # controls sharpness of resurrection
        self.beta = beta    # controls contribution from the negative side

    def forward(self, x):
        positive = F.relu(x)  # standard ReLU for x >= 0
        negative = self.beta * torch.tanh(self.alpha * x)  # smooth curve for x < 0
        return torch.where(x >= 0, positive, negative)
