In [2]:
from EvolutionStrategy import ESModel

## 1. The `ESModel` Class

`ESModel` can be used for any `torch` models, as long as it is a subclass of `nn.module`. Here we define a simple MLP:

In [8]:
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 1)
        )

    def forward(self, x):
        return self.fc(x)


To use `ESModel`, we can directly pass the model class name, in this case `MLP`, to the constructor:

In [4]:
es_model = ESModel(Model=MLP, param_std=0.01,Optimizer=torch.optim.Adam)

The `samples()` method returns an iterator, allowing us to iterate through models whose parameters were drawn from the normal distribution:

In [11]:
for model in es_model.samples(sample_size=2):
    print(model)
    

MLP(
  (fc): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=1, bias=True)
  )
)
MLP(
  (fc): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=1, bias=True)
  )
)


After sampling, we need to obtain loss to estimate the gradient. This is done outside of the `ESModel` class. After doing so, the `gradient_descent(loss)` method takes a tensor with shape [nb_samples,], each entry corresponds to a loss for a sample. It then estimates the gradient of loss w.r.t. parameter $\nabla_\theta L$ and performs gradient descent:

In [10]:
fake_loss = torch.randn([2,]).to(device) # since the sample size is 2
es_model.gradient_descent(fake_loss)

Then we iterate these two operations to train the model. The `get_best_model()` method returns a model with parameters being their estimated means. This can be used in validation:

In [None]:
best_model = es_model.get_best_model()

## 2. Example: Training CNN on MNIST Dataset using `ESModel`

### 2.1 Defining model

In [12]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Normalize
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

### 2.2 Training

In [16]:
def train_loop(es_model, dataloader, loss_fn, nb_model_samples = 30):
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        # keep track of stats for each model sample
        samples_loss = []
        correct = 0
        for model in es_model.samples(nb_model_samples):
            # Forward pass
            pred = model(x)
            samples_loss.append(loss_fn(pred, y))
            correct += (pred.argmax(1) == y).sum().item()
        samples_loss = torch.stack(samples_loss) 
        es_model.gradient_descent(samples_loss)
        
        print(f"loss: {samples_loss.mean():>7f}")
        
def test_loop(es_model, dataloader, loss_fn):
    model = es_model.get_best_model()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        test_loss += loss_fn(pred, y).item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

def train():
    train_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=True, download=True,
                             transform=Compose([
                               ToTensor(),
                               Normalize((0.1307,), (0.3081,))
                             ])),
        batch_size=256, shuffle=True)

    test_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST('data', train=False, download=True,
                                    transform=Compose([
                                    ToTensor(),
                                    Normalize((0.1307,), (0.3081,))
                                    ])),
        batch_size=1000, shuffle=True)
    with torch.no_grad(): # ES doesn't need gradient tracking
        es_model = ESModel(Model=CNN, param_std=0.05, Optimizer=optim.Adam)
        for epoch in range(1, 10):
            print(f"Epoch {epoch}\n-------------------------------")
            # train the model
            train_loop(es_model,train_dataloader, nn.CrossEntropyLoss(), nb_model_samples=1000)
            test_loop(es_model, test_dataloader, nn.CrossEntropyLoss())
train()

Epoch 1
-------------------------------
loss: 2.344096
loss: 2.365517
loss: 2.346449
loss: 2.361961
loss: 2.386756
loss: 2.364064
loss: 2.371230
loss: 2.377149
loss: 2.401134
loss: 2.363524
loss: 2.374210
loss: 2.383887
loss: 2.389943
loss: 2.411423
loss: 2.409988
loss: 2.418119
loss: 2.437249
loss: 2.437269
loss: 2.420153
loss: 2.428493
loss: 2.448111
loss: 2.442569
loss: 2.423686
loss: 2.454573
loss: 2.448188
loss: 2.448636
loss: 2.450336
loss: 2.445827
loss: 2.434328
loss: 2.430256
loss: 2.446907
loss: 2.458442
loss: 2.464206
loss: 2.460428
loss: 2.455112
loss: 2.487108
loss: 2.464961
loss: 2.475064
loss: 2.490706
loss: 2.494599
loss: 2.486852
loss: 2.506617
loss: 2.485558
loss: 2.485729
loss: 2.488230
loss: 2.512548
loss: 2.467071
loss: 2.501066
loss: 2.503813
loss: 2.499095
loss: 2.506607
loss: 2.494600
loss: 2.527209
loss: 2.539393
loss: 2.513906
loss: 2.537012
loss: 2.520734
loss: 2.557217
loss: 2.505289
loss: 2.528536
loss: 2.550517
loss: 2.502917
loss: 2.493002
loss: 2.501419
