In [29]:
import torch
from torch import Tensor
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Callable

In [30]:
class AdaptiveSGQ:

    def __init__(self, functions: List[Callable[[Tensor], Tensor]], learning_rate: float = 0.1,
                 initial_p: float = 0.3, switch_threshold: float = 0.1, dim: int = 2):
        self.functions = functions
        self.n_functions = len(functions)
        self.lr = learning_rate
        self.current_p = initial_p
        self.min_p = 0.05
        self.max_p = 0.8
        self.switch_threshold = switch_threshold
        self.algorithm_mode = 'SGQ'
        self.iteration = 0
        self.loss_history = []
        self.p_history = []
        self.mode_history = []
        self.best_loss = float('inf')
        self.stagnation_count = 0
        self.x = torch.nn.Parameter(torch.randn(dim) * 0.5, requires_grad=True)


    def Adaptive_p_update(self, current_loss: float, improvement: float):
      # Tracking improvement history
        if not hasattr(self, 'improvement_history'):
          self.improvement_history = []

        self.improvement_history.append(improvement)
        if len(self.improvement_history) > 10:
          self.improvement_history.pop(0)


        avg_improvement = np.mean(self.improvement_history) if self.improvement_history else 0

    # Adaptive p-update
        if improvement > 0.02:
          self.current_p = max(self.min_p, self.current_p * 0.9)
        elif improvement < 0.005:
          self.current_p = min(self.max_p, self.current_p * 1.1)

        recent_avg = 0
        older_avg = 0


        if len(self.improvement_history) >= 10:
          recent_avg = np.mean(self.improvement_history[-5:])
          older_avg = np.mean(self.improvement_history[:5])

          if recent_avg < older_avg * 0.3:
              if self.algorithm_mode == 'SGQ':
                  self.algorithm_mode = 'SGD'
                  print(f'ðŸ”„ Switching to SGD at iteration {self.iteration} (improvement stalled)')
              else:
                  self.algorithm_mode = 'SGQ'
                  print(f'ðŸ”„ Switching back to SGQ at iteration {self.iteration}')

              # Reset improvement history after switch
              self.improvement_history = []


        if (current_loss < self.switch_threshold and
          self.algorithm_mode == 'SGQ' and
          self.iteration > 100):
          self.algorithm_mode = 'SGD'
          print(f'ðŸ“‰ Switching to SGD at iteration {self.iteration} (loss: {current_loss:.6f})')


    def strategic_selection(self) -> int:
        gradients = []
        for i in range(self.n_functions):
            self.x.grad = None
            loss = self.functions[i](self.x)
            loss.backward(retain_graph=True)
            if self.x.grad is not None:
                gradients.append(self.x.grad.clone())
            else:
                gradients.append(torch.zeros_like(self.x))

        full_gradient = torch.mean(torch.stack(gradients), dim=0)

        if np.random.random() < self.current_p:
            return np.random.randint(self.n_functions)

        eis = []
        for grad in gradients:
            ei = torch.norm(grad - full_gradient).item()
            eis.append(ei)

        return np.argmax(eis)

    def compute_current_loss(self) -> float:
        total_loss = 0
        for function in self.functions:
            total_loss += function(self.x).item()
        return total_loss / self.n_functions

    def optimize_step(self):
        if self.algorithm_mode == 'SGQ':
            selected_index = self.strategic_selection()
        else:
            selected_index = np.random.randint(self.n_functions)

        previous_loss = self.compute_current_loss() if self.loss_history else None

        self.x.grad = None
        loss = self.functions[selected_index](self.x)
        loss.backward()

        with torch.no_grad():
            if self.x.grad is not None:
                self.x -= self.lr * self.x.grad

        current_loss = self.compute_current_loss()

        if previous_loss is not None:
          improvement = previous_loss - current_loss
        else:
          improvement = 0

        if self.best_loss != float('inf'):
            improvement = self.best_loss - current_loss
            if improvement < 1e-5:
                self.stagnation_count += 1
            else:
                self.stagnation_count = 0
        else:
            improvement = 0

        if current_loss < self.best_loss:
            self.best_loss = current_loss

        if self.algorithm_mode == 'SGQ':
            self.Adaptive_p_update(current_loss, improvement)

        self.loss_history.append(current_loss)
        self.p_history.append(self.current_p)
        self.mode_history.append(self.algorithm_mode)
        self.iteration += 1

        return current_loss

    def optimize(self, iterations: int = 1000, patience: int = 100, verbose: bool = True):
      print(f'Starting optimization with {self.algorithm_mode}')
      print(f'Initial p: {self.current_p:.3f}')
      print(f'Switch threshold: {self.switch_threshold}')

      no_improvement_count = 0
      best_loss = float('inf')

      for i in range(iterations):
          loss = self.optimize_step()

          if loss < best_loss:
              best_loss = loss
              no_improvement_count = 0
          else:
              no_improvement_count += 1

          if verbose and i % 100 == 0:
              mode_indicator = 'ðŸ”µ' if self.algorithm_mode == 'SGQ' else 'ðŸ”´'
              print(f'{mode_indicator} Iteration {i}: Loss = {loss:.6f}, p = {self.current_p:.3f}')

          # Early stopping
          if loss < 1e-6:
              print(f'Converged at iteration {i}')
              break

          if no_improvement_count >= patience:
              print(f'Early stopping at iteration {i} (no improvement for {patience} iterations)')
              break


      final_mode = 'SGD' if self.algorithm_mode == 'SGD' else 'SGQ'
      print(f'Optimization completed in {self.iteration} iterations')
      print(f'Final mode: {final_mode}')
      print(f'Final loss: {self.loss_history[-1]:.6f}')
      print(f'Final p: {self.p_history[-1]:.3f}')

In [31]:
def create_heterogeneous_functions(n_functions: int = 4):
    functions = []

    centers = [
        torch.tensor([2.0, 2.0]),
        torch.tensor([-2.0, -2.0]),
        torch.tensor([1.5, -1.5]),
        torch.tensor([-1.5, 1.5]),
    ]


    scales = [1.0, 0.8, 1.2, 0.9]

    for i, center in enumerate(centers[:n_functions]):
        scale = scales[i]

        def make_func(c=center.clone(), s=scale):
            def func(x):
                return s * torch.norm(x - c)**2
            return func

        functions.append(make_func())

    return functions

In [48]:
test_functions = create_heterogeneous_functions()

optimizer = AdaptiveSGQ(
    functions=test_functions,
    learning_rate=0.005,
    initial_p=0.4,
    switch_threshold=0.5
)

results = optimizer.optimize(iterations=2000)

Starting optimization with SGQ
Initial p: 0.400
Switch threshold: 0.5
ðŸ”µ Iteration 0: Loss = 6.792745, p = 0.440
ðŸ”„ Switching to SGD at iteration 34 (improvement stalled)
ðŸ”´ Iteration 100: Loss = 6.065800, p = 0.157
ðŸ”´ Iteration 200: Loss = 5.978937, p = 0.157
ðŸ”´ Iteration 300: Loss = 5.920213, p = 0.157
ðŸ”´ Iteration 400: Loss = 5.953159, p = 0.157
Early stopping at iteration 425 (no improvement for 100 iterations)
Optimization completed in 426 iterations
Final mode: SGD
Final loss: 5.941225
Final p: 0.157


In [49]:
print('final results')
print(f'loss: {optimizer.loss_history[-1]:.6f}')
print(f'best loss: {optimizer.best_loss:.6f}')
print(f'final mode: {optimizer.algorithm_mode}')
print(f'final p: {optimizer.p_history[-1]:.3f}')
print(f'total iteration: {optimizer.iteration}')

final results
loss: 5.941225
best loss: 5.916612
final mode: SGD
final p: 0.157
total iteration: 426


In [50]:
initial_loss = optimizer.loss_history[0]
final_loss = optimizer.loss_history[-1]

improvement = (initial_loss - final_loss) / initial_loss * 100

print(f'improvement: {improvement:.1f}%')
print(f'initial loss: {initial_loss:.6f}')

improvement: 12.5%
initial loss: 6.792745
