Trying to Show that on toy data set we can walk on the contour, reach abot L distance and still keeps the overall perforemence the same

In [2]:
import torch
from torch.optim.optimizer import Optimizer, required

class ContourWalkingOptimizer(Optimizer):
    def __init__(self, params, lr=required):
        # Initialize the optimizer with the parameters and hyperparameters (e.g., learning rate)
        defaults = dict(lr=lr)
        super(ContourWalkingOptimizer, self).__init__(params, defaults)
        self._is_first_time_calc = True
    def step(self, closure=None):
        """
        Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        # Loop over parameter groups (usually one group, but could be more)
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                # Get gradient data for the parameter
                grad = p.grad.data
                # If this is the first call, initialize a random direction for p
                if self._is_first_time_calc:
                    d = torch.randn_like(grad)
                    self.state.setdefault(p, {})['d'] = d
                else:
                    d = self.state[p]['d']

                ## take here d from self
                d_proj = self.__act(grad,d)
                ## store here d_proj into self
                self.state[p]['d'] = d_proj.clone()
                # Custom update: Here, we perform a simple gradient descent step.
                # This is equivalent to: p = p - lr * grad
                p.data.add_(-group['lr'], d_proj)
        if self._is_first_time_calc:
            self._is_first_time_calc = False
        return loss

    def __act(self, grad, d):
        grad_norm = torch.norm(grad)
        if grad_norm < 1e-8:
            return d  # or return d unchanged if gradient is nearly zero

        # Compute dot product (flatten the tensors in case they are not 1D)
        dot = torch.dot(grad.view(-1), d.view(-1))

        # Project d onto the null space of grad
        d_proj = d - (dot / (grad_norm ** 2)) * grad
        norm_d_proj = torch.norm(d_proj)
        if norm_d_proj < 1e-8:
            return d  # or return d if the projection is negligible

        d_proj = d_proj / norm_d_proj  # normalize the projected direction
        return d_proj


# Example usage:
if __name__ == '__main__':
    # A simple linear model for demonstration
    model = torch.nn.Linear(10, 1)
    # Use Mean Squared Error as our loss function
    criterion = torch.nn.MSELoss()

    # Instantiate our custom optimizer
    L = 0.2
    eta = 0.00001  # step size
    T = int(L / eta)  # number of steps
    optimizer_contour = ContourWalkingOptimizer(model.parameters(), lr=eta)
    optimizer_score = torch.optim.SGD(model.parameters(), lr=0.01)

    num_samples = 100
    # Dummy data: input x and target y
    x = torch.randn(num_samples, 10)
    y = torch.randn(num_samples, 1)

    # Training loop (simplified)
    for epoch in range(100):
        optimizer_score.zero_grad()  # reset gradients to zero
        output = model(x)
        loss = criterion(output, y)
        loss.backward()  # compute gradients

        optimizer_score.step()  # update parameters
        print(f"Epoch {epoch}: loss = {loss.item()}")
    optimized_params = [p.detach().clone() for p in model.parameters()]
    for t in range(T):
        optimizer_contour.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer_contour.step()
        if t % 1000 == 0:
            print(f"Step {t}: loss = {loss.item()}")
    curr_params = model.parameters()
    for i, (param_score, param_contour) in enumerate(zip(optimized_params, curr_params)):
        print(f"Parameter {i}: score = {torch.norm(param_score - param_contour)}")


Epoch 0: loss = 1.4185031652450562
Epoch 1: loss = 1.3880826234817505
Epoch 2: loss = 1.3594682216644287
Epoch 3: loss = 1.3325499296188354
Epoch 4: loss = 1.3072237968444824
Epoch 5: loss = 1.283393383026123
Epoch 6: loss = 1.2609676122665405
Epoch 7: loss = 1.239861011505127
Epoch 8: loss = 1.2199935913085938
Epoch 9: loss = 1.201290249824524
Epoch 10: loss = 1.1836808919906616
Epoch 11: loss = 1.1670989990234375
Epoch 12: loss = 1.1514828205108643
Epoch 13: loss = 1.1367741823196411
Epoch 14: loss = 1.1229183673858643
Epoch 15: loss = 1.1098644733428955
Epoch 16: loss = 1.0975641012191772
Epoch 17: loss = 1.0859721899032593
Epoch 18: loss = 1.0750465393066406
Epoch 19: loss = 1.0647472143173218
Epoch 20: loss = 1.0550369024276733
Epoch 21: loss = 1.0458807945251465
Epoch 22: loss = 1.037245512008667
Epoch 23: loss = 1.029100775718689
Epoch 24: loss = 1.0214170217514038
Epoch 25: loss = 1.014167308807373
Epoch 26: loss = 1.0073258876800537
Epoch 27: loss = 1.000868797302246
Epoch 28:

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha = 1) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1642.)
  p.data.add_(-group['lr'], d_proj)


Step 1000: loss = 0.8880020976066589
Step 2000: loss = 0.8880211114883423
Step 3000: loss = 0.8879920244216919
Step 4000: loss = 0.8880394697189331
Step 5000: loss = 0.8880000114440918
Step 6000: loss = 0.8880088925361633
Step 7000: loss = 0.8880508542060852
Step 8000: loss = 0.8880193829536438
Step 9000: loss = 0.8880576491355896
Step 10000: loss = 0.888070821762085
Step 11000: loss = 0.888102650642395
Step 12000: loss = 0.8881121277809143
Step 13000: loss = 0.8881509304046631
Step 14000: loss = 0.8881320357322693
Step 15000: loss = 0.8881598711013794
Step 16000: loss = 0.8881928324699402
Step 17000: loss = 0.8881887197494507
Step 18000: loss = 0.8881677389144897
Step 19000: loss = 0.8881662487983704
Parameter 0: score = 0.1527235060930252
Parameter 1: score = 0.007789388298988342


Lets try on a real dataset

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import random_split,DataLoader
# Simple MLP for MNIST
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# MNIST dataset and dataloader
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_test  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Split the training set into training and validation sets (e.g., 80/20 split)
num_val = len(mnist_test)
num_train = len(train_dataset) - num_val
mnist_train, val_set = random_split(train_dataset, [num_train, num_val])
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True)
val_loader   = DataLoader(val_set, batch_size=64, shuffle=True)
val_loader_no_contour  = DataLoader(mnist_test, batch_size=64, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = SimpleMLP().to(device)
criterion = nn.CrossEntropyLoss()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:11<00:00, 898kB/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 134kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:06<00:00, 242kB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 7.42MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

cuda


In [4]:
optimizer_score = optim.SGD(model.parameters(), lr=0.01)
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer_score.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer_score.step()
        running_loss += loss.item() * data.size(0)
    print(f"Epoch {epoch}: training loss = {running_loss / len(train_dataset)}")

# Save the optimized parameters
optimized_params = [p.detach().clone() for p in model.parameters()]
torch.save(model, 'only_opt_no_contour.pth')
model.eval()
running_loss = 0.0
for data, target in train_loader:
    data, target = data.to(device), target.to(device)
    with torch.no_grad():
      output = model(data)
    loss = criterion(output, target)
    running_loss += loss.item() * data.size(0)
print(f"Final training loss = {running_loss / len(train_dataset)}")


Epoch 0: training loss = 1.2476548559029896
Epoch 1: training loss = 0.49958251489003497
Epoch 2: training loss = 0.39451914879481
Epoch 3: training loss = 0.3528783009211222
Epoch 4: training loss = 0.32766643747488655
Epoch 5: training loss = 0.3090274195273717
Epoch 6: training loss = 0.2938690779685974
Epoch 7: training loss = 0.2804592679818471
Epoch 8: training loss = 0.26876455256938936
Epoch 9: training loss = 0.2577093495607376
Epoch 10: training loss = 0.24751803018252055
Epoch 11: training loss = 0.23812339059114457
Epoch 12: training loss = 0.2295622441093127
Epoch 13: training loss = 0.2215165986975034
Epoch 14: training loss = 0.2140719194014867
Epoch 15: training loss = 0.20715455749829612
Epoch 16: training loss = 0.20052204308509827
Epoch 17: training loss = 0.19454826732873917
Epoch 18: training loss = 0.18878573688666025
Epoch 19: training loss = 0.18333962559103967
Final training loss = 0.17940671562751134


In [5]:
model = torch.load('only_opt_no_contour.pth')
# Set contour-walking hyperparameters:
L = 0.01
eta = 1e-5  # step size for contour walking
T = int(L / eta)

optimizer_contour = ContourWalkingOptimizer(model.parameters(), lr=eta)

# Get one batch from the training loader:

t = 0
model.train()
running_loss = 0
while t < T:
  for data, target in train_loader:
      if t < T:
        data, target = data.to(device), target.to(device)
        optimizer_contour.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer_contour.step()
        running_loss += loss.item()
        t+=1
print(f"Contour Step {t}: loss = {running_loss/ T}")


  model = torch.load('only_opt_no_contour.pth')


Contour Step 999: loss = 0.17870589348303903


In [6]:
for i, (p_orig, p_contour) in enumerate(zip(optimized_params, model.parameters())):
    diff_norm = torch.norm(p_orig - p_contour)
    print(f"Parameter {i}: Norm difference = {diff_norm.item()}")


Parameter 0: Norm difference = 0.009988917037844658
Parameter 1: Norm difference = 0.009783073328435421
Parameter 2: Norm difference = 0.009955196641385555
Parameter 3: Norm difference = 0.009936198592185974


Lets check if it is o.k to walk on the contour of the validation set

In [7]:
model = torch.load('only_opt_no_contour.pth')

  model = torch.load('only_opt_no_contour.pth')


In [8]:
from torch.utils.data import random_split,DataLoader

In [14]:

# Set contour-walking hyperparameters:
L = 0.1
eta = 1e-5  # step size for contour walking
T = int(L / eta)

optimizer_contour = ContourWalkingOptimizer(model.parameters(), lr=eta)
t = 0
#### The main difference here is that we walk on the contour relative to validation and not train #####
T = int(1 + T / len(val_loader)) * len(val_loader)
num_exp = 10
valdation_contour = list()
validation_no_contour = list()
loader_for_contour = val_loader
loader_for_no_contour = val_loader_no_contour
for ind in range(num_exp):

  running_loss = 0
  model = torch.load('only_opt_no_contour.pth')
  while t < T:
    model.train()
    for data, target in loader_for_contour:
      if t < T:
          data, target = data.to(device), target.to(device)
          optimizer_contour.zero_grad()
          output = model(data)
          loss = criterion(output, target)
          loss.backward()
          optimizer_contour.step()
          running_loss += loss.item()
          t +=1
  print(f"Contour Step {t}: loss = {running_loss / T}")
  model.eval()
  running_loss = 0
  for data, target in loader_for_contour:
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
          output = model(data)
        loss = criterion(output, target)
        running_loss += loss.item()
  print(f"Validation loss = {running_loss/ len(loader_for_contour)}")
  valdation_contour.append(running_loss/ len(loader_for_contour))
  model.eval()
  running_loss = 0
  for data, target in loader_for_no_contour:
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
          output = model(data)
        loss = criterion(output, target)
        running_loss += loss.item()
  print(f"Validation loss no contour = {running_loss/ len(loader_for_no_contour)}")
  validation_no_contour.append(running_loss/ len(loader_for_no_contour))
  if loader_for_contour == val_loader:
    loader_for_contour = val_loader_no_contour
    loader_for_no_contour = val_loader
  else:
    loader_for_contour = val_loader
    loader_for_no_contour = val_loader_no_contour


  model = torch.load('only_opt_no_contour.pth')


Contour Step 10048: loss = 0.17940636844412056
Validation loss = 0.17905799375408016
Validation loss no contour = 0.1803577360548791
Contour Step 10048: loss = 0.0
Validation loss = 0.17896598136159264
Validation loss no contour = 0.1786386588600221
Contour Step 10048: loss = 0.0
Validation loss = 0.1796294016063593
Validation loss no contour = 0.17953905681515955
Contour Step 10048: loss = 0.0
Validation loss = 0.17789380366255522
Validation loss no contour = 0.18016036993758694
Contour Step 10048: loss = 0.0
Validation loss = 0.17941977688746089
Validation loss no contour = 0.1784275173666371
Contour Step 10048: loss = 0.0
Validation loss = 0.17843117835415398
Validation loss no contour = 0.1789207886311279
Contour Step 10048: loss = 0.0
Validation loss = 0.17863997748228394
Validation loss no contour = 0.17800669811049086
Contour Step 10048: loss = 0.0
Validation loss = 0.17773591760237506
Validation loss no contour = 0.17924970333838158
Contour Step 10048: loss = 0.0
Validation los

In [15]:
from scipy.stats import ttest_rel

# Suppose you collected paired samples:
A_v_before = valdation_contour  # validation accuracies before contour walk
A_v_after = validation_no_contour   # validation accuracies after contour walk

t_stat, p_value = ttest_rel(A_v_before, A_v_after)

print(f"Paired t-test result: t-stat={t_stat}, p-value={p_value}")
if p_value > 0.05:
    print("Fail to reject H0: no significant bias from contour walk.")
else:
    print("Reject H0: significant bias detected.")


Paired t-test result: t-stat=-0.7386165766956186, p-value=0.4789577488087644
Fail to reject H0: no significant bias from contour walk.
