Simple policy gradient

In [350]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np


class MLP(nn.Module):
    def __init__(
        self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
        self.layers = [
            nn.Linear(idim, odim)
            for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]

    def __call__(self, x):
        for l in self.layers[:-1]:
            x = mx.maximum(l(x), 0.0)
        return self.layers[-1](x)

In [351]:
def loss_fn(model, X, y):
    return mx.mean(nn.losses.cross_entropy(model(X), y))

In [352]:
def eval_fn(model, X, y):
    return mx.mean(mx.argmax(model(X), axis=1) == y)

In [353]:
def batch_iterate(batch_size, X, y):
    perm = mx.array(np.random.permutation(y.size))
    for s in range(0, y.size, batch_size):
        ids = perm[s : s + batch_size]
        yield X[ids], y[ids]

In [354]:
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 20
learning_rate = 1e-3

# Load the data
import mnist
train_images, train_labels, test_images, test_labels = map(
    mx.array, mnist.mnist()
)

In [355]:
# Load the model
model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
mx.eval(model.parameters())

# Get a function which gives the loss and gradient of the
# loss with respect to the model's trainable parameters
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Instantiate the optimizer
optimizer = optim.Adam(learning_rate=learning_rate)

for e in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        loss, grads = loss_and_grad_fn(model, X, y)

        # Update the optimizer state and model parameters
        # in a single call
        optimizer.update(model, grads)

        # Force a graph evaluation
        mx.eval(model.parameters(), optimizer.state)

    accuracy = eval_fn(model, test_images, test_labels)
    print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")

Epoch 0: Test accuracy 0.929
Epoch 1: Test accuracy 0.940
Epoch 2: Test accuracy 0.943
Epoch 3: Test accuracy 0.949
Epoch 4: Test accuracy 0.953
Epoch 5: Test accuracy 0.956
Epoch 6: Test accuracy 0.957
Epoch 7: Test accuracy 0.959
Epoch 8: Test accuracy 0.959
Epoch 9: Test accuracy 0.959
Epoch 10: Test accuracy 0.962
Epoch 11: Test accuracy 0.963
Epoch 12: Test accuracy 0.962
Epoch 13: Test accuracy 0.963
Epoch 14: Test accuracy 0.962
Epoch 15: Test accuracy 0.963
Epoch 16: Test accuracy 0.965
Epoch 17: Test accuracy 0.964
Epoch 18: Test accuracy 0.965
Epoch 19: Test accuracy 0.966


MLX is under control: it works!

Now, it's time to figure out gymnasium and the environment.

# MLX

In [548]:
X = mx.arange(9).reshape(3, 3)


X[mx.arange(3), mx.arange(3)], X

(array([0, 4, 8], dtype=int32),
 array([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]], dtype=int32))

In [578]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

model = nn.Sequential(
    nn.Linear(4, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 2)
)

optimizer = optim.Adam(learning_rate=1e-2)

def get_policy(obs):
    logits = model(obs)
    probs = mx.softmax(logits, axis=0)
    return probs, model

def get_action(obs):
    probs, _ = get_policy(obs)
    return np.random.choice(list(range(len(probs))), 1, p=np.array(probs))[0]

def compute_loss(obs, act, weights):
    # calc log probs
    probs, _ = get_policy(obs)
    logp = mx.log(probs[mx.arange(probs.shape[0]), act])
    return -(logp * weights).mean()

def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs

loss_and_grads_fn = nn.value_and_grad(model, compute_loss)

In [581]:
def train_one_epoch():
    epoch_len = 1000
    
    batch_obs = []
    batch_act = []
    
    batch_ret = []
    batch_weights = []
    batch_lens = []
    
    obs, _ = env.reset()
    done = False
    
    ep_rews = []
    
    while True:
        batch_obs.append(obs.copy())
        act = get_action(mx.array(obs))
    
        obs, rew, term, _, _ = env.step(act)
        
        # Append to history
        batch_act.append(act)
        ep_rews.append(rew)
    
        if term:
            # Sum for trajectories
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_ret.append(ep_ret)
            batch_lens.append(len(ep_rews))
            # Create weights
            batch_weights += list(reward_to_go(ep_rews))
            
            if len(batch_obs) > epoch_len:
                break
    
            (obs, _), term, ep_rews = env.reset(), False, []

    loss, grads = loss_and_grads_fn(mx.array(np.array(batch_obs)),
                                    mx.array(np.array(batch_act)), 
                                    mx.array(np.array(batch_weights)))
    optimizer.update(model, grads)
    mx.eval(model.parameters(), optimizer.state)
                
    return batch_ret, loss, batch_lens

In [582]:
for e in range(50):
    ret, loss, batch_lens = train_one_epoch()
    print(f"epoch {e}: loss = {loss.item()}; mean batch len = {np.mean(batch_lens)}")

epoch 0: loss = 83.45491790771484; mean batch len = 19.705882352941178
epoch 1: loss = 110.93253326416016; mean batch len = 24.853658536585368
epoch 2: loss = 151.95423889160156; mean batch len = 34.266666666666666
epoch 3: loss = 169.48532104492188; mean batch len = 44.47826086956522
epoch 4: loss = 243.25856018066406; mean batch len = 62.11764705882353
epoch 5: loss = 408.7662048339844; mean batch len = 103.0
epoch 6: loss = 268.99761962890625; mean batch len = 71.85714285714286
epoch 7: loss = 296.6059875488281; mean batch len = 80.84615384615384
epoch 8: loss = 279.02294921875; mean batch len = 68.86666666666666
epoch 9: loss = 220.41543579101562; mean batch len = 59.76470588235294
epoch 10: loss = 248.4477996826172; mean batch len = 68.93333333333334
epoch 11: loss = 254.44683837890625; mean batch len = 65.0
epoch 12: loss = 223.99301147460938; mean batch len = 60.705882352941174
epoch 13: loss = 260.1882629394531; mean batch len = 66.25
epoch 14: loss = 203.06796264648438; mean b

Please do not use it yet...

I think a lot of utility functions existing in PyTorch are not yet implemented in MLX, which makes it quite hard for the development.

The lazy evaluation and compilation were also quite confusing, which made the whole experience quite unpleasent.

# Torch

In [598]:
import torch 
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
from torch.distributions import Categorical

model = nn.Sequential(
    nn.Linear(4, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 2)
)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def get_policy(obs):
    logits = model(obs)
    return logits

# make action selection function (outputs int actions, sampled from policy)
def get_action(obs):
    logits = get_policy(obs)
    probs = F.softmax(logits, dim=0)
    probs = np.array(probs.detach()) 
    
    return np.random.choice(range(2), p=probs)

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(obs, act, weights):
    probs = F.softmax(get_policy(obs), dim=-1)
    logp = torch.log(probs[range(probs.shape[0]), act])
        
    return -(logp * weights).mean()


def reward_to_go(rews):
    n = len(rews)
    rtgs = np.zeros_like(rews)
    for i in reversed(range(n)):
        rtgs[i] = rews[i] + (rtgs[i+1] if i+1 < n else 0)
    return rtgs


def train_one_epoch():
    epoch_len = 300
    
    batch_obs = []
    batch_act = []
    
    batch_ret = []
    batch_weights = []
    batch_lens = []
    
    obs, _ = env.reset()
    done = False
    
    ep_rews = []
    
    while True:
        batch_obs.append(obs.copy())
        act = get_action(torch.tensor(obs))
    
        obs, rew, term, _, _ = env.step(act)
        
        # Append to history
        batch_act.append(act)
        ep_rews.append(rew)
    
        if term:
            # Sum for trajectories
            ep_ret, ep_len = sum(ep_rews), len(ep_rews)
            batch_ret.append(ep_ret)
            batch_lens.append(len(ep_rews))
            # Create weights
            batch_weights += list(reward_to_go(ep_rews))
            
            if len(batch_obs) > epoch_len:
                break
    
            (obs, _), term, ep_rews = env.reset(), False, []

    optimizer.zero_grad()
    loss = compute_loss(torch.tensor(np.array(batch_obs)),
                        torch.tensor(np.array(batch_act)), 
                        torch.tensor(np.array(batch_weights)))
    loss.backward()
    optimizer.step()
                
    return batch_ret, loss, batch_lens

render = False
epochs = 100
# training loop
for i in range(epochs):
    # batch_loss, batch_rets, batch_lens = train_one_epoch()
    ret, loss, batch_lens = train_one_epoch()
    if (i + 1) % 10 == 0:
        print(f"epoch {i}: loss = {loss.item()}; mean batch len = {np.mean(batch_lens)}")
    # print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
            # (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))


epoch 9: loss = 10.730097514449382; mean batch len = 21.857142857142858
epoch 19: loss = 22.730474894486584; mean batch len = 51.5
epoch 29: loss = 17.588762214376587; mean batch len = 37.875
epoch 39: loss = 15.48880399016439; mean batch len = 39.666666666666664
epoch 49: loss = 18.680067237668904; mean batch len = 52.166666666666664
epoch 59: loss = 22.067168519978928; mean batch len = 62.2
epoch 69: loss = 32.871211015335895; mean batch len = 86.25
epoch 79: loss = 17.16089697937649; mean batch len = 46.0
epoch 89: loss = 36.00716688901583; mean batch len = 103.25
epoch 99: loss = 84.48073412971387; mean batch len = 302.0
