## Imports and Classes

In [1]:
from tqdm.notebook import tqdm  # Import tqdm for Jupyter Notebook

from src.optimizee import *
from src.initializer import *

from torch.utils.tensorboard import SummaryWriter

## V1 - Output Convex Combination

In [29]:
class LSTMConcurrent_Post(nn.Module):
    """
    LSTM-based optimizer as described in the paper.
    """
    def __init__(self, num_optims, hidden_size=20, preproc=True, preproc_factor=10, learning_rate=0.001):
        super().__init__()
        self.hidden_size = hidden_size
        self.preproc = preproc
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_factor = torch.tensor(preproc_factor, device=self.device)
        self.preproc_threshold = float(torch.exp(-self.preproc_factor))
        self.lr = learning_rate
        
        self.input_size = 2 * num_optims if preproc else 1 * num_optims
        self.lstm = nn.LSTM(self.input_size, hidden_size, 2, batch_first=True).to(self.device)
        self.output_layer = nn.Linear(hidden_size, num_optims).to(self.device)


    def forward(self, x, hidden_state):
        """
        x: (num_params, 1, input_size)
        hidden_state: tuple of (h, c) with shape (num_layers, num_params, hidden_size)
        """
        x = x.to(self.device)
        if self.preproc:
            x = self.preprocess_gradients(x)  # shape: (num_params, input_size)
        
        x = x.unsqueeze(1)  # (num_params, 1, input_size) to match LSTM's (batch, seq_len, input_size)

        out, new_hidden_state = self.lstm(x, hidden_state)  # Efficient batch LSTM call
        out = self.output_layer(out).squeeze(1)  # (num_params, 1, 1) → (num_params, 1)
        return out, new_hidden_state


    def preprocess_gradients(self, gradients):
        """ Applies log transformation & sign extraction to gradients. """
        gradients = gradients.data.to(self.device)
        if len(gradients.size()) == 1:
            gradients = gradients.unsqueeze(-1)

        param_size = gradients.size(0)
        num_optims = gradients.size(1)

        preprocessed = torch.zeros(param_size, 2 * num_optims, device=self.device)

        for i in range(num_optims):
            gradient = gradients[:, i]
            keep_grads = (torch.abs(gradient) >= self.preproc_threshold)

            # Log transformation for large gradients
            preprocessed[keep_grads, 2 * i] = (torch.log(torch.abs(gradient[keep_grads]) + 1e-8) / self.preproc_factor)
            preprocessed[keep_grads, 2 * i + 1] = torch.sign(gradient[keep_grads])

            # Direct scaling for small gradients
            preprocessed[~keep_grads, 2 * i] = -1
            preprocessed[~keep_grads, 2 * i + 1] = (float(torch.exp(self.preproc_factor)) * gradient[~keep_grads])

        return preprocessed



    def initialize_hidden_state(self, num_params):
        h0 = torch.zeros(2, num_params, self.hidden_size, device=self.device)
        c0 = torch.zeros(2, num_params, self.hidden_size, device=self.device)
        return (h0, c0)



In [None]:
def train_LSTM_Post(lstm_optimizer, meta_optimizer, initializer, num_epochs=500, time_horizon=200, discount=1, scheduler = None, writer=None):
    
    lstm_optimizer.train()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.ConstantLR(meta_optimizer, factor=1.0, total_iters=num_epochs)


    with tqdm(range(num_epochs), desc="Training Progress") as pbar:
        for epoch in pbar:
            optimizees = initializer.initialize()
            optimizees[0].set_params()
            params = optimizees[0].all_parameters().to(device)
            hidden_state = lstm_optimizer.initialize_hidden_state(params.size(0))
            cumulative_loss = None
            for t in range(time_horizon):
                gradients = []
                for i in range(len(optimizees)):
                    optimizee = optimizees[i]
                    loss, grad_params = optimizee.compute_loss(params, return_grad=True)
                    if i == 0 and discount: cumulative_loss = loss*discount**(time_horizon-1) if cumulative_loss is None else cumulative_loss + loss*discount**(time_horizon-t-1)
                    elif i==0: cumulative_loss = loss
                    gradients.append(grad_params.squeeze().to(device))
                    # if writer and i==0 and epoch==1: writer.add_scalar("Grad", grad_params.squeeze().mean(), t)

                grad_params = torch.stack(gradients).T
                # print(grad_params.shape, len(optimizees))
                output, hidden_state = lstm_optimizer(grad_params, hidden_state)
                alpha = output.mean(dim=0)
                lambda_ = nn.functional.softmax(alpha, dim=0)
                update = grad_params @ lambda_
                params = params + lstm_optimizer.lr * update.unsqueeze(-1)
                # if writer and epoch==1: writer.add_scalar("Update", update.mean(), t)
                optimizees[0].set_params(params)


            # Backpropagation through time (BPTT)
            if writer: writer.add_scalar("Loss", cumulative_loss, epoch)
            meta_optimizer.zero_grad()
            cumulative_loss.backward()
            # torch.nn.utils.clip_grad_norm_(lstm_optimizer.parameters(), 1)
            meta_optimizer.step()
            scheduler.step()

            # Update progress bar
            pbar.set_postfix(loss=cumulative_loss.item())

            num_prints = num_epochs // 10
            if (epoch + 1) % num_prints == 0:
                current_lr = meta_optimizer.param_groups[0]['lr']
                print(f"Epoch [{epoch+1}/{num_epochs}], Cumulative Loss: {cumulative_loss.item():.4f}, LR: {current_lr:.3e}")
                print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")
                print(f"Lambdas: {(lambda_.detach().cpu().numpy().T)[:10]}...")


    print("\nTraining complete!")
    return lstm_optimizer




def test_LSTM_Post(lstm_optimizer, initializer, time_horizon=200, writer=None):
    lstm_optimizer.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizees = initializer.initialize()
    optimizees[0].set_params()
    params = optimizees[0].all_parameters().to(device)
    hidden_state = lstm_optimizer.initialize_hidden_state(params.size(0))

    lambdas = []
    for t in range(time_horizon):
        gradients = []
        for i in range(len(optimizees)):
            optimizee = optimizees[i]
            loss, grad_params = optimizee.compute_loss(params)
            if writer and i==0: writer.add_scalar("Loss", loss, t)
            gradients.append(grad_params.squeeze().to(device))

        grad_params = torch.stack(gradients).T
        # if len(grad_params.shape)==1: grad_params = grad_params.unsqueeze(-1)

        output, hidden_state = lstm_optimizer(grad_params, hidden_state)
        alpha = output.sum(dim=0)
        lambda_ = nn.functional.softmax(alpha, dim=0)
        lambdas.append(lambda_)        
        update = grad_params @ lambda_
        params = params + lstm_optimizer.lr * update.unsqueeze(-1)
        optimizees[0].set_params(params)

    final_loss = optimizees[0].compute_loss(params, return_grad=False)
    print(f"Final Loss: {final_loss}")
    print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")
    lambdas = torch.stack(lambdas).T
    return params, lambdas

In [31]:
torch.manual_seed(1)  # Set random seed for reproducibility
np.random.seed(1)
n=10
W = torch.randn(n, n)  # Random weights for the linear model
theta0 = torch.ones(n,1)  # Random theta for the linear model

kwargs = {"W": [W], "theta0": [theta0], "noise_std": [0.01, 0.05, 0.1, 0.15, 0.2]}

initializer = Param_Initializer(QuadraticOptimizee, kwargs)

lstm_optimizer = LSTMConcurrent_Post(num_optims=initializer.get_num_optims(), preproc=True, learning_rate=1e-4)
meta_optimizer = optim.Adam(lstm_optimizer.parameters(), lr=0.001)

lstm_optimizer = train_LSTM_Post(lstm_optimizer, meta_optimizer, initializer, num_epochs=100, time_horizon=200, discount=0)
params, lambdas = test_LSTM_Post(lstm_optimizer, initializer, time_horizon=1000)

print(lambdas[:,-1])

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch [10/100], Cumulative Loss: 538.5627, LR: 1.000e-03
Final parameters: [[-0.87955356  0.52410847 -2.448278   -1.2886281  -1.0808189  -1.8231136
   0.982601    0.5773259  -0.6236282   0.6687098 ]]...
Lambdas: [0.2319119  0.16179897 0.18114139 0.17833005 0.24681768]...
Epoch [20/100], Cumulative Loss: 207.1111, LR: 1.000e-03
Final parameters: [[ 0.9308614   0.15551014 -0.13941224  0.11179635 -0.98587835 -1.4441396
  -1.2371871  -0.83737427  0.82849985 -0.49477306]]...
Lambdas: [0.2274417  0.16046692 0.18279006 0.18195894 0.24734229]...
Epoch [30/100], Cumulative Loss: 451.1354, LR: 1.000e-03
Final parameters: [[-1.6675334   0.8975835  -1.1228527  -1.4673119  -1.4349031  -1.6759653
  -1.1139376   1.4055      0.14894584  2.3297148 ]]...
Lambdas: [0.22810116 0.1586107  0.18177626 0.18184382 0.24966803]...
Epoch [40/100], Cumulative Loss: 713.3618, LR: 1.000e-03
Final parameters: [[-3.3408432  -0.9995748  -1.4017767  -0.05888959 -2.854169   -0.38491982
   0.39384905 -0.51047796 -1.499851

## V2 - Input Convex Combination

In [2]:
class LSTMConcurrent_Pre(nn.Module):
    """
    LSTM-based optimizer as described in the paper.
    """
    def __init__(self, num_optims, hidden_size=20, preproc=True, preproc_factor=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.preproc = preproc
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_factor = torch.tensor(preproc_factor, device=self.device)
        self.preproc_threshold = float(torch.exp(-self.preproc_factor))
        
        self.input_layer = nn.Linear(num_optims, 1, bias=False).to(self.device)
        self.input_size = 2 if preproc else 1
        self.lstm = nn.LSTM(self.input_size, hidden_size, 2, batch_first=True).to(self.device)
        self.output_layer = nn.Linear(hidden_size, 1).to(self.device)



    def forward(self, x, hidden_state):
        """
        x: (num_params, 1, input_size)
        hidden_state: tuple of (h, c) with shape (num_layers, num_params, hidden_size)
        """
        x = x.to(self.device)
        alphas = nn.utils.parameters_to_vector(self.input_layer.parameters()).to(self.device)
        lambdas_ = nn.functional.softmax(alphas, dim=0)
        x = x @ lambdas_.unsqueeze(-1)  # (num_params, 1, input_size) @ (input_size, 1) = (num_params, 1, 1)
                
        if self.preproc:
            x = self.preprocess_gradients(x)  # shape: (num_params, input_size)
        
        x = x.unsqueeze(1)  # (num_params, 1, input_size) to match LSTM's (batch, seq_len, input_size)

        out, new_hidden_state = self.lstm(x, hidden_state)  # Efficient batch LSTM call
        out = self.output_layer(out).squeeze(1)  # (num_params, 1, 1) → (num_params, 1)
        return out, new_hidden_state


    def preprocess_gradients(self, gradients):
        """ Applies log transformation & sign extraction to gradients. """
        gradients = gradients.data.to(self.device)
        if len(gradients.size()) == 1:
            gradients = gradients.unsqueeze(-1)

        param_size = gradients.size(0)
        num_optims = gradients.size(1)

        preprocessed = torch.zeros(param_size, 2 * num_optims, device=self.device)

        for i in range(num_optims):
            gradient = gradients[:, i]
            keep_grads = (torch.abs(gradient) >= self.preproc_threshold)

            # Log transformation for large gradients
            preprocessed[keep_grads, 2 * i] = (torch.log(torch.abs(gradient[keep_grads]) + 1e-8) / self.preproc_factor)
            preprocessed[keep_grads, 2 * i + 1] = torch.sign(gradient[keep_grads])

            # Direct scaling for small gradients
            preprocessed[~keep_grads, 2 * i] = -1
            preprocessed[~keep_grads, 2 * i + 1] = (float(torch.exp(self.preproc_factor)) * gradient[~keep_grads])

        return preprocessed


    def initialize_hidden_state(self, num_params):
        h0 = torch.zeros(2, num_params, self.hidden_size, device=self.device)
        c0 = torch.zeros(2, num_params, self.hidden_size, device=self.device)
        return (h0, c0)



In [7]:
def train_LSTM_Pre(lstm_optimizer, meta_optimizer, initializer, num_epochs=500, time_horizon=200, discount=1, scheduler = None, writer=None):
    
    lstm_optimizer.train()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.ConstantLR(meta_optimizer, factor=1.0, total_iters=num_epochs)

    lambdas_ = torch.zeros((initializer.get_num_optims(), num_epochs), device=device)
    with tqdm(range(num_epochs), desc="Training Progress") as pbar:
        for epoch in pbar:
            optimizees = initializer.initialize()
            optimizees[0].set_params()
            params = optimizees[0].all_parameters().to(device)
            hidden_state = lstm_optimizer.initialize_hidden_state(params.size(0))
            cumulative_loss = None
            for t in range(time_horizon):
                gradients = []
                for i in range(len(optimizees)):
                    optimizee = optimizees[i]
                    loss, grad_params = optimizee.compute_loss(params, return_grad=True)
                    if i == 0 and discount: cumulative_loss = loss*discount**(time_horizon-1) if cumulative_loss is None else cumulative_loss + loss*discount**(time_horizon-t-1)
                    elif i==0: cumulative_loss = loss
                    gradients.append(grad_params.squeeze().to(device))
                    # if writer and i==0 and epoch==1: writer.add_scalar("Grad", grad_params.squeeze().mean(), t)

                grad_params = torch.stack(gradients).T
                # print(grad_params.shape, len(optimizees))
                update, hidden_state = lstm_optimizer(grad_params, hidden_state)
                params = params + update
                # if writer and epoch==1: writer.add_scalar("Update", update.mean(), t)
                optimizees[0].set_params(params)


            # Backpropagation through time (BPTT)
            if writer: writer.add_scalar("Loss", cumulative_loss, epoch)
            meta_optimizer.zero_grad()
            cumulative_loss.backward()
            # torch.nn.utils.clip_grad_norm_(lstm_optimizer.parameters(), 1)
            meta_optimizer.step()
            scheduler.step()

            # Update progress bar
            pbar.set_postfix(loss=cumulative_loss.item())

            num_prints = num_epochs // 10
            if (epoch + 1) % num_prints == 0:
                current_lr = meta_optimizer.param_groups[0]['lr']
                print(f"Epoch [{epoch+1}/{num_epochs}], Cumulative Loss: {cumulative_loss.item():.4f}, LR: {current_lr:.3e}")
                print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")
                print(f"Input Weights: {(nn.functional.softmax(nn.utils.parameters_to_vector(lstm_optimizer.input_layer.parameters()), dim=0))[:10]}...")

            lam = nn.functional.softmax(nn.utils.parameters_to_vector(lstm_optimizer.input_layer.parameters()), dim=0)
            if lam[0]>0.99: 
                print("Stopping at epoch", epoch+1, "due to convergence.")
                break
            lambdas_[:, epoch] = lam

    print("\nTraining complete!")
    return lstm_optimizer, lambdas_




def test_LSTM_Pre(lstm_optimizer, initializer, time_horizon=200, writer=None):
    lstm_optimizer.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizees = initializer.initialize()
    optimizees[0].set_params()
    params = optimizees[0].all_parameters().to(device)
    hidden_state = lstm_optimizer.initialize_hidden_state(params.size(0))

    for t in range(time_horizon):
        gradients = []
        for i in range(len(optimizees)):
            optimizee = optimizees[i]
            loss, grad_params = optimizee.compute_loss(params)
            if writer and i==0: writer.add_scalar("Loss", loss, t)
            gradients.append(grad_params.squeeze().to(device))

        grad_params = torch.stack(gradients).T
        # if len(grad_params.shape)==1: grad_params = grad_params.unsqueeze(-1)

        update, hidden_state = lstm_optimizer(grad_params, hidden_state)
        params = params + update
        optimizees[0].set_params(params)

    final_loss = optimizees[0].compute_loss(params, return_grad=False)
    print(f"Final Loss: {final_loss}")
    print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")
    return params

In [8]:
torch.manual_seed(1)  # Set random seed for reproducibility
np.random.seed(1)
n=10
W = torch.randn(n, n)  # Random weights for the linear model
theta0 = torch.ones(n,1)  # Random theta for the linear model

kwargs = {"W": [W], "theta0": [theta0], "noise_std": [0.01, 0.1, 0.2, 0.3, 0.4]}

initializer = Param_Initializer(QuadraticOptimizee, kwargs)

lstm_optimizer = LSTMConcurrent_Pre(num_optims=initializer.get_num_optims(), preproc=False)
meta_optimizer = optim.Adam(lstm_optimizer.parameters(), lr=0.01)

lstm_optimizer, lambdas_ = train_LSTM_Pre(lstm_optimizer, meta_optimizer, initializer, num_epochs=10000, time_horizon=50, discount=0.9)
params = test_LSTM_Pre(lstm_optimizer, initializer, time_horizon=50)

v = nn.utils.parameters_to_vector(lstm_optimizer.input_layer.parameters())
print(nn.functional.softmax(v, dim=0))

Training Progress:   0%|          | 0/10000 [00:00<?, ?it/s]

Epoch [1000/10000], Cumulative Loss: 4.2635, LR: 1.000e-02
Final parameters: [[1.0497031  0.75563484 1.058384   1.2414     0.8961162  0.90815145
  0.9265303  1.0820944  0.63031507 1.0479609 ]]...
Input Weights: tensor([0.8565, 0.0708, 0.0319, 0.0206, 0.0201], grad_fn=<SliceBackward0>)...
Epoch [2000/10000], Cumulative Loss: 1.6800, LR: 1.000e-02
Final parameters: [[0.8665524  1.368811   0.8716192  0.5881127  1.2063721  1.2046357
  1.1790186  0.8200624  1.7120963  0.79581356]]...
Input Weights: tensor([0.9394, 0.0295, 0.0135, 0.0090, 0.0085], grad_fn=<SliceBackward0>)...
Epoch [3000/10000], Cumulative Loss: 2.0430, LR: 1.000e-02
Final parameters: [[1.0513903  0.8425856  1.0537324  1.1599139  0.9206704  0.92245924
  0.92868984 1.0703173  0.69660276 1.0521542 ]]...
Input Weights: tensor([0.9643, 0.0171, 0.0080, 0.0055, 0.0050], grad_fn=<SliceBackward0>)...
Epoch [4000/10000], Cumulative Loss: 2.0883, LR: 1.000e-02
Final parameters: [[0.91533995 1.3589557  0.9013477  0.6310213  1.1695914  

In [9]:
np.save("lambdas_.npy", lambdas_.detach().cpu().numpy())