In [1]:
import math
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
import torch.optim
import pdb
import logging
import os
import torch.distributed as dist

if TYPE_CHECKING:
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any

class DAdaptAdam(torch.optim.Optimizer):
    def __init__(self, params, lr=1.0,
                betas=(0.9, 0.999), eps=1e-8,
                weight_decay=0, log_every=0,
                decouple=False,
                use_bias_correction=False,
                d0=1e-6, growth_rate=float('inf'),
                fsdp_in_use=False):
        if not 0.0 < d0:
            raise ValueError("Invalid d0 value: {}".format(d0))
        if not 0.0 < lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 < eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

        if decouple:
            print(f"Using decoupled weight decay")


        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay,
                        d = d0,
                        k=0,
                        layer_scale=1.0,
                        numerator_weighted=0.0,
                        log_every=log_every,
                        growth_rate=growth_rate,
                        use_bias_correction=use_bias_correction,
                        decouple=decouple,
                        fsdp_in_use=fsdp_in_use)
        self.d0 = d0
        super().__init__(params, defaults)

    @property
    def supports_memory_efficient_fp16(self):
        return False

    @property
    def supports_flat_params(self):
        return True

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        sk_l1 = 0.0

        group = self.param_groups[0]
        use_bias_correction = group['use_bias_correction']
        numerator_weighted = group['numerator_weighted']
        beta1, beta2 = group['betas']
        k = group['k']

        d = group['d']
        lr = max(group['lr'] for group in self.param_groups)

        if use_bias_correction:
            bias_correction = ((1-beta2**(k+1))**0.5)/(1-beta1**(k+1))
        else:
            bias_correction = 1

        dlr = d*lr*bias_correction

        growth_rate = group['growth_rate']
        decouple = group['decouple']
        log_every = group['log_every']
        fsdp_in_use = group['fsdp_in_use']


        sqrt_beta2 = beta2**(0.5)

        numerator_acum = 0.0

        for group in self.param_groups:
            decay = group['weight_decay']
            k = group['k']
            eps = group['eps']
            group_lr = group['lr']
            r = group['layer_scale']

            if group_lr not in [lr, 0.0]:
                raise RuntimeError(f"Setting different lr values in different parameter groups "
                                "is only supported for values of 0. To scale the learning "
                                "rate differently for each layer, set the 'layer_scale' value instead.")

            for p in group['params']:
                if p.grad is None:
                    continue
                if hasattr(p, "_fsdp_flattened"):
                    fsdp_in_use = True

                grad = p.grad.data

                if decay != 0 and not decouple:
                    grad.add_(p.data, alpha=decay)

                state = self.state[p]

                if 'step' not in state:
                    state['step'] = 0
                    state['s'] = torch.zeros_like(p.data).detach()
                    state['exp_avg'] = torch.zeros_like(p.data).detach()
                    state['exp_avg_sq'] = torch.zeros_like(p.data).detach()

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                s = state['s']

                if group_lr > 0.0:
                    denom = exp_avg_sq.sqrt().add_(eps)
                    numerator_acum += r * dlr * torch.dot(grad.flatten(), s.div(denom).flatten()).item()

                    exp_avg.mul_(beta1).add_(grad, alpha=r*dlr*(1-beta1))
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)

                    s.mul_(sqrt_beta2).add_(grad, alpha=dlr*(1-sqrt_beta2))
                    sk_l1 += r * s.abs().sum().item()


        numerator_weighted = sqrt_beta2*numerator_weighted + (1-sqrt_beta2)*numerator_acum
        d_hat = d
        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            if fsdp_in_use:
                dist_tensor = torch.zeros(2).cuda()
                dist_tensor[0] = numerator_weighted
                dist_tensor[1] = sk_l1
                dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
                global_numerator_weighted = dist_tensor[0]
                global_sk_l1 = dist_tensor[1]
            else:
                global_numerator_weighted = numerator_weighted
                global_sk_l1 = sk_l1


            d_hat = global_numerator_weighted/((1-sqrt_beta2)*global_sk_l1)
            d = max(d, min(d_hat, d*growth_rate))

        if log_every > 0 and k % log_every == 0:
            logging.info(f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}")

        for group in self.param_groups:
            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            decay = group['weight_decay']
            k = group['k']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                state['step'] += 1

                denom = exp_avg_sq.sqrt().add_(eps)

                if decay != 0 and decouple:
                    p.data.add_(p.data, alpha=-decay * dlr)

                p.data.addcdiv_(exp_avg, denom, value=-1)

            group['k'] = k + 1

        return loss

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import os

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1) 
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

def save_checkpoint(config, state, filename='checkpoint.pth.tar'):
    os.makedirs(f'checkpoints/config_{config}', exist_ok=True)
    filepath = os.path.join(f'checkpoints/config_{config}', filename)
    torch.save(state, filepath)

def load_checkpoint(config, filename='checkpoint.pth.tar'):
    filepath = os.path.join(f'checkpoints/config_{config}', filename)
    if os.path.exists(filepath):
        return torch.load(filepath)
    return None

num_epochs = 175

hyperparams = [
    {'lr': 1.0, 'betas': (0.9, 0.999), 'weight_decay': 0, 'd0': 1e-6},
    {'lr': 0.5, 'betas': (0.9, 0.999), 'weight_decay': 0.01, 'd0': 1e-5},
    {'lr': 1.0, 'betas': (0.85, 0.999), 'weight_decay': 0.001, 'd0': 1e-7}
]

all_train_accuracies = []
all_test_accuracies = []
all_losses = []

for config_id, hparams in enumerate(hyperparams):
    print(f"Starting training for hyperparameter configuration {config_id + 1}/{len(hyperparams)}")

    model = SimpleCNN()
    optimizer = DAdaptAdam(
        model.parameters(),
        lr=hparams['lr'],
        betas=hparams['betas'],
        weight_decay=hparams['weight_decay'],
        d0=hparams['d0']
    )

    start_epoch = 0
    checkpoint = load_checkpoint(config_id)
    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        train_accuracies = checkpoint['train_accuracies']
        losses = checkpoint['losses']
        print(f"Resuming from epoch {start_epoch}")
    else:
        train_accuracies = []
        losses = []

    for epoch in range(start_epoch, num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        model.train()
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_accuracy = 100 * correct / total
        train_accuracies.append(train_accuracy)
        losses.append(running_loss / len(train_loader))
        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}, Training Accuracy: {train_accuracy:.2f}%")

        if (epoch + 1) % 10 == 0:
            save_checkpoint(config_id, {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_accuracies': train_accuracies,
                'losses': losses,
            })

    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_accuracy = 100 * correct / total
    all_test_accuracies.append([test_accuracy])
    print(f'Accuracy of the network on the 10000 test images: {test_accuracy:.2f}%')

    all_train_accuracies.append(train_accuracies)
    all_losses.append(losses)

    save_checkpoint(config_id, {
        'epoch': num_epochs - 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_accuracies': train_accuracies,
        'losses': losses,
        'test_accuracy': test_accuracy
    })

all_train_accuracies = [list(train_accuracies) for train_accuracies in all_train_accuracies]
all_test_accuracies = [list(test_accuracies) for test_accuracies in all_test_accuracies]
all_losses = [list(losses) for losses in all_losses]

print("All Training Accuracies over Epochs for each config:", all_train_accuracies)
print("All Test Accuracies after each config:", all_test_accuracies)
print("All Losses over Epochs for each config:", all_losses)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 15469334.14it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Starting training for hyperparameter configuration 1/3
Epoch 1, Loss: 1.8479507211833963, Training Accuracy: 30.64%
Epoch 2, Loss: 1.5265347352418144, Training Accuracy: 43.35%
Epoch 3, Loss: 1.419721240735115, Training Accuracy: 47.86%
Epoch 4, Loss: 1.3564787116806831, Training Accuracy: 50.18%
Epoch 5, Loss: 1.3204516827907709, Training Accuracy: 51.68%
Epoch 6, Loss: 1.298067217592693, Training Accuracy: 52.95%
Epoch 7, Loss: 1.2745878688057366, Training Accuracy: 53.90%
Epoch 8, Loss: 1.2567446345410993, Training Accuracy: 54.18%
Epoch 9, Loss: 1.2421544243765, Training Accuracy: 55.23%
Epoch 10, Loss: 1.2313954300435304, Training Accuracy: 55.47%
Epoch 11, Loss: 1.221115350951929, Training Accuracy: 55.98%
Epoch 12, Loss: 1.2071787481722625, Training Accuracy: 56.33%
Epoch 13, Loss: 1.2027929135600623, Training Accuracy: 56.63%
Epoch 14, Loss: 1.1948651246859898, Training Accuracy: 56.82%
Epoch 1