<a href="https://colab.research.google.com/github/Loki-33/Optimizer/blob/main/Optimizers_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import tqdm
import numpy as np

In [None]:
torch.manual_seed(1234)
random.seed(1234)
np.random.seed(1234)

In [None]:
mean = 0.13066048920154572
std = 0.30810779333114624

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.RandomCrop(28, padding=2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[mean], std=[std])
])

train_data = datasets.MNIST(root='.data', train=True, download=True, transform=train_transforms)

In [None]:
batch_size = 128
train_iterator = data.DataLoader(train_data, shuffle=True, batch_size=batch_size)

In [None]:
class MLP(nn.Module):
  def __init__(self, input_dim, hid_dim, output_dim):
    super().__init__()
    self.layer1 = nn.Linear(input_dim, hid_dim)
    self.layer2 = nn.Linear(hid_dim, hid_dim)
    self.layer3 = nn.Linear(hid_dim, output_dim)


  def init_params(self):
    for n, p in self.named_parameters():
      if 'weight' in n:
        nn.init.kaiming_normal_(p, nonlinearity='relu')
      elif 'bias' in n:
        nn.init.constant_(p.data, 0.0)


  def forward(self, x):
    batch_size, *_ = x.shape
    x = x.view(batch_size, -1)
    x = F.relu(self.layer1(x))
    x = F.relu(self.layer2(x))
    x = self.layer3(x)
    return x

In [None]:
input_dim = 28*28
hid_dim = 256
output_dim = 10

model = MLP(input_dim, hid_dim, output_dim)

In [None]:
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = model.to(device)
criterion = criterion.to(device)

In [None]:
def train_epoch(iterator, model, optimizer, criterion, device):
  losses = []
  for images, labels in tqdm.tqdm(iterator):
    images = images.to(device)
    labels = labels.to(device)
    optimizer.zero_grad()
    predictions = model(images)
    loss = criterion(predictions, labels)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
  return losses

In [None]:
def train(train_iterator, model, optimizer, criterion, device, n_epochs=3):
  losses = []
  model.init_params()

  for epoch in range(n_epochs):
    epoch_losses = train_epoch(train_iterator, model, optimizer, criterion, device)
    losses.extend(epoch_losses)
  return losses

In [None]:
def plot_loss(loss, title=None, ymin=0, ymax=None, figsize=(15,5)):
  fig, ax = plt.subplots(figsize=figsize)
  ax.plot(loss)
  ax.set_ylabel('Loss')
  ax.set_xlabel('Update Steps')
  ax.set_title(title)
  ax.set_ylim(ymin, ymax)
  ax.grid()

In [None]:
def plot_losses(losses, labels, title=None, ymin=0, ymax=None, figsize=(15,5)):
  fig, ax = plt.subplots(figsize=figsize)
  for loss, label in zip(losses, labels):
    ax.plot(loss, label=label)
  ax.set_title(title)
  ax.set_ylabel('Loss')
  ax.set_xlabel('Update Steps')
  ax.set_ylim(ymin=ymin, ymax=ymax)
  ax.grid()
  ax.legend(loc='upper right')


In [None]:
class SGD:
  def __init__(self, model_params, lr=1e-3):
    self.model_params = list(model_params)
    self.lr = lr

  def zero_grad(self):
    for p in self.model_params:
      p.grad = None

  @torch.no_grad()
  def step(self):
    for p in self.model_params:
      p.sub_(self.lr * p.grad)

In [None]:
optimizer = SGD(model.parameters())

In [None]:
sgd_loss= train(train_iterator, model, optimizer, criterion, device)

In [None]:
plot_loss(sgd_loss, 'SGD with lr=1e-3')

In [None]:
class SGD_Momentum:
  def __init__(self, model_params, lr=1e-3, momentum=0.9):
    self.model_params = list(model_params)
    self.lr = lr
    self.momentum = momentum
    self.v = [torch.zeros_like(p) for p in self.model_params]

  def zero_grad(self):
    for p in self.model_params:
      p.grad = None

  @torch.no_grad()
  def step(self):
    for p,v in zip(self.model_params, self.v):
      v.mul_(self.momentum).add_(p.grad)
      p.sub_(self.lr*v)

In [None]:
optimizer = SGD_Momentum(model.parameters())

In [None]:
SGD_momentum_loss = train(train_iterator, model, optimizer, criterion, device)

In [None]:
plot_loss(SGD_momentum_loss, "SGD MOMENTUM WITH LR=1e-3 AND MOMENTUM=0.9")

In [None]:
class AdaGrad:
  def __init__(self, model_params, init_acc_sqr_grad=0, lr=1e-3, eps=1e-10):
    self.model_params = list(model_params)
    self.lr = lr
    self.acc_seq_grads = [torch.full_like(p, init_acc_sqr_grad) for p in self.model_params]
    self.eps = eps

  def zero_grad(self):
    for p in self.model_params:
      p.grad = None
  @torch.no_grad()
  def step(self):
    for p, a in zip(self.model_params, self.acc_seq_grads):
      a.add_(p.grad * p.grad)
      std = a.sqrt().add(self.eps)
      p.sub_((self.lr / std) * p.grad)

In [None]:
optimizer = AdaGrad(model.parameters())


In [None]:
adagrad_loss = train(train_iterator, model, optimizer, criterion, device)


In [None]:
plot_loss(adagrad_loss, 'Adagrad with lr=1e-2, init_acc_sqr_grad=0, eps=1e-10')

In [None]:
class AdaDelta:
  def __init__(self, model_params, lr=1.0, rho=0.9, eps=1e-9):
    self.model_params = list(model_params)
    self.lr = lr
    self.rho = rho
    self.avg_sqr_grads = [torch.zeros_like(p) for p in self.model_params]
    self.avg_sqr_deltas = [torch.zeros_like(p) for p in self.model_params]
    self.eps = eps

  def zero_grad(self):
    for p in self.model_params:
      p.grad = None

  @torch.no_grad()
  def step(self):
    for p, a, b in zip(self.model_params, self.avg_sqr_grads, self.avg_sqr_deltas):
      a.mul_(self.rho).add_(p.grad*p.grad * (1 - self.rho))
      std = a.add(self.eps).sqrt()
      delta = b.add(self.eps).sqrt().div(std).mul(p.grad)
      p.sub_(self.lr * delta)
      b.mul_(self.rho).add_(delta * delta * (1-self.rho))

In [None]:
optimizer = AdaDelta(model.parameters())

In [None]:
adadelta_loss = train(train_iterator, model, optimizer, criterion, device)

In [None]:
plot_loss(adadelta_loss, 'Adadelta with lr=1.0, rho=0.9, eps=1e-6')

In [None]:
class RMSprop:
  def __init__(self, model_params, lr=1e-2, alpha=0.99, eps=1e-8):
    self.model_params = model_params
    self.eps = eps
    self.lr = lr
    self.alpha = alpha #rho in case of AdaDelta
    self.avg_sqr_grads = [torch.zeros_like(p) for p in self.model_params]

  def zero_grad(self):
    for p in self.model_params:
      p.grad = None
  @torch.no_grad()
  def step(self):
    for p, a in zip(self.model_params, self.avg_sqr_grads):
      a.mul_(self.alpha).add_(p.grad * p.grad * (1-self.alpha))
      std = a.sqrt().add(self.eps)
      p.sub_((self.lr / std) * p.grad)

In [None]:
optimizer = RMSprop(model.parameters())

In [None]:
rmsprop_loss = train(train_iterator, model, optimizer, criterion, device)

In [None]:
plot_loss(rmsprop_loss, 'RMSprop with lr=1e-2, alpha=0.99, eps=1e-8')

In [None]:
class Adam:
  def __init__(self, model_params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
    self.model_params = model_params
    self.lr = lr
    self.beta1 = beta1
    self.beta2 = beta2
    self.eps = eps
    self.n_steps = 0
    self.avg_grads = [torch.zeros_like(p) for p in self.model_params]
    self.avg_sqr_grads = [torch.zeros_like(p) for p in self.model_params]

  def zero_grad(self):
    for p in self.model_params:
      p.grad = None

  @torch.no_grad()
  def step(self):
    for p, a, b in zip(self.model_params, self.avg_grads, self.avg_sqr_grads):
      self.n_steps += 1
      a.mul_(self.beta1).add_(p.grad * (1-self.beta1))
      b.mul_(self.beta2).add_(p.grad * p.grad * (1-self.beta2))
      avg_corrected_grad = a.div(1-self.beta1**self.n_steps)
      avg_corrected_sqr_grad = b.div(1-self.beta2**self.n_steps)
      std = avg_corrected_sqr_grad.sqrt().add(self.eps)
      p.sub_(self.lr * avg_corrected_grad / std)

In [None]:
optimizer = Adam(model.parameters())

In [None]:
adam_loss = train(train_iterator, model, optimizer, criterion, device)

In [None]:
plot_loss(adam_loss, 'Adam with lr=1e-3, betas=(0.9, 0.999), eps=1e-8')