In [20]:
import torch
from torch import nn
from torch import optim
import matplotlib.pyplot as plt

In [297]:
class UniNN(nn.Module):
    def __init__(self, hidden,B = 2):
        super(UniNN, self).__init__()
        self.depth = 1
        self.device = "cpu"
        self.order = float('inf')
        self.B = B
        #if hidden > 500:
        #    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.hidden = hidden       
        self.fc1 = nn.Linear(1, self.hidden)
        self.B_1 = torch.ones(1)
        self.register_parameter(name='fc1_max_layer_norm', param=torch.nn.Parameter(self.B_1))
        self.dropout = nn.Dropout(p=0.1)
        #self.batch_norm = nn.BatchNorm1d(self.hidden)
        self.activation = nn.ReLU()
        #self.activation = nn.Tanh()                         
        self.fc2 = nn.Linear(self.hidden, 1)
        self.B_2 = torch.ones(1)
        self.register_parameter(name='fc2_max_layer_norm', param=torch.nn.Parameter(self.B_2))
        #print(self.B_1, self.B_2)
        self.test_loss_reached = False
        self.end_test_loss = 0         

    def forward(self, x):
        out_1 = self.fc1(x.to(self.device))
        out_drop = self.dropout(out_1)
        out_act = self.activation(out_drop)
        return self.fc2(out_act)

    def get_dataloader(self,f,num_samples=5000, batch_size = 32):
        X = torch.vstack((torch.rand(num_samples, 1), torch.zeros(num_samples //20 ,1)))
        train_dataset = torch.utils.data.TensorDataset(X, f(X))
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        return train_dataloader

    def fit(self, dataloader, dataloader_test, epochs=100, lr=0.001, decay = 1e-3, B = 10):
        self.to(self.device)
        criterion = nn.MSELoss()
        optimizer = optim.RAdam([
            {'params': self.fc1.weight},  # Only fc1 weights
            {'params': self.fc1.bias},   # Only fc1 biases
            {'params': self.fc2.weight},  # Only fc2 weights
            {'params': self.fc2.bias},   # Only fc2 biases,  # Only fc1 weights,   # Only fc1 biases
        ], lr=lr)
        #optimizer_b = optim.RAdam([
        #], lr=1e-5)
        train_losses = []
        test_losses = []
        iters = 0
        epoch = 0
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=333, gamma=0.1)
        while True:
            epoch += 1
            self.train()
            running_train_loss = 0.0
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self(inputs)
                loss = criterion(outputs, labels)
                overall_loss = loss + (dict(self.named_parameters())[f'fc1_max_layer_norm'].data**2 + dict(self.named_parameters())[f'fc2_max_layer_norm'].data**2).item() * 0.001
                optimizer.zero_grad()
                #self.project_B()
                running_train_loss += loss.item()
                self.compute_lipschitz_constant()
                overall_loss.backward()
                self.compute_grad_B()
                self.update_B()
                optimizer.step()
                self.project_B()
                self.project_W()
                    
            avg_train_loss = running_train_loss / len(dataloader)
            train_losses.append(avg_train_loss)
            self.eval() 
            running_test_loss = 0.0
            with torch.no_grad():
                for inputs, labels in dataloader_test:
                    outputs = self(inputs.to(self.device))
                    loss = criterion(outputs, labels.to(self.device))
                    running_test_loss += loss.item()
                    
            avg_test_loss = running_test_loss / len(dataloader_test)
            test_losses.append(avg_test_loss)
            if avg_test_loss < 5e-3:
                iters += 1
            else:
                iters = 0
            if iters == 10:
                self.test_loss_reached = True
            scheduler.step()
            if epoch > 100:
                break
        self.end_test_loss = sum(test_losses[-5:]) / 5
        self.end_train_loss = sum(train_losses[-5:]) / 5  
        self.model_err_sup_norm()
        print(self.hidden, B,self.L, self.test_loss_reached, test_losses[-5:])
        
        return test_losses
    
    def update_B(self):
        #print(self.L)
        #print("fc1:", self.fc1_max_layer_norm.data, self.fc2_max_layer_norm.data, "before")
        self.fc1_max_layer_norm.data = self.fc1_max_layer_norm.data - 0.1 * self.grad_B_1
        self.fc2_max_layer_norm.data = self.fc2_max_layer_norm.data - 0.1 * self.grad_B_2
        #print("fc1:",self.fc1_max_layer_norm.data, self.fc2_max_layer_norm.data, "after")
        self.grad_B_1 = 0
        self.grad_B_2 = 0

    def compute_grad_B(self):
        self.grad_B_1 = torch.sum(self.fc1.weight.grad * self.fc1.weight.data / torch.linalg.matrix_norm(self.fc1.weight.data, ord = self.order)) + 2 * 0.001 * self.fc1_max_layer_norm.data ## Reg
        self.grad_B_2 = torch.sum(self.fc2.weight.grad * self.fc2.weight.data / torch.linalg.matrix_norm(self.fc2.weight.data, ord = self.order)) + 2 * 0.001 * self.fc1_max_layer_norm.data ## Reg

    def project_W(self):
        self.fc1.weight.data *= self.fc1_max_layer_norm / torch.linalg.matrix_norm(self.fc1.weight.data, ord = self.order)
        self.fc2.weight.data *= self.fc2_max_layer_norm / torch.linalg.matrix_norm(self.fc2.weight.data, ord = self.order)

    def project_B(self):
        return
    # Retrieve B1 and B2 from the parameters
        B1 = torch.clamp(self.fc1_max_layer_norm.data, min=1e-6, max=self.B)
        B2 = torch.clamp(self.fc2_max_layer_norm.data, min=1e-6, max=self.B)

        # Compute the product of B1 and B2
        #product = B1 * B2

        # If product is already equal to self.B, no need to adjust
        #if torch.isclose(product, torch.tensor(self.B, dtype=torch.float32, device=product.device)):
        #    return

        # Ensure product B1 * B2 = self.B by normalizing one of the values
        print(B1 , B2)
        if B1 < B2:
            # Normalize B1 such that the product B1 * B2 = self.B
            B1 = self.B / B2
        else:
            # Normalize B2 such that the product B1 * B2 = self.B
            B2 = self.B / B1

        # Clamp the normalized B1 and B2 to stay within [1e-6, self.B]
        B1 = torch.clamp(B1, min=1e-6, max=self.B)
        B2 = torch.clamp(B2, min=1e-6, max=self.B)

        # Reassign the updated values back to the parameters
        self.fc1_max_layer_norm.data = B1
        self.fc2_max_layer_norm.data = B2
        print(self.fc1_max_layer_norm.data, self.fc2_max_layer_norm.data)

    def plot_model(self, f, title):
        x_train = torch.linspace(0,1,1000)
        with torch.no_grad():
            y_pred = self(x_train.view(-1,1))
        print("estimated_sup_norm_error", torch.max(torch.abs(y_pred.view(-1) - f(x_train))))
        plt.plot(x_train.numpy(), y_pred.numpy(), label="Model")
        plt.plot(x_train.numpy(), f(x_train).numpy(), label="Objective")
        plt.title("Model Predictions vs Data" + ' decay:' + title)
        plt.xlabel("x")
        plt.ylabel("y")
        plt.legend()
        plt.show()

    def reg(self):
        reg_loss = 0
        for name, param in self.named_parameters():
            if 'weight' in name:
                reg_loss += torch.linalg.matrix_norm(param, ord = 1)
        return reg_loss

    def model_err_sup_norm(self):
        x_train = torch.linspace(0,1,1000)
        with torch.no_grad():
            y_pred = self(x_train.view(-1,1)).to("cpu")
        self.sup_err = torch.max(torch.abs(y_pred.view(-1) - f(x_train)))

    def compute_lipschitz_constant(self):
        self.L = 1
        for name, param in self.named_parameters():
            if 'weight' in name:
                self.L *= torch.linalg.matrix_norm(param, ord = self.order)
                #self.L *= spectral_norm(param)
                


In [293]:
def f(X):
    return X**0.23

In [299]:
model = UniNN(16)
dataloader_train = model.get_dataloader(f)
dataloader_test = model.get_dataloader(f, num_samples=200)
models = {}
Bs = [2,9] 
#Bs = [2]
#decays = [1e-2]
for i in range(3,4):
    diction = {2**i: []}
    for B in Bs:
        model = UniNN(2**i, B = B)
        loss = model.fit(lr = 1e-3, dataloader=dataloader_train, dataloader_test=dataloader_test, decay=0, B = B)
        diction[2**i].append(model)
    models.update(diction)

KeyboardInterrupt: 