In [1]:
import math
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
import torch.optim
import pdb
import logging
import os
import torch.distributed as dist

if TYPE_CHECKING:
    from torch.optim.optimizer import _params_t
else:
    _params_t = Any

class DAdaptAdam(torch.optim.Optimizer):
    def __init__(self, params, lr=1.0,
                betas=(0.9, 0.999), eps=1e-8,
                weight_decay=0, log_every=0,
                decouple=False,
                use_bias_correction=False,
                d0=1e-6, growth_rate=float('inf'),
                fsdp_in_use=False):
        if not 0.0 < d0:
            raise ValueError("Invalid d0 value: {}".format(d0))
        if not 0.0 < lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 < eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

        if decouple:
            print(f"Using decoupled weight decay")


        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay,
                        d = d0,
                        k=0,
                        layer_scale=1.0,
                        numerator_weighted=0.0,
                        log_every=log_every,
                        growth_rate=growth_rate,
                        use_bias_correction=use_bias_correction,
                        decouple=decouple,
                        fsdp_in_use=fsdp_in_use)
        self.d0 = d0
        super().__init__(params, defaults)

    @property
    def supports_memory_efficient_fp16(self):
        return False

    @property
    def supports_flat_params(self):
        return True

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        sk_l1 = 0.0

        group = self.param_groups[0]
        use_bias_correction = group['use_bias_correction']
        numerator_weighted = group['numerator_weighted']
        beta1, beta2 = group['betas']
        k = group['k']

        d = group['d']
        lr = max(group['lr'] for group in self.param_groups)

        if use_bias_correction:
            bias_correction = ((1-beta2**(k+1))**0.5)/(1-beta1**(k+1))
        else:
            bias_correction = 1

        dlr = d*lr*bias_correction

        growth_rate = group['growth_rate']
        decouple = group['decouple']
        log_every = group['log_every']
        fsdp_in_use = group['fsdp_in_use']


        sqrt_beta2 = beta2**(0.5)

        numerator_acum = 0.0

        for group in self.param_groups:
            decay = group['weight_decay']
            k = group['k']
            eps = group['eps']
            group_lr = group['lr']
            r = group['layer_scale']

            if group_lr not in [lr, 0.0]:
                raise RuntimeError(f"Setting different lr values in different parameter groups "
                                "is only supported for values of 0. To scale the learning "
                                "rate differently for each layer, set the 'layer_scale' value instead.")

            for p in group['params']:
                if p.grad is None:
                    continue
                if hasattr(p, "_fsdp_flattened"):
                    fsdp_in_use = True

                grad = p.grad.data

                if decay != 0 and not decouple:
                    grad.add_(p.data, alpha=decay)

                state = self.state[p]

                if 'step' not in state:
                    state['step'] = 0
                    state['s'] = torch.zeros_like(p.data).detach()
                    state['exp_avg'] = torch.zeros_like(p.data).detach()
                    state['exp_avg_sq'] = torch.zeros_like(p.data).detach()

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                s = state['s']

                if group_lr > 0.0:
                    denom = exp_avg_sq.sqrt().add_(eps)
                    numerator_acum += r * dlr * torch.dot(grad.flatten(), s.div(denom).flatten()).item()

                    exp_avg.mul_(beta1).add_(grad, alpha=r*dlr*(1-beta1))
                    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)

                    s.mul_(sqrt_beta2).add_(grad, alpha=dlr*(1-sqrt_beta2))
                    sk_l1 += r * s.abs().sum().item()


        numerator_weighted = sqrt_beta2*numerator_weighted + (1-sqrt_beta2)*numerator_acum
        d_hat = d
        if sk_l1 == 0:
            return loss

        if lr > 0.0:
            if fsdp_in_use:
                dist_tensor = torch.zeros(2).cuda()
                dist_tensor[0] = numerator_weighted
                dist_tensor[1] = sk_l1
                dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
                global_numerator_weighted = dist_tensor[0]
                global_sk_l1 = dist_tensor[1]
            else:
                global_numerator_weighted = numerator_weighted
                global_sk_l1 = sk_l1


            d_hat = global_numerator_weighted/((1-sqrt_beta2)*global_sk_l1)
            d = max(d, min(d_hat, d*growth_rate))

        if log_every > 0 and k % log_every == 0:
            logging.info(f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}")

        for group in self.param_groups:
            group['numerator_weighted'] = numerator_weighted
            group['d'] = d

            decay = group['weight_decay']
            k = group['k']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                state['step'] += 1

                denom = exp_avg_sq.sqrt().add_(eps)

                if decay != 0 and decouple:
                    p.data.add_(p.data, alpha=-decay * dlr)

                p.data.addcdiv_(exp_avg, denom, value=-1)

            group['k'] = k + 1

        return loss

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
import numpy as np
import os

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 100)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data = datasets.CIFAR100(root='data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR100(root='data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

num_epochs = 175
num_runs = 5
save_every = 10

checkpoint_dir = 'checkpoints'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

def save_checkpoint(run, epoch, model, optimizer, train_losses, train_accuracies):
    checkpoint = {
        'run': run,
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
    }
    filename = os.path.join(checkpoint_dir, f'checkpoint_run{run}_epoch{epoch}.pth')
    torch.save(checkpoint, filename)
    print(f'Saved checkpoint: {filename}')

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    return checkpoint['run'], checkpoint['epoch'], checkpoint['model_state_dict'], checkpoint['optimizer_state_dict'], checkpoint['train_losses'], checkpoint['train_accuracies']

all_train_losses = []
all_train_accuracies = []
all_test_accuracies = []

for run in range(num_runs):
    print(f"Starting run {run + 1}/{num_runs}")

    checkpoint_file = os.path.join(checkpoint_dir, f'checkpoint_run{run}_epoch{num_epochs}.pth')
    if os.path.exists(checkpoint_file):
        print(f"Resuming from checkpoint: {checkpoint_file}")
        run, start_epoch, model_state_dict, optimizer_state_dict, train_losses, train_accuracies = load_checkpoint(checkpoint_file)
        model = SimpleCNN().to(device)
        optimizer = DAdaptAdam(model.parameters(), lr=1)
        model.load_state_dict(model_state_dict)
        optimizer.load_state_dict(optimizer_state_dict)
    else:
        start_epoch = 0
        model = SimpleCNN().to(device)
        optimizer = DAdaptAdam(model.parameters(), lr=1)
        train_losses = []
        train_accuracies = []

    for epoch in range(start_epoch, num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        train_accuracy = 100 * correct / total

        train_losses.append(avg_loss)
        train_accuracies.append(train_accuracy)

        print(f"Run {run + 1}, Epoch {epoch + 1}, Loss: {avg_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")

        if (epoch + 1) % save_every == 0:
            save_checkpoint(run, epoch + 1, model, optimizer, train_losses, train_accuracies)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_accuracy = 100 * correct / total
    all_test_accuracies.append(test_accuracy)

    print(f"\nResults after Run {run + 1}:")
    print("Training Losses:", train_losses)
    print("Training Accuracies:", train_accuracies)
    print("Test Accuracy:", test_accuracy)

    save_checkpoint(run, num_epochs, model, optimizer, train_losses, train_accuracies)

    all_train_losses.append(train_losses)
    all_train_accuracies.append(train_accuracies)

np.savetxt('all_train_losses.txt', np.array(all_train_losses), fmt='%f')
np.savetxt('all_train_accuracies.txt', np.array(all_train_accuracies), fmt='%f')
np.savetxt('all_test_accuracies.txt', np.array(all_test_accuracies), fmt='%f')

print("\nAggregated Results:")
print("All Training Losses over Epochs for each run:", all_train_losses)
print("All Training Accuracies over Epochs for each run:", all_train_accuracies)
print("All Test Accuracies for each run:", all_test_accuracies)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:12<00:00, 14025456.53it/s]


Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified
Starting run 1/5
Run 1, Epoch 1, Loss: 4.2039, Training Accuracy: 4.70%
Run 1, Epoch 2, Loss: 3.6583, Training Accuracy: 12.36%
Run 1, Epoch 3, Loss: 3.3401, Training Accuracy: 18.28%
Run 1, Epoch 4, Loss: 3.1015, Training Accuracy: 22.90%
Run 1, Epoch 5, Loss: 2.9238, Training Accuracy: 26.33%
Run 1, Epoch 6, Loss: 2.7818, Training Accuracy: 28.87%
Run 1, Epoch 7, Loss: 2.6651, Training Accuracy: 31.17%
Run 1, Epoch 8, Loss: 2.5717, Training Accuracy: 33.12%
Run 1, Epoch 9, Loss: 2.4905, Training Accuracy: 34.72%
Run 1, Epoch 10, Loss: 2.4062, Training Accuracy: 36.44%
Saved checkpoint: checkpoints/checkpoint_run0_epoch10.pth
Run 1, Epoch 11, Loss: 2.3432, Training Accuracy: 37.83%
Run 1, Epoch 12, Loss: 2.2876, Training Accuracy: 38.84%
Run 1, Epoch 13, Loss: 2.2284, Training Accuracy: 40.02%
Run 1, Epoch 14, Loss: 2.1757, Training Accuracy: 41.12%
Run 1, Epoch 15, Loss: 2.1295, Training Accuracy: 4