In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

class ALU(nn.Module):
    def __init__(self):
        super(ALU, self).__init__()
        self.alpha = nn.Parameter(torch.tensor(1.6733))
        self.beta = nn.Parameter(torch.tensor(1.0507))
        
    def forward(self, input):
        return torch.where(input >= 0, self.beta * input, self.alpha * (torch.exp(input) - 1))
    
    
class RaLU(nn.Module):
    def __init__(self):
        super(RaLU, self).__init__()
        self.alpha = nn.Parameter(torch.tensor(-1.0))
        self.beta = nn.Parameter(torch.tensor(1.0))
        
    def forward(self, input):
        return torch.where(input >= self.alpha, self.beta * input, self.alpha)
        # return F.threshold(self.beta*input, self.alpha.item(), self.alpha.item())
        

In [None]:
ralu = RaLU()
# print(ralu.alpha.data)
x = torch.tensor(np.linspace(-4, 4, 100)).float()
y = ralu(x)
plt.plot(x.detach().numpy(), y.detach().numpy())

In [None]:
def run_experiment():
    # ralu = ALU()
    ralu = RaLU()
    optimizer = torch.optim.Adam(ralu.parameters(), lr=0.001)
    batches = 100
    epochs = 1000

    alphas = list()
    lambds = list()

    optimize = True

    for j in range(epochs):
        # n = np.random.choice([1, 3, 5, 7])
        n = np.random.randint(1000)+1
        fc = nn.Linear(n, n, bias=False)
        # x = torch.randn(1, n) * torch.randn(1) * 10 + torch.rand(1)*10 - 5
        # conv = nn.Conv2d(1, 1, n, padding=int((n-1)/2))
        x = torch.randn(1, n) #  * torch.randn(1)*10 + torch.rand(1)*10-5
        means = list()
        stds = list()
        y = x
        for i in range(batches):
            y = ralu(fc(y))
            means.append(torch.mean(y))
            stds.append(torch.std(y))
        mean_loss = np.sum(np.abs(means))
        std_loss = np.sum(np.abs([x-1 for x in stds]))
        loss = (mean_loss + std_loss) / batches
        if optimize and torch.std(y) != 0:
            optimizer.zero_grad()
            loss.backward()
            old_alpha = ralu.alpha.item()
            old_beta = ralu.beta.item()
            optimizer.step()
            reset = False
            if torch.isnan(ralu.alpha):
                ralu.alpha = nn.Parameter(torch.tensor(old_alpha))
                reset = True
            if torch.isnan(ralu.beta):
                ralu.beta = nn.Parameter(torch.tensor(old_beta))
                reset = True
            if reset:
                # print(' reset')
                pass

        alphas.append(ralu.alpha.detach().item())
        lambds.append(ralu.beta.detach().item())
        print('\r' + str(j), end='')

    # print(loss)
    # print(mean_loss/batches)
    # print(std_loss/batches)
    # print()
    print('\r', ralu.alpha.detach().item(), ralu.beta.detach().item())

In [None]:
for i in range(100):
    run_experiment()

In [None]:
plt.plot(torch.stack(means).detach().numpy(), torch.stack(stds).detach().numpy(), '-', c='black')
plt.plot(torch.stack(means).detach().numpy(), torch.stack(stds).detach().numpy(), '.')
plt.plot(means[0].detach(), stds[0].detach(), 'o')

In [None]:
plt.plot(alphas, lambds)
plt.plot(alphas[0], lambds[0], 'o')