In [1]:
import torch
from dataclasses import dataclass, field
from typing import Dict, Tuple, Any, Optional
from torch.optim.optimizer import Optimizer, required
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [3]:
@dataclass
class DoEConfig():
    pass

In [4]:
@dataclass
class SGDConfig(DoEConfig):
    lr : float = 1e-3
    momentum: float = 0
    nesterov: bool = False


In [5]:
class BaseOptimizer(Optimizer):
    def __init__(
        self, 
        params, 
        defaults: Dict[str, Any], 
        config : DoEConfig 

    ):
        # if not 0.0 <= config.lr:
        #     raise ValueError(f"Invalid learning rate: {config.lr}")
        # if not 0.0 <= config.momentum < 1.0:
        #     raise ValueError(f"Invalid momentum: {config.momentum}")
        # if not 0.0 <= config.eps:
        #     raise ValueError(f"Invalid epsilon value: {config.eps}")
        # if not 0.0 <= config.betas[0] < 1.0:
        #     raise ValueError(f"Invalid beta parameter at index 0: {config.betas[0]}")
        # if not 0.0 <= config.betas[1] < 1.0:
        #     raise ValueError(f"Invalid beta parameter at index 1: {config.betas[1]}")

        defaults = config.__dict__

        super().__init__(params, defaults)

In [6]:
class SGD(BaseOptimizer):
    def __init__(self, params, config: DoEConfig, defaults: Optional[Dict[str, Any]] = None):
        defaults = {} if defaults is None else defaults
        super().__init__(params,defaults, config)

In [7]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = Net().to(device)
config = SGDConfig(lr=0.01)
optimizer = SGD(model.parameters(), config)




  return torch._C._cuda_getDeviceCount() > 0


In [8]:
optimizer

SGD (
Parameter Group 0
    lr: 0.01
    momentum: 0
    nesterov: False
)

In [304]:
config = SGDConfig(lr=0)
config.__dict__

{'lr': 0, 'momentum': 0, 'nesterov': False}

In [1]:
import nadir

In [3]:
nadir.SGD

nadir.sgd.SGD