## checking NAdamW behaviour
- Checking if NAadamW can be mathematically reduced to alpha * d (a stepsize and a direction)

### (1) Define NAdamW

In [6]:
import sys
import os
absolute_path = "/home/suckrowd/Documents/taylor"
if absolute_path not in sys.path:
    sys.path.append(absolute_path)
    
import math
from typing import Dict, Iterator, List, Tuple

from absl import logging
import torch
from torch import Tensor
import torch.distributed.nn as dist_nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LinearLR
from torch.optim.lr_scheduler import SequentialLR
import torch.nn as nn

import random
import time
from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
from source.adaptive_optimizer import AdaptiveLROptimizer
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from curvlinops import GGNLinearOperator
import csv


USE_PYTORCH_DDP = pytorch_setup()[0]

HPARAMS = {
    "dropout_rate": 0.1,
    "learning_rate": 0.0017486387539278373,
    "one_minus_beta1": 0.06733926164,
    "beta2": 0.9955159689799007,
    "weight_decay": 0.08121616522670176,
    "warmup_factor": 0.02
}


# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
class NAdamW(torch.optim.Optimizer):
  r"""Implements NAdamW algorithm.

    See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
    the NAdam algorithm (there is also a comment in the code which highlights
    the only difference of NAdamW and AdamW).
    For further details regarding the algorithm we refer to
    `Decoupled Weight Decay Regularization`_.

    Args:
      params (iterable): iterable of parameters to optimize or dicts defining
          parameter groups
      lr (float, optional): learning rate (default: 1e-3)
      betas (Tuple[float, float], optional): coefficients used for computing
          running averages of gradient and its square (default: (0.9, 0.999))
      eps (float, optional): term added to the denominator to improve
          numerical stability (default: 1e-8)
      weight_decay (float, optional): weight decay coefficient (default: 1e-2)
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
  """

  def __init__(self,
               params,
               lr=1e-3,
               betas=(0.9, 0.999),
               eps=1e-8,
               weight_decay=1e-2):
    if not 0.0 <= lr:
      raise ValueError(f'Invalid learning rate: {lr}')
    if not 0.0 <= eps:
      raise ValueError(f'Invalid epsilon value: {eps}')
    if not 0.0 <= betas[0] < 1.0:
      raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}')
    if not 0.0 <= betas[1] < 1.0:
      raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
    if not 0.0 <= weight_decay:
      raise ValueError(f'Invalid weight_decay value: {weight_decay}')
    defaults = {
        'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
    }
    super().__init__(params, defaults)

  def __setstate__(self, state):
    super().__setstate__(state)
    state_values = list(self.state.values())
    step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
        state_values[0]['step'])
    if not step_is_tensor:
      for s in state_values:
        s['step'] = torch.tensor(float(s['step']))

  @torch.no_grad()
  def step(self, closure=None):
    """Performs a single optimization step.

        Args:
          closure (callable, optional): A closure that reevaluates the model
              and returns the loss.
    """
    self._cuda_graph_capture_health_check()

    loss = None
    if closure is not None:
      with torch.enable_grad():
        loss = closure()

    for group in self.param_groups:
      params_with_grad = []
      grads = []
      exp_avgs = []
      exp_avg_sqs = []
      state_steps = []
      beta1, beta2 = group['betas']

      for p in group['params']:
        if p.grad is None:
          continue
        params_with_grad.append(p)
        if p.grad.is_sparse:
          raise RuntimeError('NAdamW does not support sparse gradients')
        grads.append(p.grad)

        state = self.state[p]

        # State initialization
        if len(state) == 0:
          state['step'] = torch.tensor(0.)
          # Exponential moving average of gradient values
          state['exp_avg'] = torch.zeros_like(
              p, memory_format=torch.preserve_format)
          # Exponential moving average of squared gradient values
          state['exp_avg_sq'] = torch.zeros_like(
              p, memory_format=torch.preserve_format)

        exp_avgs.append(state['exp_avg'])
        exp_avg_sqs.append(state['exp_avg_sq'])
        state_steps.append(state['step'])

      nadamw(
          params_with_grad,
          grads,
          exp_avgs,
          exp_avg_sqs,
          state_steps,
          beta1=beta1,
          beta2=beta2,
          lr=group['lr'],
          weight_decay=group['weight_decay'],
          eps=group['eps'])

    return loss


def nadamw(params: List[Tensor],
           grads: List[Tensor],
           exp_avgs: List[Tensor],
           exp_avg_sqs: List[Tensor],
           state_steps: List[Tensor],
           beta1: float,
           beta2: float,
           lr: float,
           weight_decay: float,
           eps: float) -> None:
  r"""Functional API that performs NAdamW algorithm computation.
    See NAdamW class for details.
  """

  if not all(isinstance(t, torch.Tensor) for t in state_steps):
    raise RuntimeError(
        'API has changed, `state_steps` argument must contain a list of' +
        ' singleton tensors')

  for i, param in enumerate(params):
    grad = grads[i]
    exp_avg = exp_avgs[i]
    exp_avg_sq = exp_avg_sqs[i]
    step_t = state_steps[i]

    # Update step.
    step_t += 1

    # Perform stepweight decay.
    param.mul_(1 - lr * weight_decay)

    # Decay the first and second moment running average coefficient.
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

    # Only difference between NAdamW and AdamW in this implementation.
    # The official PyTorch implementation of NAdam uses a different algorithm.
    # We undo these ops later on, which could cause numerical issues but saves
    # us from having to make an extra copy of the gradients.
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

    step = step_t.item()

    bias_correction1 = 1 - beta1**step
    bias_correction2 = 1 - beta2**step

    step_size = lr / bias_correction1

    bias_correction2_sqrt = math.sqrt(bias_correction2)
    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

    param.addcdiv_(exp_avg, denom, value=-step_size)
    exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)


## Define simple Model

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim

# Simple test model: Single hidden layer
class TestModel(nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)


### Setup NAdamW and Data

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize two models with the same parameters
model1 = SimpleNet()
model2 = SimpleNet()

# Make sure the models are initialized with the same weights
model2.load_state_dict(model1.state_dict())

lr = 0.01
# Define the optimizer with a regular learning rate for model1
optimizer1 = NAdamW(model1.parameters(), lr=lr)

# Define the optimizer with a doubled learning rate for model2
optimizer2 = NAdamW(model2.parameters(), lr=lr*2)

# save the initialo parameters for comparison
params1_before = {name: param.clone() for name, param in model1.named_parameters()}
params2_before = {name: param.clone() for name, param in model2.named_parameters()}

# Create some dummy input and target data
torch.manual_seed(37)
input_data = torch.randn(20, 10)  # Batch size of 20, 10 features
target_data = torch.randn(20, 1)  # Target for regression
print("input_data", input_data)
print("target_data", target_data)

# Loss function
loss_fn = nn.MSELoss()

# Forward and backward pass for model 1
output1 = model1(input_data)
loss1 = loss_fn(output1, target_data)
optimizer1.zero_grad()
loss1.backward()
optimizer1.step()

# Save the new parameters for comparison
params1_after = {name: param.clone() for name, param in model1.named_parameters()}

# Forward and backward pass for model 2 (same data)
output2 = model2(input_data)
loss2 = loss_fn(output2, target_data)
optimizer2.zero_grad()
loss2.backward()
optimizer2.step()

print("Compare the parameters before optimization")
# Save the new parameters for comparison
params2_after = {name: param.clone() for name, param in model2.named_parameters()}
# Now, compare the parameters before optimization
for (name1, param1), (name2, param2) in zip(params1_before.items(), params2_before.items()):
    print(f"Comparing {name1}:")
    print(f"Model 1: {param1}")
    print(f"Model 2: {param2}")
    print(f"Difference: {torch.norm(param1 - param2)}")

print("Compare the parameters after optimization")
# Now, compare the parameters after optimization
for (name1, param1), (name2, param2) in zip(params1_after.items(), params2_after.items()):
    print(f"Comparing {name1}:")
    print(f"Model 1: {param1}")
    print(f"Model 2: {param2}")
    print(f"Difference: {torch.norm(param1 - param2)}")

# extract the step that was taken by calculating the difference between the parameters before and after optimization
step1 = {name: params1_after[name] - params1_before[name] for name in params1_after}
step2 = {name: params2_after[name] - params2_before[name] for name in params2_after}

print("Compare the steps taken by the two optimizers")

for (name1, param1), (name2, param2) in zip(step1.items(), step2.items()):
    print(f"Comparing {name1}:")
    
    # Print the full parameter tensors for both models (optional)
    print(f"Model 1: {param1}")
    print(f"Model 2: {param2}")
    
    # Print the norm of the difference (a single value, scalar)
    print(f"Norm of Difference: {torch.norm(param1 - param2)}")
    
    # Compute the difference and print for each individual parameter
    difference = param1 - param2
    print(f"Difference for each individual param: {difference}")
    
    # Compute the ratio and print for each individual parameter
    ratio = param1 / param2
    print(f"Ratio for each individual param: {ratio}")
    
    # Optionally, print the individual weight differences and ratios clearly
    for i, (diff, r) in enumerate(zip(difference.flatten(), ratio.flatten())):
        print(f"Weight {i}: Difference = {diff.item()}, Ratio = {r.item()}, Single precision not exceeded (+- lr* e-7): {torch.isclose(r, torch.tensor(0.5), atol=lr * 1e-7)}")



input_data tensor([[-9.3848e-02, -4.9599e-01,  9.6214e-01,  3.9399e-01, -9.2998e-01,
          1.1855e+00,  1.4126e+00,  1.6326e-01, -1.7673e+00, -1.3445e+00],
        [ 1.5449e+00,  4.5563e-01,  1.1740e-01, -7.2400e-01, -2.1713e+00,
          7.8712e-01, -6.2890e-01, -4.4137e-02, -1.2323e+00, -5.0699e-01],
        [-1.3766e+00, -2.0966e+00, -1.0178e+00,  3.7341e-01, -1.3290e-01,
          5.8631e-02, -1.6333e+00,  6.6700e-01, -1.8709e+00,  1.6859e-01],
        [ 9.3317e-01,  2.0100e-01,  9.6671e-01,  2.2769e-01, -1.7471e+00,
          2.5883e+00,  1.0854e+00,  4.4832e-02, -1.4370e+00,  1.3024e+00],
        [ 8.5050e-01,  1.0503e+00, -1.3052e-01, -5.9629e-01, -2.8378e-01,
         -7.7014e-01,  1.5392e+00,  3.0449e-02,  5.3290e-01,  5.7384e-01],
        [ 9.3549e-01,  1.6361e-01, -7.3672e-01,  1.3144e+00,  9.5261e-01,
          4.4999e-01, -1.3974e+00, -5.1735e-01,  2.9637e-01, -6.8204e-01],
        [ 1.1491e-01,  1.1775e+00, -6.1895e-01, -5.0955e-01,  1.1652e+00,
         -1.9671e+00,

- as per this data, the learning rate scales the step the optimizer suggests linearly 