In [1]:
import torch
import argparse
import wandb
from typing import Any, List, Tuple

from dataclasses import dataclass, field
from typing import Dict, Tuple, Any, Optional
from torch.optim.optimizer import Optimizer, required


import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import nadir as optim
import mnist

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
args = argparse.Namespace()
args.random_seeds : List[int] = [42]
args.seed : int = args.random_seeds[0]

args.learning_rate : float = 1e-3
args.batch_size : int = 64
args.test_batch_size : int = 1000
args.gamma : float = 0.7
args.device : bool = 'cuda' if torch.cuda.is_available() else 'cpu'
args.log_interval : int = 10
args.epochs : int = 10
args.betas : Tuple[float, float] = (0.9, 0.99)
args.eps : float = 1e-16
args.optimizer : Any = optim.Adam


In [5]:
largs = argparse.Namespace()
largs.run_name : str = 'DoE-Adam'
largs.run_seed : str = args.seed

In [6]:
run = wandb.init(project="MNIST", entity="dawn-of-eve")
run.name = f'{largs.run_name}'
run.config.update(args)
run.config.update(largs)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbhavnicksm[0m ([33mdawn-of-eve[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
mnist.mnist_tester(args.optimizer, args, largs)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 13304585.23it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 8737427.43it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 8981198.92it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18057373.24it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



Loss:  0.08657: 100%|██████████| 938/938 [00:14<00:00, 63.40it/s]
Accuracy:  0.9707: 100%|██████████| 10/10 [00:01<00:00,  9.72it/s]
Loss:  0.01439: 100%|██████████| 938/938 [00:13<00:00, 71.93it/s]
Accuracy:  0.9793: 100%|██████████| 10/10 [00:01<00:00,  9.63it/s]
Loss:  0.38176: 100%|██████████| 938/938 [00:13<00:00, 71.88it/s]
Accuracy:  0.9805: 100%|██████████| 10/10 [00:01<00:00,  9.84it/s]
Loss:  0.00020: 100%|██████████| 938/938 [00:12<00:00, 74.73it/s]
Accuracy:  0.9812: 100%|██████████| 10/10 [00:01<00:00,  9.83it/s]
Loss:  0.00016: 100%|██████████| 938/938 [00:12<00:00, 74.60it/s]
Accuracy:  0.9817: 100%|██████████| 10/10 [00:01<00:00,  9.62it/s]
Loss:  0.42026: 100%|██████████| 938/938 [00:12<00:00, 73.26it/s]
Accuracy:  0.9834: 100%|██████████| 10/10 [00:01<00:00,  9.92it/s]
Loss:  0.21644: 100%|██████████| 938/938 [00:12<00:00, 74.19it/s]
Accuracy:  0.9834: 100%|██████████| 10/10 [00:01<00:00,  9.48it/s]
Loss:  0.00589: 100%|██████████| 938/938 [00:12<00:00, 72.69it/s]
Acc

([tensor(0.0866, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.0144, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.3818, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.4203, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.2164, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.0059, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.0138, device='cuda:0', grad_fn=<NllLossBackward0>),
  tensor(0.0026, device='cuda:0', grad_fn=<NllLossBackward0>)],
 [0.0001387216441333294,
  0.00011632492169737816,
  0.00010801393166184425,
  0.00010309783965349198,
  0.00010692187622189521,
  0.0001015650101006031,
  8.860463909804821e-05,
  9.646954238414765e-05,
  9.636803269386291e-05,
  9.87303152680397e-05])