<a href="https://colab.research.google.com/github/bhavnicksm/nadir/blob/main/examples/ranger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Nadir Example: Easy Ranger Optimizer Implementation

This colab notebook is to show-case how you can use Nadir to easily build stuff like Ranger optimizer out of the box, without any effort at all!

Implementing Range Optimizer in Nadir is as simple as just enabling Nesterov with RAdam.

```
import nadir as nd
from nadir import Radam, RadamConfig

config = RadamConfig(nesterov=True)
optim = Radam(params=..., config=config)
```

In [None]:
!pip install nadir



In [None]:
import os, random
from typing import Any, List, Tuple, Dict
import argparse

import numpy as np

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms, utils

from tqdm.notebook import tqdm

In [None]:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic=True

SEED = 42

# Initialising the seeds
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
# Make a Namespace object to store all the experiment values
args = argparse.Namespace()

args.learning_rate : float = 1e-3

args.batch_size : int = 64
args.test_batch_size : int = 1000
args.gamma : float = 0.7
args.device : bool = 'cuda' if torch.cuda.is_available() else 'cpu'
args.log_interval : int = 10
args.epochs : int = 10
args.betas : Tuple[float, float] = (0.9, 0.99)
args.eps : float = 1e-16
args.optimizer : Any = optim.Adam

# with open("random_seeds.txt", 'r') as file:
#     file_str = file.read().split('\n')
#     seeds = [int(num) for num in file_str]
args.random_seeds : List[int] = [42]

args.seed : int = args.random_seeds[0]

# writing the logging args as a namespace obj
largs = argparse.Namespace()
largs.run_name : str = 'DoE-Adam'
largs.run_seed : str = args.seed


In [None]:
class MNISTestNet(nn.Module):
    def __init__(self):
        super(MNISTestNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = x
        return output

In [None]:
def prepare_loaders(args, use_cuda=False):
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    generator = torch.Generator()
    generator.manual_seed(args.seed)

    def seed_worker():
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            './data',
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=args.batch_size,
        shuffle=True,
        worker_init_fn = seed_worker,
        generator=generator,
        **kwargs,
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            './data',
            train=False,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            ),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        worker_init_fn=seed_worker,
        generator=generator,
        **kwargs,
    )
    return train_loader, test_loader

In [None]:
train_loader, test_loader = prepare_loaders(args)

In [None]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    running_loss = 0
    count = 0
    for (data, target)in (pbar := tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        running_loss = (running_loss * count + loss.item())/(count + 1)
        count += 1

        pbar.set_description(f"Running Loss: {running_loss : .5f}")
        # print(f"'train/Loss': {loss.item()}")
    return running_loss, loss.item()

In [None]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in (pbar := tqdm(test_loader)):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='mean').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum().item()

            pbar.set_description(f"Accuracy: {100 * correct/len(test_loader.dataset) : .4f}")
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100 * correct / len(test_loader.dataset)
    # wandb.log({'test/Accuracy': correct/len(test_loader.dataset)})
    # wandb.log({'test/Loss': test_loss})
    return test_accuracy, test_loss

In [None]:
def mnist_tester(optimizer=None, model = None, args = None, largs = None):
    train_loss = []
    test_loss = []

    torch.manual_seed(args.random_seeds[0])
    device = args.device
    use_cuda = True if device == torch.device('cuda') else False
    train_loader, test_loader = prepare_loaders(args, use_cuda)

    # model = MNISTestNet().to(device)

    # create grid of images and write to wandb
    # images, labels = next(iter(train_loader))
    # img_grid = utils.make_grid(images)
    # wandb.log({'mnist_images': img_grid})

    # custom optimizer from torch_optimizer package
    # if args.optimizer == optim.SGD:
    #     config = optim.SGDConfig(lr=args.learning_rate)
    # elif args.optimizer == optim.Adam:
    #     config = optim.AdamConfig(lr=args.learning_rate, betas=args.betas, eps=args.eps)
    # # config = config(lr=args.learning_rate)
    # optimizer = optimizer(model.parameters(), config)
    # optimizer = optim(model.parameters(), lr=args.learning_rate)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    for epoch in (pbar := tqdm(range(1,  args.epochs + 1))):
        loss,  _ =train(args, model, device, train_loader, optimizer, epoch)
        tloss=test(model, device, test_loader)
        scheduler.step()
        train_loss.append(loss)
        test_loss.append(tloss)
        pbar.set_description(f"Loss: {loss : .5f}")
    return train_loss, test_loss

In [None]:
device = "cuda"
model = MNISTestNet()
model = model.to(device)

In [None]:
from nadir import Radam, RadamConfig

config = RadamConfig(lr=1e-4, nesterov=True)
optim = Radam(params=model.parameters(), config=config)

In [None]:
optim

Radam (
Parameter Group 0
    amsgrad: False
    beta_1: 0.9
    beta_2: 0.99
    bias_correction: True
    eps: 1e-08
    lr: 0.0001
    nesterov: True
    weight_decay: 0.0
)

In [None]:
train(args, model, device, train_loader, optim, 1)

  0%|          | 0/938 [00:00<?, ?it/s]

(0.5081453292068632, 0.11233509331941605)

In [None]:
mnist_tester(optim, model, args, largs)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

([0.1549495553244341,
  0.10099813504926941,
  0.08061444159668658,
  0.06912140672637236,
  0.06317276834323605,
  0.059352496512301726,
  0.056001899444730455,
  0.05465932715280071,
  0.05261306171052293,
  0.052126462286614245],
 [(97.9, 6.848291680216789e-05),
  (98.43, 5.060526058077812e-05),
  (98.53, 4.4991660118103026e-05),
  (98.7, 4.0964876487851146e-05),
  (98.67, 4.139606337994337e-05),
  (98.79, 3.801739551126957e-05),
  (98.72, 3.7209724262356756e-05),
  (98.77, 3.686356116086244e-05),
  (98.79, 3.658476080745459e-05),
  (98.78, 3.5766651481389997e-05)])

## Fin.