## Imports and Classes

In [2]:
from tqdm.notebook import tqdm  # Import tqdm for Jupyter Notebook

from src.optimizee import *
from src.initializer import *

from torch.utils.tensorboard import SummaryWriter

## V1 - Output Convex Combination

In [None]:
class LSTMConcurrent_Post(nn.Module):
    """
    LSTM-based optimizer as described in the paper.
    """
    def __init__(self, num_optims, hidden_size=20, preproc=True, preproc_factor=10, learning_rate=0.001):
        super().__init__()
        self.hidden_size = hidden_size
        self.preproc = preproc
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_factor = torch.tensor(preproc_factor, device=self.device)
        self.preproc_threshold = float(torch.exp(-self.preproc_factor))
        self.lr = learning_rate
        
        self.input_size = 2 * num_optims if preproc else 1 * num_optims
        self.lstm = nn.LSTM(self.input_size, hidden_size, 2, batch_first=True).to(self.device)
        self.output_layer = nn.Linear(hidden_size, num_optims).to(self.device)


    def forward(self, x, hidden_state):
        """
        Forward pass of the LSTM optimizer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length, input_size).
            hidden_state (tuple): Hidden state of the LSTM (h, c).

        Returns:
            torch.Tensor: Output updates of shape (batch_size, sequence_length, 1).
            tuple: Updated hidden state.
        """
        x = x.to(self.device)
        hidden_state = tuple(h.to(self.device) for h in hidden_state)

        if self.preproc:
            x = self.preprocess_gradients(x)

        out, hidden_state = self.lstm(x, hidden_state)
        out = self.output_layer(out)

        return out, hidden_state


    def preprocess_gradients(self, gradients):
        """ Applies log transformation & sign extraction to gradients. """
        gradients = gradients.data.to(self.device)
        if len(gradients.size()) == 1:
            gradients = gradients.unsqueeze(-1)

        param_size = gradients.size(0)
        num_optims = gradients.size(1)

        preprocessed = torch.zeros(param_size, 2 * num_optims, device=self.device)

        for i in range(num_optims):
            gradient = gradients[:, i]
            keep_grads = (torch.abs(gradient) >= self.preproc_threshold)

            # Log transformation for large gradients
            preprocessed[keep_grads, 2 * i] = (torch.log(torch.abs(gradient[keep_grads]) + 1e-8) / self.preproc_factor)
            preprocessed[keep_grads, 2 * i + 1] = torch.sign(gradient[keep_grads])

            # Direct scaling for small gradients
            preprocessed[~keep_grads, 2 * i] = -1
            preprocessed[~keep_grads, 2 * i + 1] = (float(torch.exp(self.preproc_factor)) * gradient[~keep_grads])

        return preprocessed


    def initialize_hidden_state(self):
        # Initialize hidden & cell states for LSTM (one per parameter)
        h0 = torch.zeros(2, self.hidden_size, device=self.device)
        c0 = torch.zeros(2, self.hidden_size, device=self.device)
        return (h0, c0)


In [79]:
def train_LSTM_Post(lstm_optimizer, meta_optimizer, initializer, num_epochs=500, time_horizon=200, discount=1, scheduler = None, writer=None):
    
    lstm_optimizer.train()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.ConstantLR(meta_optimizer, factor=1.0, total_iters=num_epochs)


    with tqdm(range(num_epochs), desc="Training Progress") as pbar:
        for epoch in pbar:
            optimizees = initializer.initialize()
            optimizees[0].set_params()
            params = optimizees[0].all_parameters().to(device)
            hidden_state = lstm_optimizer.initialize_hidden_state()
            cumulative_loss = None
            for t in range(time_horizon):
                gradients = []
                for i in range(len(optimizees)):
                    optimizee = optimizees[i]
                    loss, grad_params = optimizee.compute_loss(params, return_grad=True)
                    if i == 0 and discount: cumulative_loss = loss*discount**(time_horizon-1) if cumulative_loss is None else cumulative_loss + loss*discount**(time_horizon-t-1)
                    elif i==0: cumulative_loss = loss
                    gradients.append(grad_params.squeeze().to(device))
                    # if writer and i==0 and epoch==1: writer.add_scalar("Grad", grad_params.squeeze().mean(), t)

                grad_params = torch.stack(gradients).T
                # print(grad_params.shape, len(optimizees))
                output, hidden_state = lstm_optimizer(grad_params, hidden_state)
                alpha = output.mean(dim=0)
                lambda_ = nn.functional.softmax(alpha, dim=0)
                update = grad_params @ lambda_
                params = params + lstm_optimizer.lr * update.unsqueeze(-1)
                # if writer and epoch==1: writer.add_scalar("Update", update.mean(), t)
                optimizees[0].set_params(params)


            # Backpropagation through time (BPTT)
            if writer: writer.add_scalar("Loss", cumulative_loss, epoch)
            meta_optimizer.zero_grad()
            cumulative_loss.backward()
            # torch.nn.utils.clip_grad_norm_(lstm_optimizer.parameters(), 1)
            meta_optimizer.step()
            scheduler.step()

            # Update progress bar
            pbar.set_postfix(loss=cumulative_loss.item())

            num_prints = num_epochs // 10
            if (epoch + 1) % num_prints == 0:
                current_lr = meta_optimizer.param_groups[0]['lr']
                print(f"Epoch [{epoch+1}/{num_epochs}], Cumulative Loss: {cumulative_loss.item():.4f}, LR: {current_lr:.3e}")
                print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")

    print("\nTraining complete!")
    return lstm_optimizer




def test_LSTM_Post(lstm_optimizer, initializer, time_horizon=200, writer=None):
    lstm_optimizer.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizees = initializer.initialize()
    optimizees[0].set_params()
    params = optimizees[0].all_parameters().to(device)
    hidden_state = lstm_optimizer.initialize_hidden_state()

    lambdas = []
    for t in range(time_horizon):
        gradients = []
        for i in range(len(optimizees)):
            optimizee = optimizees[i]
            loss, grad_params = optimizee.compute_loss(params)
            if writer and i==0: writer.add_scalar("Loss", loss, t)
            gradients.append(grad_params.squeeze().to(device))

        grad_params = torch.stack(gradients).T
        # if len(grad_params.shape)==1: grad_params = grad_params.unsqueeze(-1)

        output, hidden_state = lstm_optimizer(grad_params, hidden_state)
        alpha = output.sum(dim=0)
        lambda_ = nn.functional.softmax(alpha, dim=0)
        lambdas.append(lambda_)        
        update = grad_params @ lambda_
        params = params + lstm_optimizer.lr * update.unsqueeze(-1)
        optimizees[0].set_params(params)

    final_loss = optimizees[0].compute_loss(params, return_grad=False)
    print(f"Final Loss: {final_loss}")
    print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")
    lambdas = torch.stack(lambdas).T
    return params, lambdas

In [None]:
torch.manual_seed(1)  # Set random seed for reproducibility
np.random.seed(1)
n=10
W = torch.randn(n, n)  # Random weights for the linear model
theta0 = torch.ones(n,1)  # Random theta for the linear model

kwargs = {"W": [W], "theta0": [theta0], "noise_std": [0.01, 0.03, 0.05, 0.02, 0.04]}

initializer = Initializer(QuadraticOptimizee, kwargs)

lstm_optimizer = LSTMConcurrent_Post(num_optims=initializer.get_num_optims(), preproc=True, learning_rate=1e-4)
meta_optimizer = optim.Adam(lstm_optimizer.parameters(), lr=0.01)

lstm_optimizer = train_LSTM_Post(lstm_optimizer, meta_optimizer, initializer, num_epochs=100, time_horizon=500, discount=0.9)
params, lambdas = test_LSTM_Post(lstm_optimizer, initializer, time_horizon=1000)

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

  self.W = torch.tensor(W, dtype=torch.float32)
  self.theta0 = torch.tensor(theta0, dtype=torch.float32)


Epoch [10/100], Cumulative Loss: 39511.5156, LR: 1.000e-02
Final parameters: [[-6.0194206   1.7483896  -8.478009   -3.5153334  -4.4033933  -5.215141
   0.16786733  0.09905227 -3.0846622  -0.86051357]]...
Epoch [20/100], Cumulative Loss: 11393.5645, LR: 1.000e-02
Final parameters: [[ 1.0221771  -0.28389767 -4.0830293  -1.2902454  -2.0648732  -3.5640893
  -2.3760672  -2.4591022   0.19383122 -0.7780546 ]]...
Epoch [30/100], Cumulative Loss: 33961.9570, LR: 1.000e-02
Final parameters: [[-7.1104503   2.1884773  -6.0318775  -3.5795996  -4.606734   -4.7308726
  -1.8622217   0.98911506 -2.1911707   0.9513819 ]]...
Epoch [40/100], Cumulative Loss: 55718.1289, LR: 1.000e-02
Final parameters: [[-10.690538     0.7943404   -7.4738164   -2.609048    -6.8816843
   -3.9297535   -0.43910682  -0.64680797  -4.6128397   -1.8278034 ]]...
Epoch [50/100], Cumulative Loss: 102859.8125, LR: 1.000e-02
Final parameters: [[-12.168209     1.5275154  -12.236728    -3.9083304  -11.193118
   -7.2349505    0.41631803 

In [75]:
lambdas

tensor([[2.9690e-04, 1.1443e-04, 1.1356e-04,  ..., 1.2278e-04, 1.2277e-04,
         1.2276e-04],
        [1.1508e-07, 2.3006e-08, 2.2740e-08,  ..., 2.5501e-08, 2.5498e-08,
         2.5495e-08],
        [6.7602e-04, 4.0678e-04, 4.0516e-04,  ..., 4.3399e-04, 4.3398e-04,
         4.3398e-04],
        [3.7265e-04, 1.7844e-04, 1.7737e-04,  ..., 1.8482e-04, 1.8481e-04,
         1.8480e-04],
        [9.9865e-01, 9.9930e-01, 9.9930e-01,  ..., 9.9926e-01, 9.9926e-01,
         9.9926e-01]], grad_fn=<PermuteBackward0>)

## V2 - Input Convex Combination

In [89]:
class LSTMConcurrent_Pre(nn.Module):
    """
    LSTM-based optimizer as described in the paper.
    """
    def __init__(self, num_optims, hidden_size=20, preproc=True, preproc_factor=10):
        super().__init__()
        self.hidden_size = hidden_size
        self.preproc = preproc
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preproc_factor = torch.tensor(preproc_factor, device=self.device)
        self.preproc_threshold = float(torch.exp(-self.preproc_factor))
        
        self.input_layer = nn.Linear(num_optims, 1, bias=False).to(self.device)
        self.input_size = 2 if preproc else 1
        self.lstm = nn.LSTM(self.input_size, hidden_size, 2, batch_first=True).to(self.device)
        self.output_layer = nn.Linear(hidden_size, 1).to(self.device)


    def forward(self, x, hidden_state):
        """
        Forward pass of the LSTM optimizer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length, input_size).
            hidden_state (tuple): Hidden state of the LSTM (h, c).

        Returns:
            torch.Tensor: Output updates of shape (batch_size, sequence_length, 1).
            tuple: Updated hidden state.
        """
        x = x.to(self.device)
        hidden_state = tuple(h.to(self.device) for h in hidden_state)
        x = self.input_layer(x)

        if self.preproc:
            x = self.preprocess_gradients(x)

        out, hidden_state = self.lstm(x, hidden_state)
        out = self.output_layer(out)

        return out, hidden_state


    def preprocess_gradients(self, gradients):
        """ Applies log transformation & sign extraction to gradients. """
        gradients = gradients.data.to(self.device)
        if len(gradients.size()) == 1:
            gradients = gradients.unsqueeze(-1)

        param_size = gradients.size(0)
        num_optims = gradients.size(1)

        preprocessed = torch.zeros(param_size, 2 * num_optims, device=self.device)

        for i in range(num_optims):
            gradient = gradients[:, i]
            keep_grads = (torch.abs(gradient) >= self.preproc_threshold)

            # Log transformation for large gradients
            preprocessed[keep_grads, 2 * i] = (torch.log(torch.abs(gradient[keep_grads]) + 1e-8) / self.preproc_factor)
            preprocessed[keep_grads, 2 * i + 1] = torch.sign(gradient[keep_grads])

            # Direct scaling for small gradients
            preprocessed[~keep_grads, 2 * i] = -1
            preprocessed[~keep_grads, 2 * i + 1] = (float(torch.exp(self.preproc_factor)) * gradient[~keep_grads])

        return preprocessed


    def initialize_hidden_state(self):
        # Initialize hidden & cell states for LSTM (one per parameter)
        h0 = torch.zeros(2, self.hidden_size, device=self.device)
        c0 = torch.zeros(2, self.hidden_size, device=self.device)
        return (h0, c0)


In [90]:
def train_LSTM_Pre(lstm_optimizer, meta_optimizer, initializer, num_epochs=500, time_horizon=200, discount=1, scheduler = None, writer=None):
    
    lstm_optimizer.train()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.ConstantLR(meta_optimizer, factor=1.0, total_iters=num_epochs)


    with tqdm(range(num_epochs), desc="Training Progress") as pbar:
        for epoch in pbar:
            optimizees = initializer.initialize()
            optimizees[0].set_params()
            params = optimizees[0].all_parameters().to(device)
            hidden_state = lstm_optimizer.initialize_hidden_state()
            cumulative_loss = None
            for t in range(time_horizon):
                gradients = []
                for i in range(len(optimizees)):
                    optimizee = optimizees[i]
                    loss, grad_params = optimizee.compute_loss(params, return_grad=True)
                    if i == 0 and discount: cumulative_loss = loss*discount**(time_horizon-1) if cumulative_loss is None else cumulative_loss + loss*discount**(time_horizon-t-1)
                    elif i==0: cumulative_loss = loss
                    gradients.append(grad_params.squeeze().to(device))
                    # if writer and i==0 and epoch==1: writer.add_scalar("Grad", grad_params.squeeze().mean(), t)

                grad_params = torch.stack(gradients).T
                # print(grad_params.shape, len(optimizees))
                update, hidden_state = lstm_optimizer(grad_params, hidden_state)
                params = params + update
                # if writer and epoch==1: writer.add_scalar("Update", update.mean(), t)
                optimizees[0].set_params(params)


            # Backpropagation through time (BPTT)
            if writer: writer.add_scalar("Loss", cumulative_loss, epoch)
            meta_optimizer.zero_grad()
            cumulative_loss.backward()
            # torch.nn.utils.clip_grad_norm_(lstm_optimizer.parameters(), 1)
            meta_optimizer.step()
            scheduler.step()

            # Update progress bar
            pbar.set_postfix(loss=cumulative_loss.item())

            num_prints = num_epochs // 10
            if (epoch + 1) % num_prints == 0:
                current_lr = meta_optimizer.param_groups[0]['lr']
                print(f"Epoch [{epoch+1}/{num_epochs}], Cumulative Loss: {cumulative_loss.item():.4f}, LR: {current_lr:.3e}")
                print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")

    print("\nTraining complete!")
    return lstm_optimizer




def test_LSTM_Pre(lstm_optimizer, initializer, time_horizon=200, writer=None):
    lstm_optimizer.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    optimizees = initializer.initialize()
    optimizees[0].set_params()
    params = optimizees[0].all_parameters().to(device)
    hidden_state = lstm_optimizer.initialize_hidden_state()

    for t in range(time_horizon):
        gradients = []
        for i in range(len(optimizees)):
            optimizee = optimizees[i]
            loss, grad_params = optimizee.compute_loss(params)
            if writer and i==0: writer.add_scalar("Loss", loss, t)
            gradients.append(grad_params.squeeze().to(device))

        grad_params = torch.stack(gradients).T
        # if len(grad_params.shape)==1: grad_params = grad_params.unsqueeze(-1)

        update, hidden_state = lstm_optimizer(grad_params, hidden_state)
        params = params + update
        optimizees[0].set_params(params)

    final_loss = optimizees[0].compute_loss(params, return_grad=False)
    print(f"Final Loss: {final_loss}")
    print(f"Final parameters: {(params.detach().cpu().numpy().T)[:10]}...")
    return params, lambdas

In [94]:
torch.manual_seed(1)  # Set random seed for reproducibility
np.random.seed(1)
n=10
W = torch.randn(n, n)  # Random weights for the linear model
theta0 = torch.ones(n,1)  # Random theta for the linear model

kwargs = {"W": [W], "theta0": [theta0], "noise_std": list(np.arange(0.01, 0.11, 0.01))}

initializer = Initializer(QuadraticOptimizee, kwargs)

lstm_optimizer = LSTMConcurrent_Pre(num_optims=initializer.get_num_optims(), preproc=True)
meta_optimizer = optim.Adam(lstm_optimizer.parameters(), lr=0.01)

lstm_optimizer = train_LSTM_Pre(lstm_optimizer, meta_optimizer, initializer, num_epochs=100, time_horizon=500, discount=0.9)
params, lambdas = test_LSTM_Pre(lstm_optimizer, initializer, time_horizon=1000)

nn.utils.parameters_to_vector(lstm_optimizer.input_layer.parameters())

Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch [10/100], Cumulative Loss: 14.4609, LR: 1.000e-02
Final parameters: [[1.2612435  0.76025265 1.4085405  1.5638216  0.44637197 0.6904131
  0.6698243  1.4570692  0.1904761  2.0461488 ]]...
Epoch [20/100], Cumulative Loss: 19.9983, LR: 1.000e-02
Final parameters: [[1.0428268  1.4419596  0.9861763  0.52929705 1.2703655  0.9480345
  0.61638314 1.4384967  1.4138949  1.1917597 ]]...
Epoch [30/100], Cumulative Loss: 20.0624, LR: 1.000e-02
Final parameters: [[1.017815  2.0292325 1.0270126 0.5879456 1.2848508 1.353821  1.0104824
  0.9744402 1.4916581 1.2162242]]...
Epoch [40/100], Cumulative Loss: 13.9848, LR: 1.000e-02
Final parameters: [[0.9706728  0.98461735 0.91066253 0.9710672  1.0461549  0.9121424
  0.6231784  1.3609481  1.2046943  1.1919951 ]]...
Epoch [50/100], Cumulative Loss: 13.7419, LR: 1.000e-02
Final parameters: [[ 1.1615869   0.14044592  1.0822572   1.9291571   0.4783302   0.5552066
   0.93265265  1.3953425  -0.17320907  1.6908815 ]]...
Epoch [60/100], Cumulative Loss: 25.562

tensor([ 0.2291,  0.1965, -0.2289, -0.2277, -0.1912,  0.0397,  0.3152, -0.1998,
         0.1685, -0.1750], grad_fn=<CatBackward0>)

In [95]:
v = nn.utils.parameters_to_vector(lstm_optimizer.input_layer.parameters())
v**2 / torch.norm(v)**2

tensor([0.1216, 0.0894, 0.1213, 0.1201, 0.0847, 0.0037, 0.2301, 0.0924, 0.0658,
        0.0710], grad_fn=<DivBackward0>)