## Getting started

In [90]:
from __future__ import annotations

import random
import math
import time
import itertools
import functools
from typing import Tuple, List, Mapping, Optional, Union, NamedTuple, Callable
# Many default parameters are included in jnumpy and are optional.
# I only resort to using `Optional` in the type annotations where the
# context does not make this clear.  

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

## jnumpy

In [91]:
# Jnumpy: Jacob's numpy library for machine learning
# Copyright (c) 2021 Jacob F. Valdez. Released under the MIT license.


V = np.array # V is for Value type
Vs = Tuple[V]
Vss = Union[V,Vs]


class ExecutionMode:
    EAGER=1
    STATIC=2  # STATIC execution mode not supported
    
EXECUTION_MODE = ExecutionMode.EAGER


class T:
    """Tensor"""
    
    def __init__(self, val: Optional[V] = None):
        self.val = val
        
        if val is None:
            raise 'STATIC execution mode not supported'

    def __getitem__(self, key):
        return eval("Index")(self, key)

    def __setitem__(self, key, value):
        raise NotImplementedError('slice assign not yet supported')
    
    def __add__(self, other):
        return eval("Add")(self, other)
    
    def __neg__(self):
        return eval('Neg')(self)
    
    def __sub__(self, other):
        return eval("Sub")(self, other)
    
    def __mul__(self, other):
        return eval("Mul")(self, other)
    
    def __pow__(self, other):
        return eval("Pow")(self, other)
    
    def __matmul__(self, other):
        return eval("MatMul")(self, other)

    @property
    def shape(self):
        return self.val.shape

    @property
    def ndim(self):
        return self.val.ndim

    @property
    def size(self):
        return self.val.size

    @property
    def dtype(self):
        return self.val.dtype

    @property
    def T(self, axes: Tuple[int] = None):
        return eval("Transpose")(self, axes=axes)

    def __repr__(self):
        return f"Tensor({self.val})"

    def __str__(self):
        return f"Tensor({self.val})"

    def __eq__(self, other):
        return self.val == other.val

    def __hash__(self):
        return hash(self.val)

    def __iter__(self):
        return iter(self.val)

    def __len__(self):
        return len(self.val)

    def __getstate__(self):
        return self.val.__getstate__()

    def __setstate__(self, state):
        self.val = state

    def __array__(self):
        return self.val.__array__()


Ts = Tuple[T]
Tss = Union[T,Ts]


class Var(T):
    """Variable Tensor"""
    def __init__(self, val: Optional[V] = None, trainable: bool = True):
        
        self.trainable = trainable
        super().__init__(val=val)


class Op(T):
    """Operation-backed Tensor"""
    
    def __init__(self, *inputs: T):
        """Make sure to set any variables you might need in `forward` 
        before initializing when the graph is in eager execution mode
        """
        
        self.input_ts = inputs
        
        if EXECUTION_MODE == ExecutionMode.EAGER:
            val = self.forward(tuple(i.val for i in inputs))
        else:
            val = None
        
        super().__init__(val=val)
        
    def forward(self, inputs: Vs) -> V:
        raise NotImplementedError('subclasses should implement this method')
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        raise NotImplementedError('subclasses should implement this method')


class Transpose(Op):
    
    def __init__(self, t: T, axes: Tuple[int] = None):
        
        self.forward_kwargs = dict()
        self.reverse_kwargs = dict()
        
        if axes is not None:
            self.forward_kwargs['axes'] = axes
            self.reverse_kwargs['axes'] = tuple(reversed(axes))
            
        super().__init__(t)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Y = X.transpose(**self.forward_kwargs)
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dY = top_grad

        dX = dY.transpose(**self.reverse_kwargs)
        
        return (dX,)


class Reshape(Op):
    
    def __init__(self, t: T, shape: Tuple[int]):
        
        self.reshape_shape = shape
            
        super().__init__(t)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Y = X.reshape(self.reshape_shape)
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dY = top_grad
        
        dX = dY.reshape(tuple(reversed(self.reshape_shape)))
        
        return (dX,)
 

class Concat(Op):
    
    def __init__(self, ts: List[T], axis: int = 0):
        """Concatenates input tensors along an axis

        Args:
            t (T): [description]
            axis (int, optional): Axis to concatenate along. Defaults to 0.
        """
        
        self.axis = axis
        self.orig_axis_lens = [t.shape[axis] for t in ts]

        super().__init__(*ts)
    
    def forward(self, inputs: Vs) -> V:
        Xs = inputs
        
        Y = np.concatenate(Xs, axis=self.axis)
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dY = top_grad
        
        dXs = np.split(dY, self.orig_axis_dims, axis=self.axis)[0]
        
        return dXs


class Index(Op):
    
    def __init__(self, t: T, indices):
        """Slices a tensor along all axes.

        Args:
            t (T): The tensor to slice
            indices (Tuple[slice]):  The partial or full indices to slice on `t`.
                Can be an index, single slice, tuple of slices, or Ellipsis.
                `None` is not allowed.
        """
        if not isinstance(indices, tuple):
            indices = (indices,)

        self.indices = indices
        
        super().__init__(t)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Y = X[self.indices]
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        dY = top_grad
        
        dX = np.zeros(X.shape)
        dX[self.indices] = dY
        
        return (dX,)


class ReduceSum(Op):
    
    def __init__(self, t: T, axis: int):
        self.axis = axis
            
        super().__init__(t)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Y = X.sum(axis=self.axis)
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        dY = top_grad
        
        dX = np.repeat(
            np.expand_dims(dY, axis=self.axis),
            X.shape[self.axis],
            axis=self.axis
        )
        
        return (dX,)


class ReduceMax(Op):
    """Differentiable max operator"""
    
    def __init__(self, t: T, axis: int):
        self.axis = axis
            
        super().__init__(t)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Y = X.max(axis=self.axis)
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        dY = top_grad
        
        print(X.shape, dY.shape)

        dX = np.zeros_like(X)
        dX[np.argmax(X, axis=self.axis)] = dY

        print(dX.shape)
        
        return (dX,)

class ReduceMin(Op):
    """Differentiable min operator"""
    
    def __init__(self, t: T, axis: int):
        self.axis = axis
            
        super().__init__(t)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Y = X.min(axis=self.axis)
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        dY = top_grad
        
        dX = np.zeros_like(X)
        dX[np.argmin(X, axis=self.axis)] = dY
        
        return (dX,)


class NaN2Num(Op):
    
    def __init__(self, t: T, posinf: float = 1e3, neginf: float = -1e3):
        self.posinf = posinf
        self.neginf = neginf
            
        super().__init__(t)

    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = np.nan_to_num(X, posinf=self.posinf, neginf=self.neginf)
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = np.nan_to_num(dZ, posinf=10., neginf=-10.)
        
        return (dX,)


class Linear(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = X
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = dZ
        
        return (dX,)


class StopGrad(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = X
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = np.zeros_like(dZ)
        
        return (dX,)


class Neg(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = -X
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = -dZ
        
        return (dX,)


class Add(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        Y = inputs[1]
        
        Z = X + Y
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = dZ
        dY = dZ
        
        return dX, dY


class Sub(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        Y = inputs[1]
        
        Z = X - Y
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = dZ
        dY = -dZ
        
        return dX, dY


class Mul(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        Y = inputs[1]
        
        Z = X * Y
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        Y = inputs[1]
        dZ = top_grad
        
        dX = Y * dZ
        dY = X * dZ
        
        return dX, dY


class MatMul(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        Y = inputs[1]
        
        Z = X @ Y
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]  # [A,B]
        Y = inputs[1]  # [B,C]
        dZ = top_grad  # [A,C]
        
        dX = dZ @ Y.transpose()
        dY = X.transpose() @ dZ
        
        return dX, dY


class Exp(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = np.exp(X)
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        Z = output
        dZ = top_grad
        
        dX = Z * dZ
        
        return (dX,)


class Sigm(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = 1 / (1 + np.exp(-X))
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        Z = output
        dZ = top_grad
        
        dX = Z * (1 - Z) * dZ
        
        return (dX,)


class Tanh(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = np.tanh(X)
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        Z = output
        dZ = top_grad
        
        dX = ((1 - Z)**2) * dZ
        
        return (dX,)


class Relu(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = (X > 0) * X
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        dZ = top_grad
        
        dX = (X > 0) * dZ
        
        return (dX,)


class Threshold(Op):
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        
        Z = (X >= 0)
        
        return Z
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        dZ = top_grad
        
        dX = dZ
        
        return (dX,)


class Pow(Op):
    
    def __init__(self, x: T, power: int):
        
        self.power = power
        
        super().__init__(x)
    
    def forward(self, inputs: Vs) -> V:
        X = inputs[0]
        p = self.power
        
        Y = X ** p
        
        return Y
    
    def reverse_grad(self, inputs: Vs, output: V, top_grad: V) -> Vs:
        X = inputs[0]
        p = self.power
        dY = top_grad
        
        dX = p * X ** (p-1) * dY
        dX = np.nan_to_num(dX, posinf=1e3, neginf=-1e3)
        
        return (dX,)


class Optimizer:
    
    def minimize(self, t: T):
        pass


class SGD(Optimizer):
    
    def __init__(self, lr: float = 0.001):
        
        self.lr = lr
        self.debug = False
        
        super().__init__()
    
    def minimize(self, t: T):
        
        if EXECUTION_MODE == ExecutionMode.STATIC:
            raise 'STATIC execution mode not enabled'
        
        self.bprop(t_out=t, output_grad=-np.ones_like(t.val))
        
    def bprop(self, t_out: T, output_grad: V):
        
        output_grad = np.nan_to_num(output_grad, posinf=10., neginf=-10.)
        
        assert isinstance(t_out, (Var, Op))
        
        if self.debug:
            print(f'bp {t_out} output_grad:')
            print(output_grad)
        
        """
        This approach does not efficiently handle weights that are consumed by multiple nodes
        It would be better to treat backpropagation from a spreading-network-delta perspective
        than assume everything is a tree (That's also how I should do STATIC execution refresh)
        This should still work though, but it's just going to set the same weight multiple times
        for every downstream consumer.

        Actually, the whole thesis of minibatch gradient descent is that we can approximate a global
        gradient by updates on local subsets of data, so it might be sufficient to leave the code
        as is.

        However this approach will still take unnecessary walks down the tree in depth-first fashion.
        Innefficient: Yes; Works: Yes.
        """
        
        # iteratively called
        if isinstance(t_out, Var):
            if t_out.trainable:
                #print('output_grad', output_grad.shape)
                if self.debug:                    
                    print('t_out.val (old)', t_out.shape)
                    print(t_out.val)
                
                # yucky duct tape to handle batch size differences
                if t_out.shape[0] == 1 and output_grad.shape[0] > 1:
                    output_grad = np.sum(output_grad, axis=0)[None, ...]

                t_out.val = t_out.val + (self.lr * output_grad)
                if self.debug:
                    print('t_out.val (new)', t_out.shape)
                    print(t_out.val)
            
        elif isinstance(t_out, Op):
            input_grads = t_out.reverse_grad(
                inputs=tuple(t.val for t in t_out.input_ts),
                output=t_out.val, top_grad=output_grad)
            
            for input_t, input_grad in zip(t_out.input_ts, input_grads):
                self.bprop(t_out=input_t, output_grad=input_grad)
        
        return

## Neural Network

In [92]:
class Layer:

    def __init__(self) -> None:
        self._built = False
        self._loss = Var(np.zeros(()), trainable=False)

    @property
    def loss(self) -> T:
        return self._loss

    @property
    def trainable_variables(self) -> List[T]:
        pass

    def build(self, input_shape):
        pass

    def forward(self, X_T: T) -> T:
        pass

    def __call__(self, X_T: T) -> T:
        if not self._built:
            self.build(X_T.shape)
            self._built = True

        # reset the regularization loss
        self._loss = Var(0, trainable=False)

        return self.forward(X_T)

In [93]:
class Dense(Layer):

    def __init__(self, 
        units: int, 
        activation: Op = None, 
        use_bias: bool = True,
        activity_L2: float = None, 
        weight_L2: float = None,  
        bias_L2: float = None
    ):
        super(Dense, self).__init__()

        if activation is None:
            activation = Linear
            
        self.units = units
        self.activation = activation
        self.use_bias = use_bias

        self.activity_L2 = Var(activity_L2, trainable=False) if activity_L2 is not None else None
        self.weight_L2 = Var(weight_L2, trainable=False) if weight_L2 is not None else None
        self.bias_L2 = Var(bias_L2, trainable=False) if bias_L2 is not None else None

    @property
    def trainable_variables(self) -> List[T]:
        return [self.W_T] + ([self.B_T] if self.use_bias else [])

    def build(self, input_shape):
        W = np.random.uniform(low=-0.05, high=0.05, size=(input_shape[-1], self.units))
        self.W_T = Var(val=W, trainable=True)
        if self.use_bias:
            B = np.random.uniform(low=-0.05, high=0.05, size=(1, self.units))
            self.B_T = Var(val=B, trainable=True)

    def forward(self, X_T: T) -> T:

        # compute presynaptic input
        Z_T = X_T @ self.W_T

        # maybe add bias
        if self.use_bias:
            Z_T = Z_T + self.B_T

        # apply activation
        Y_T = self.activation(Z_T)
        
        # track regularization losses
        if self.activity_L2 is not None:
            self._loss += self.activity_L2 * ReduceSum(ReduceSum(Y_T**2, 1), 0)
        if self.weight_L2 is not None:
            self._loss += self.weight_L2 * ReduceSum(ReduceSum(self.W_T**2, 1), 0)
        if self.use_bias and self.bias_L2 is not None:
            self._loss += self.bias_L2 * ReduceSum(ReduceSum(self.B_T**2, 1), 0)

        return Y_T


layer = Dense(10, Relu, 0.1, 0.1, 0.1)
X_T = Var(np.random.uniform(0, 1, size=(3, 5)), trainable=False)
Y_T = layer(X_T)
X_T.val, Y_T.val, layer.loss, layer.trainable_variables

(array([[0.77489121, 0.77122773, 0.57652188, 0.4192426 , 0.96431482],
        [0.31910157, 0.79432582, 0.41088662, 0.77494032, 0.14994998],
        [0.61753794, 0.98515798, 0.88242563, 0.90604408, 0.35057589]]),
 array([[-0.        , -0.        , -0.        ,  0.00262531, -0.        ,
         -0.        , -0.        ,  0.06937857, -0.        , -0.        ],
        [-0.        ,  0.02030562, -0.        ,  0.01614764, -0.        ,
         -0.        , -0.        ,  0.03847847, -0.        , -0.        ],
        [-0.        ,  0.01156763, -0.        ,  0.02133865, -0.        ,
         -0.        , -0.        ,  0.05267905, -0.        , -0.        ]]),
 Tensor(0.004767513934823068),
 [Tensor([[-0.04720139 -0.02048058  0.00204957  0.00587311 -0.02101643  0.04068757
    -0.02559744  0.03319881 -0.04795947  0.03194449]
   [-0.0102456  -0.01168267 -0.04813369  0.02581874 -0.02411819 -0.04077488
     0.02897201 -0.00573817 -0.03563188 -0.03216296]
   [-0.0282607   0.01285564 -0.01693905 -0.

In [94]:
class Conv2D(Layer):
    """Standard 2D Conv layer.
    I.E. convolves over Tensors shaped [B, H, W, D]
    to produce [B, H-2*kernel_size, W-2*kernel_size, filters]
    """

    def __init__(self,
        filters: int, 
        kernel_size: Union[int, Tuple[int, int]] = 3,
        strides: Union[int, Tuple[int, int]] = 1,
        padding: str = 'valid',  # 'valid' or 'same'
        activation: Op = None,
        use_bias: bool = False,
        activity_L2: float = None, 
        weight_L2: float = None,  
        bias_L2: float = None,
    ):
        super(Conv2D, self).__init__()

        if activation is None:
            activation = Linear
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        assert kernel_size[0] % 2 == 1 and kernel_size[1] % 2 == 1, 'kernel_size must be odd'
        if isinstance(strides, int):
            strides = (strides, strides)
        assert strides[0] > 0 and strides[1] > 0, 'strides must be positive'
        padding = padding.lower()
        assert padding in ('valid', 'same'), 'padding must be valid or same'

        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.activation = activation
        self.use_bias = use_bias

        self.activity_L2 = Var(activity_L2, trainable=False) if activity_L2 is not None else None
        self.weight_L2 = Var(weight_L2, trainable=False) if weight_L2 is not None else None
        self.bias_L2 = Var(bias_L2, trainable=False) if bias_L2 is not None else None

    @property
    def trainable_variables(self) -> List[T]:
        return [self.W_T] + ([self.B_T] if self.use_bias else [])

    def build(self, input_shape):
        W = np.random.uniform(low=-0.05, high=0.05, size=(
            self.kernel_size[0]*self.kernel_size[1]*input_shape[-1], 
            self.filters))
        self.W_T = Var(val=W, trainable=True)
        
        if self.use_bias:
            B = np.random.uniform(low=-0.05, high=0.05, size=(1, self.filters))
            self.B_T = Var(val=B, trainable=True)

    def forward(self, X_T: T) -> T:

        # maybe pad input
        if self.padding == 'same':

            # various padding sizes, strides, and offsets
            # 0   1   2   3   4
            #     0 1 2 3 4
            #         0 1 2 3 4 5 6 7 8 9
            #         0 1 2 3 4
            #         0   1   2   3   4

            pad_top = self.strides[0]*(self.kernel_size[0]-1)//2
            pad_bottom = pad_top
            pad_left = self.strides[1]*(self.kernel_size[1]-1)//2
            pad_right = pad_left
            B, H_orig, W_orig, C = X_T.shape

            # pad height
            X_T = Concat([
                Var(np.zeros((B, pad_top, W_orig, C)), trainable=False),
                X_T,
                Var(np.zeros((B, pad_bottom, W_orig, C)), trainable=False),
            ], axis=1)

            # pad width
            X_T = Concat([
                Var(np.zeros((B, H_orig+pad_top+pad_bottom, pad_left, C)), trainable=False),
                X_T,
                Var(np.zeros((B, H_orig+pad_top+pad_bottom, pad_right, C)), trainable=False),
            ], axis=2)

        elif self.padding == 'valid':
            pass

        # stack the input tensor along the channel axis
        # but shifted by all possible kernel shifts
        stack = []
        for shift in itertools.product(range(0, self.strides[0]*self.kernel_size[0], self.strides[0]),
                                       range(0, self.strides[1]*self.kernel_size[1], self.strides[1])):
            stack.append(X_T[
                :,
                shift[0]:,
                shift[1]:,
                :
            ])

        # clip stack to greatest common shape
        min_shape = np.min(np.array([s.shape for s in stack]), axis=0)
        stack = [s[:, :min_shape[1], :min_shape[2], :] for s in stack]

        # stack the shifted tensors along the channel axis
        stacked = Concat(stack, axis=3)  # [B, H-k_h//2, W-k_w//2, C*k_h*k_w]

        # convolve over the stacked tensors
        Z_T = stacked @ self.W_T  
    
        # maybe add bias
        if self.use_bias:
            Z_T = Z_T + self.B_T

        # apply the activation function
        Y_T = self.activation(Z_T)

        # track regularization losses
        if self.activity_L2 is not None:
            self._loss += self.activity_L2 * ReduceSum(ReduceSum(Y_T**2, 1), 0)
        if self.weight_L2 is not None:
            self._loss += self.weight_L2 * ReduceSum(ReduceSum(self.W_T**2, 1), 0)
        if self.use_bias and self.bias_L2 is not None:
            self._loss += self.bias_L2 * ReduceSum(ReduceSum(self.B_T**2, 1), 0)

        return Y_T


layer = Conv2D(filters=64, kernel_size=5, strides=5, padding='same', weight_L2=0.1)
X_T = Var(np.random.uniform(0, 1, size=(2, 7, 70, 5)), trainable=False)
Y_T = layer(X_T)
X_T.shape, Y_T.shape, layer.loss, layer.trainable_variables

((2, 7, 70, 5),
 (2, 7, 70, 64),
 Tensor(0.6650852939497751),
 [Tensor([[-0.04690391  0.03219781  0.03133849 ... -0.01250094 -0.02040534
     0.02855088]
   [-0.04776165  0.02981423 -0.04664671 ...  0.01683485 -0.02068035
     0.0417308 ]
   [-0.04366009  0.00428372  0.01177435 ... -0.04615517 -0.02045305
    -0.019969  ]
   ...
   [ 0.01918657 -0.03537107  0.00954096 ...  0.00848717  0.01232981
    -0.00966575]
   [-0.04994427  0.00382309 -0.01930179 ... -0.04597156 -0.00578884
    -0.00101414]
   [-0.03620059 -0.00242729  0.02865895 ...  0.04867044 -0.0362522
     0.04736328]])])

In [95]:
class Flatten(Layer):
    """Flattens all non-batch dimensions into a single axis"""

    def __init__(self):
        super(Flatten, self).__init__()

    @property
    def trainable_variables(self) -> List[T]:
        return []

    def forward(self, X_T: T) -> T:
        flat_dims = functools.reduce(lambda x, y: x*y, X_T.shape[1:])
        Y_T = Reshape(X_T, (X_T.shape[0], flat_dims))
        return Y_T


layer = Flatten()
X_T = Var(np.random.uniform(0, 1, size=(2, 3, 4, 5)), trainable=False)
Y_T = layer(X_T)
X_T.shape, Y_T.shape, layer.loss, layer.trainable_variables

((2, 3, 4, 5), (2, 60), Tensor(0), [])

In [96]:
class Sequential(Layer):

    def __init__(self, layers):
        self.layers = layers
        super(Sequential, self).__init__()

    def forward(self, X_T: T) -> T:
        for layer in self.layers:
            X_T = layer(X_T)
        return X_T

    @property
    def loss(self) -> T:
        return sum(layer.loss for layer in self.layers)

    @property
    def trainable_variables(self) -> List[T]:
        trainable_vars = []
        for layer in self.layers:
            trainable_vars += layer.trainable_variables
        return trainable_vars


net = Sequential([
    Dense(10, Relu),
    Dense(128, Relu),
    Dense(512, Sigm),
    Dense(1, lambda x: x)
])
net(X_T)

img_T = Var(np.random.uniform(0, 1, size=(1, 28, 28, 1)), trainable=False)
net = Sequential([
    Conv2D(32, 3, 2, activation=Relu),
    Conv2D(64, 3, 2, activation=Relu),
    Flatten(),
    Dense(512, Sigm),
    Dense(1, lambda x: x)
])
net(img_T)

Tensor([[0.3228232]])

## Reinforcement Learning

Standards:
- `Step`: uses batch dimension (except `done` which is always a bool)
- `Agent` uses batch dimension
- `Environment` doesn't use batch dimension

This means you will have to use `Step.batch` and `Step.unbatch` in your training/running loop

In [97]:
class Step(NamedTuple):
    """Single step."""
    
    obs: np.ndarray
    next_obs: np.ndarray
    action: np.ndarray
    reward: np.ndarray
    done: bool
    info: any

    @staticmethod
    def unbatch(step: Step) -> List[Step]:
        return [
            Step(
                obs=step.obs[i:i+1],
                next_obs=step.next_obs[i:i+1],
                action=step.action[i:i+1],
                reward=step.reward[i:i+1],
                done=step.done,
                info=step.info[i:i+1],
            )
            for i in range(step.obs.shape[0])
        ]

    @staticmethod
    def batch(steps: List[Step]) -> Step:
        return Step(
            obs=np.concatenate([step.obs for step in steps], axis=0),
            next_obs=np.concatenate([step.next_obs for step in steps], axis=0),
            action=np.concatenate([step.action for step in steps], axis=0),
            reward=np.concatenate([step.reward for step in steps], axis=0),
            done=any(step.done for step in steps),
            info=[step.info for step in steps])

    @staticmethod
    def from_no_batch_axis(step: NoBatchStep) -> Step:
        return Step(
            obs=step.obs[None, ...],
            next_obs=step.next_obs[None, ...],
            action=step.action[None, ...],
            reward=step.reward[None, ...],
            done=step.done,
            info=[step.info]
        )

# dimensional type hinting
BatchStep = Step
NoBatchStep = Step

Traj = List[BatchStep]

TODO
- replace `list` with List and same for dict/mapping
- use one-hot encodings for all actions and rewire the policies to use them

In [98]:
class Environment:
    """RL environment."""

    def __init__(self):
        pass

    def reset(self) -> Step:
        """Resets the environment

        Returns:
            Step: Initial step. The `next_obs` attribute should be set 
                with an initial observation. `done` should be False. 
                `obs` and `action` should not be used.
        """
        pass

    def step(self, action: np.ndarray) -> Step:
        """Computes one logical step in the environment

        Args:
            action (np.ndarray): The action to take.

        Returns:
            Step: Step resulting from taking `action`. The `next_obs` attribute
                should be set with the observation resulting from taking the `action`
                in the current environment state. `obs` should not be used. If the 
                environment is turn-based, then the reward should correspond to the 
                agent that just acted (not the next agent in line to act).
        """
        pass

    def render(self):
        pass


class NoBatchEnv(Environment):
    """Environment with no batch dimension."""

    def reset(self) -> NoBatchStep:
        pass

    def step(self, action: np.ndarray) -> NoBatchStep:
        pass


class BatchEnv(Environment):
    """Environment with batch dimension."""

    def reset(self) -> BatchStep:
        pass

    def step(self, action: np.ndarray) -> BatchStep:
        pass


In [99]:
class Batch1Env(BatchEnv):
    """Adds a batch axis to all outgoing Steps and strips it off incoming Steps."""

    def __init__(self, env: NoBatchEnv):
        self.env = env

    def reset(self) -> BatchStep:
        return Step.from_no_batch_axis(self.env.reset())

    def step(self, action: np.array) -> BatchStep:
        return Step.from_no_batch_axis(self.env.step(action[0]))

    def render(self):
        self.env.render()

In [100]:
class ParallelEnv(BatchEnv):
    """Keeps reseting the same environment in a batch.
    
    Declares itself done when a total of `batch_size` individual environment
    dones are experienced.

    NOTE: Individual environments should be `NoBatchEnv`'s.
    """

    def __init__(self, batch_size: int, env_init_fn: Callable[[], NoBatchEnv]):
        self.batch_size = batch_size
        self.env_init_fn = env_init_fn

    def reset(self) -> Step:
        self.dones = 0
        self.envs = [self.env_init_fn() for _ in range(self.batch_size)]
        steps = [env.reset() for env in self.envs]
        steps = [Step.from_no_batch_axis(step) for step in steps]
        return Step.batch(steps)

    def step(self, action: np.array) -> Step:
        steps = []
        for i, (env, single_action) in enumerate(zip(self.envs, action)):
            step = env.step(single_action)
            if step.done:
                self.envs[i] = self.env_init_fn()
                new_step = env.reset()
                step.next_obs = new_step.next_obs
                self.dones += 1
            steps.append(Step.from_no_batch_axis(step))

        batched_step = Step.batch(steps)
        batched_step.done = self.dones >= self.batch_size
        return batched_step

    def render(self):
        for env in self.envs:
            env.render()

In [101]:
class ReplayBuffer:
    """Replay buffer.

    Expects the following hyperparameters:
        - `epoch`: The current epoch.
        - `batch_size`: Number of trajectories to return at each call.
        - `min_sample_len`: Minimum length of trajectories to sample.
        - `max_sample_len`: Maximum length of trajectories to sample.
        - `num_steps_replay_coef`: Sampling coefficient based on trajectory length.
        - `success_replay_coef`: Sampling coefficient based on trajectory success.
        - `age_replay_coef`: Sampling coefficient based on trajectory age.
    """
  
    def __init__(self, hparams: dict):
        self.hparams = hparams
        self.trajs = dict()

    @property
    def flat_traj(self):
        return [step for traj in self.trajs for step in traj]

    def add(self, traj: Traj, epoch: int):
        """Add a new trajectory to the buffer.

        Args:
            traj (Traj): the trajectory to add.
            epoch (int): the epoch that the trajectory was experience in.
        """
        if epoch not in self.trajs:
            self.trajs[epoch] = []
        self.trajs[epoch] += traj

    def sample(self) -> Traj:
        """Samples a batched trajectory from the buffer stochastically based on:
            - the number of steps in the trajectory (num_steps_replay_coef)
            - how well the agent did in the trajectory (success_replay_coef)
            - how long ago the trajectory was experienced (age_replay_coef)

        Returns:
            Traj: a trajectory of batched steps experienced.
        """
        
        weights = {
            epoch: self.hparams['num_steps_replay_coef'] * len(traj) +
                   self.hparams['success_replay_coef'] * sum(step.reward for step in traj) +
                   self.hparams['age_replay_coef'] * (epoch - self.hparams['epoch'])
            for epoch, traj in self.trajs.items()
        }
        epochs = np.random.choice(list(weights.keys()), size=(), replace=True, p=list(weights.values()))
        trajs = [self.trajs[epoch] for epoch in epochs]
        batched_traj = [BatchStep.batch(steps) for steps in zip(*trajs)]
        return batched_traj

In [102]:
class Agent:

    def __init__(self, policy: Callable):
        self.policy = policy

    def forward(self, step: BatchStep) -> np.ndarray:
        """Generates an action for a given observation using `self.policy`. 
        Override if you want to give your policy more information such as
        recurrent state or previous reward.

        Args:
            step (Step): last step output by the environment. This means the agent
                should feed `step.next_obs`, not `step.obs` to its policy. If the 
                environment is multi-agent, then the `reward` attribute has already
                updated to reflect the reward for this agent by the driver.

        Returns:
            np.ndarray: The action to take
        """
        return self.policy(step.next_obs)

    def reward(self, traj: Traj) -> float:
        """Evaluates the cumulative reward for your agent as the sum of 
        individual rewards experienced. 
        
        If your agent uses intrinsic rewards, be sure to add them in here.
        Do not introduce Q-values or predicted rewards here.

        Args:
            traj (Traj): timestep trajectory.

        Returns:
            float: Cumulative (sum) reward over the entire sequence.
        """
        return sum(step.reward for step in traj)

    def train(self, traj: Traj):
        """Train your agent on a sequence of Timestep's.

        Args:
            traj (Traj): timestep trajectory.
        """
        raise NotImplemented('Method `train` must be implemented by subclass')

In [103]:
class ParallelDriver:
    """Drives batched turn-based `BatchedEnv` environments with multiple agents.
    Also supports single-agent environments as a special case.
    """

    def __init__(self):
        pass

    def drive(self, agents: Mapping[str, Agent], env: Environment) -> Mapping[str, Traj]:
        """Drives a batched environment with multiple agents.
        
        Args:
            agents (Mapping[str, Agent]): A dictionary of agents to drive.
            env (Environment): The environment to drive.

        Returns:
            Mapping[str, Traj]: A dictionary of trajectories for each agent.
                Each trajectory is completely disengaged from the other agent's.
                (i.e.: the obs, next_obs, action, reward, done attributes are
                individual to each agent for each trajectory.)
        """

        names_it = itertools.cycle(agents.keys())
        trajs = {agent_name: [] for agent_name in agents}
        prev_rewards = {agent_name: 0. for agent_name in agents}

        step = env.reset()
        while not step.done:
            agent_name = next(names_it)
            
            # `Agent.forward` only looks at `step.next_obs` and `step.reward`
            # but I'm assigning defaults just to be safe.
            action = agents[agent_name].forward(Step(
                obs=step.obs,  # what the previous agent saw before acting
                next_obs=step.next_obs,  # what the current agent sees before acting
                reward=prev_rewards[agent_name],  # the reward this agent experienced following its last action
                done=step.done,  # whether the environment was done after the previous agent acted
                info=step.info  # any extra information the environment might have output
            )) 

            prev_step = step
            step = env.step(action)  # `Environment.step` produces a Step with all fields except `step.obs` set
            step.obs = prev_step.next_obs  # the current agent's observation is the previous agent's next observation
            prev_rewards[agent_name] = step.reward  # the reward for the action the current agent just took
            trajs[agent_name].append(step)  # Step completely corresponding to this agent (obs before action, obs after action, action, reward, done, info)
        
        return trajs

In [104]:
class ParallelTrainer:
    """Trains `BatchedEnv` environments and mutliple agents 
    (with N=1 single-agent supported as a special case).
    
    Uses following hyperparameters:
    - `epoch`: the current epoch. Reads and writes to this variable.
    - `epochs`: the number of epochs to train for.
    - `min_steps_per_epoch`: the minimum number of steps to train for each epoch.
    """

    def __init__(self, hparams: dict, callbacks: List[Callable]):
        self.hparams = hparams
        self.callbacks = callbacks
        
    def train(self, 
        agents: Mapping[str, Agent], 
        env: Environment,
        test_env: Environment = None,
        buffers: Mapping[str, ReplayBuffer] = None,
        collect_driver: ParallelDriver = None,
        test_driver: ParallelDriver = None,
        histories: Mapping[str, Mapping[int, Mapping[str, any]]] = None,
        ) -> Mapping[int, Mapping[str, any]]:

        agent_names = list(agents.keys())
        
        # initialize defaults
        if test_env is None:
            test_env = env
        if buffers is None:
            buffers = dict()
        if collect_driver is None:
            collect_driver = ParallelDriver()
        if test_driver is None:
            test_driver = collect_driver
        if histories is None:
            histories = dict()  # {agent_name: {epoch: {...data}}}

        # build uninitialized agent-specific objects
        for agent_name in agent_names:
            if agent_name not in buffers:
                buffers[agent_name] = ReplayBuffer()
            if agent_name not in histories:
                histories[agent_name] = dict()

        # run training loop
        for epoch in range(self.hparams['epoch'], self.hparams['epochs']):
            self.hparams['epoch'] = epoch

            # collect trajectories
            steps = 0
            while steps < self.hparams['steps_per_epoch']:
                collect_trajs = collect_driver.drive(agents, env)
                steps += min(len(traj) for _, traj in collect_trajs.items())
                for agent_name in agent_names:
                    buffers[agent_name].add(collect_trajs[agent_name])

            # train
            train_trajs = {agent_name: buffers[agent_name].sample() for agent_name in agent_names}
            for agent_name in agent_names:
                agents[agent_name].train(train_trajs[agent_name])
                
            # test
            test_trajs = test_driver.drive(agents, env)

            # record history and run callbacks
            for agent_name in agent_names:
                histories[agent_name][epoch] = {
                    'epoch': epoch,
                    'agent': agents[agent_name],
                    'all_agents': agents,
                    'env': env,
                    'test_env': test_env,
                    'collect_traj': collect_trajs[agent_name],
                    'train_traj': train_trajs[agent_name],
                    'test_traj': test_trajs[agent_name],
                    'buffer': buffers[agent_name],
                }
                for callback in self.callbacks:
                    callback(histories[agent_name][epoch])

        return histories

In [105]:
class PrintCallback:

    def __init__(self, hparams: dict, print_hparam_keys: List[str] = None, print_data_keys: List[str] = None):
        if print_hparam_keys is None:
            print_hparam_keys = ['epoch']
        if print_data_keys is None:
            print_data_keys = []
        
        self.hparams = hparams
        self.print_hparam_keys = print_hparam_keys
        self.print_data_keys = print_data_keys

    def __call__(self, data: Mapping[str, any]):
        for key in self.print_hparam_keys:
            print(f'{key}: {self.hparams[key]}', end='\t')
        for key in self.print_data_keys:
            print(f'{key}: {data[key]}', end='\t')

In [106]:
class QEvalCallback:

    def __init__(self, 
        eval_on_collect: bool = True, 
        eval_on_train: bool = False, 
        eval_on_test: bool = False):

        self.eval_on_collect = eval_on_collect
        self.eval_on_train = eval_on_train
        self.eval_on_test = eval_on_test

    def __call__(self, data: Mapping[str, any]):
        agent = data['agent']
        if not hasattr(agent, 'q_eval'):
            return

        if self.eval_on_collect:
            traj = data['collect_traj']
            q_val = agent.q_eval(traj)
            data['q_collect'] = q_val

        if self.eval_on_train:
            traj = data['train_traj']
            q_val = agent.q_eval(traj)
            data['q_train'] = q_val

        if self.eval_on_test:
            traj = data['test_traj']
            q_val = agent.q_eval(traj)
            data['q_test'] = q_val

TODO
- make the reward optionally an advantage computation over last round
- also make a recurrent DQN agent (estimate q function of a sequence of states)
- make a simple greedy connect4 agent
- make the preprocessor perform a columnwise mean pool before flattening
- train the preprocessor on an auxillary objective to estimate the max connected for each length for self and for oponent
- add padding='SAME'|'VALID' to conv2d

### Agent implementations

In [113]:
class RandomAgent(Agent):
    """Takes a random action on each timestep."""

    def __init__(self, num_actions: int):
        super(RandomAgent, self).__init__(policy=self._policy)
        self.num_actions = num_actions

    def _policy(self, obs: np.ndarray) -> np.ndarray:
        choices = np.random.randint(0, self.num_actions, (obs.shape[0],))
        onehots = np.eye(self.num_actions)[choices]
        return onehots

    def train(self, traj: Traj):
        """Train your agent on a sequence of Timestep's.

        Args:
            traj (Traj): timestep trajectory.
        """
        pass

agent = RandomAgent(num_actions=5)
obs = Var(np.array([[1, 2], [3, 4]]))
step = Step(obs=None, next_obs=obs, action=None, reward=None, done=None, info=None)
agent.forward(step)

array([[0., 0., 0., 0., 1.],
       [0., 1., 0., 0., 0.]])

In [None]:
class RealDQN(Agent):
    """'Classic' Deep Q-learning agent.
    Implements the approach in https://arxiv.org/pdf/1312.5602.pdf.
    """

    def __init__(self, num_actions: int, encoder: Layer, hparams: dict):
        ####### TODO TODO TODO #######
        # make encoder perform columnwise max pooling
        # make the qfunction column-specific 
        # (no X-row connections outside conv layers)
        # output max q value over all rows
        # implement corresponding changes in categorical agent
        # Actually the categorical agent just a per-column MLP
        # with no max Q pooling.
        ####### TODO TODO TODO #######

        self.num_actions = num_actions
        self.encoder = encoder
        self.head = Sequential([
            Dense(512, Sigm), 
            Dense(1, lambda x: x)
        ]) # [B, L+|A|] -> [B, 1]
        self.hparams = hparams

        super(RealDQN, self).__init__(policy=self._policy)

    def _policy(self, obs: np.ndarray) -> np.ndarray:

        # Maybe take greedy step
        epsilon = self.hparams['epsilon_start'] * self.hparams['epsilon_decay']**self.hparams['epoch']
        epsilon = min(epsilon, self.hparams['min_epsilon'])
        if random.random() < epsilon:
            return np.random.rand(self.num_actions)
        
        # Otherwise take the action with the highest Q-value
        actions = np.arange(self.num_actions)  # [self.num_actions]
        q_vals = np.zeros((self.num_actions,))  # [self.num_actions]
        for i, action in enumerate(actions):
            
            # prepare inputs
            obs_T = Var(obs[None, ...])  # [1, H, W, 2]
            action_T = Var(action[None, None], trainable=False)  # [1, 1]

            # run the network
            enc_T = self.encoder(obs_T)  # [1, d_enc]
            cat_T = Concat([enc_T, action_T], axis=1)  # [1, d_enc+1]
            q_T = self.head(cat_T)  # [1, 1]

            # store q-value
            q_vals[i] = q_T.val[0,0]  # []
        
        # select the action with the highest Q-value
        action = actions[q_vals.argmax()]
        return action

    def train(self, traj: Traj):
        """Train your agent on a sequence of Timestep's.

        Args:
            traj (Traj): batched timestep trajectory.
        """

        optimizer = self.hparams['optimizer']
        discount_T = Var(self.hparams['discount'], trainable=False)  # []
        for step in traj:

            obs_T = Var(step.obs, trainable=False)  # [B, H, W, 2]
            action_T = Var(step.action[None], trainable=False)  # [B, 1]
            action_next_T = Var(self.policy(step.next_obs), trainable=False)  # [B, 1]
            r_T = Var(step.reward, trainable=False)  # [B]

            # compute previous Q value using the actual (not necesarily optimal) action selected
            enc_T = self.encoder(obs_T)  # [B, d_enc]
            cat_T = Concat([enc_T, action_T], axis=1)  # [B, d_enc+1]
            Qnow_T = self.head(cat_T)[:, 0]  # [B]
            reg_loss_now_T = self.encoder.loss + self.head.loss  # []

            # compute the maximum possible next step Q-value
            enc_T = self.encoder(obs_T)  # [B, d_enc]
            cat_T = Concat([enc_T, action_next_T], axis=1)  # [B, d_enc+1]
            Qnext_T = self.head(cat_T)[:, 0]  # [B]
            reg_loss_next_T = self.encoder.loss + self.head.loss  # []

            # train the current policy to estimate new Q-value 
            # i.e.: approx Qnew_T <= (1-lr)*Qnow_T + lr*(r_T+discount_T*Qnext_T)  # [B, 1]
            # using gradient descent (let lr=1 in the above; small updates are handled in the SGD step)
            loss_T = ReduceSum(((r_T+discount_T*StopGrad(Qnext_T)) - Qnow_T)**2,
                axis=0) + reg_loss_now_T + reg_loss_next_T  # []
            optimizer.minimize(loss_T)

    def q_eval(self, obs: np.ndarray, action: np.ndarray) -> np.ndarray:
        """Evaluate the Q-value of a given state-action pair.

        Args:
            obs (np.ndarray): observation.
            action (np.ndarray): action.

        Returns:
            q_val (np.ndarray): Q-value of the given state-action pair.
        """

        # prepare inputs
        obs_T = Var(obs)  # [B, H, W, 2]
        action_T = Var(action)  # [B, 1]
        enc_T = self.encoder(obs_T)  # [B, d_enc]
        cat_T = Concat([enc_T, action_T], axis=1)  # [B, d_enc+1]
        Qnow_T = self.head(cat_T)[:, 0]  # [B]

        return Qnow_T.val


encoder = Sequential([
    Conv2D(32, 3, 2, 'same', Relu),
    Conv2D(64, 3, 2, 'same', Relu),
    Flatten(),
])  # [B, H, W, 1] -> [B, L]

agent = RealDQN(num_actions=HPARAMS['board_size'], encoder=encoder, hparams=HPARAMS)
agent

NameError: name 'HPARAMS' is not defined

In [None]:
class CategoricalDQN(Agent):
    """Categorical deep Q-network agent.
    I never read the paper for this architecture, so my implementation
    might be different from the origonal researchers.
    """

    def __init__(self, num_actions: int, encoder: Layer, hparams: dict):

        self.num_actions = num_actions
        self.encoder = encoder  # [B, H, W, C] -> [B, d_enc]
        self.head = Sequential([
            Dense(2*hparams['board_size']*num_actions, Tanh),
            Dense(num_actions, Linear)
        ]) # [B, d_enc] -> [B, num_actions]
        self.hparams = hparams

        super(CategoricalDQN, self).__init__(policy=self._policy)

    def _policy(self, obs: np.ndarray) -> np.ndarray:

        # Maybe take greedy step
        epsilon = self.hparams['epsilon_start'] * self.hparams['epsilon_decay']**self.hparams['epoch']
        epsilon = min(epsilon, self.hparams['min_epsilon'])
        if random.random() < epsilon:
            return np.random.rand(self.num_actions)
        
        # Otherwise take the action with the highest Q-value
        # compute q-values for all actions
        obs_T = Var(obs, trainable=False)  # [B, H, W, C]
        enc_T = self.encoder(obs_T)  # [B, d_enc]
        qvals_T = self.head(enc_T)  # [B, A]
        return np.argmax(qvals_T.val, axis=1)  # [B]

    
    def train(self, traj: Traj):
        """Train your agent on a sequence of Timestep's.

        Args:
            traj (Traj): batched timestep trajectory.
        """

        optimizer = self.hparams['optimizer']
        discount_T = Var(self.hparams['discount'], trainable=False)  # []
        for step in traj:

            obs_T = Var(step.obs, trainable=False)  # [B, H, W, 2]
            obs_next_T = Var(step.next_obs, trainable=False)  # [B, H, W, 2]
            r_T = Var(step.reward, trainable=False)  # [B]

            # compute previous Q value using the actual (not necesarily optimal) action selected
            enc_T = self.encoder(obs_T)  # [B, d_enc]
            qvals_T = self.head(enc_T)  # [B, A]
            Q_now_T = qvals_T[step.action]  # [B]
            reg_loss_now_T = self.encoder.loss + self.head.loss  # []

            # compute the maximum possible next step Q-value
            enc_T = self.encoder(obs_next_T)  # [B, d_enc]
            qvals_T = self.head(enc_T)  # [B, A]
            Q_next_T = np.max(qvals_T.val, axis=1)  # [B]
            reg_loss_next_T = self.encoder.loss + self.head.loss  # []

            # train the current policy to estimate new Q-value 
            # i.e.: approx Qnew_T <= (1-lr)*Qnow_T + lr*(r_T+discount_T*Qnext_T)  # [B, 1]
            # using gradient descent (let lr=1 in the above; small updates are handled in the SGD step)
            # but only update the targets that were actually selected for action at `step_now`.
            loss_T = ReduceSum(((r_T+discount_T*StopGrad(Q_next_T)) - Q_now_T)**2,
                axis=0) + reg_loss_now_T + reg_loss_next_T  # []
            optimizer.minimize(loss_T)


    def q_eval(self, obs: np.ndarray, action: np.ndarray) -> np.ndarray:
        """Evaluate the Q-value of a given state-action pair.

        Args:
            obs (np.ndarray): observation.
            action (np.ndarray): action.

        Returns:
            q_val (np.ndarray): Q-value of the given state-action pair.
        """

        # prepare inputs
        obs_T = Var(obs)  # [B, H, W, 2]
        enc_T = self.encoder(obs_T)  # [B, d_enc]
        qvals_T = self.head(enc_T)  # [B, A]

        return qvals_T.val[action]


agent = CategoricalDQN(num_actions=HPARAMS['board_size'], encoder=encoder, hparams=HPARAMS)
agent

<__main__.CategoricalDQN at 0x7fe1ec8b5cd0>

## Connect 4

In [None]:
class HardwiredConnect4Agent(Agent):

    def __init__(self, board_size: int, hparams: dict):
        self.board_size = board_size
        self.hparams = hparams
        super(HardwiredConnect4Agent, self).__init__(policy=self._policy)

    def _policy(self, obs: np.ndarray) -> np.ndarray:
        B = obs.shape[0]
        action = np.zeros(B, dtype=np.int32)
        for b in range(B):
            o = obs[b]
            ## TODO: make a greedy agent
            action[b] = random.randint(0, self.board_size-1)
        return action

    def train(self, traj: Traj):
        pass

In [None]:
class Board:
    """Drafted by copilot with minor human edits"""

    def __init__(self, size=7, win_length=4):
        self.size = size
        self.win_length = win_length
        self.board = np.zeros((size, size))
        self.turn = 1
        self.winner = 0

    def __str__(self):
        return f'{self.board}\nTurn: {self.turn}\nWinner: {self.winner}'

    def __repr__(self):
        return self.__str__()

    def __eq__(self, other):
        return self.board == other.board

    def __hash__(self):
        return hash(self.board.tostring())

    def is_full(self):
        return np.count_nonzero(self.board) == self.size**2

    def is_empty(self, col):
        return self.board[0, col] == 0

    def is_valid_move(self, col):
        return 0 <= col < self.size and self.is_empty(col)

    def make_move(self, col):
        if self.is_valid_move(col):
            highest_row = np.where(self.board[:, col] == 0)[0][-1]
            self.board[highest_row, col] = self.turn
            self.turn *= -1

    def undo_move(self, col):
        if self.is_valid_move(col) and self.board[0, col] != 0:
            highest_row = np.where(self.board[:, col] == 0)[0][-2]
            self.board[highest_row, col] = 0
            self.turn *= -1

    def check_win(self) -> int:
        for turn in [-1, 1]:
            if self.num_connected(self.win_length, turn) > 0:
                self.winner = turn
                return True
        return self.winner

    def num_connected(self, length, turn):
        num_connected = 0
        # Check horizontal
        for row in range(self.size):
            for col in range(self.size-length+1):
                if np.all(self.board[row, col:col+length] == turn):
                    num_connected += 1
        # Check vertical
        for col in range(self.size):
            for row in range(self.size-length+1):
                if np.all(self.board[row:row+length, col] == turn):
                    num_connected += 1
        # Check diagonal
        for row in range(self.size-length+1):
            for col in range(self.size-length+1):
                if all(self.board[row+i, col+i] == turn for i in range(length)):
                    num_connected += 1
        # Check anti-diagonal
        for row in range(self.size-length+1):
            for col in range(length-1, self.size):
                if all(self.board[row+i, col-i] == turn for i in range(length)):
                    num_connected += 1
        return num_connected

board = Board()
board.make_move(0)
board.make_move(0)
board.make_move(0)
board.make_move(0)
board

[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [-1.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.]
 [-1.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.]]
Turn: 1
Winner: 0

In [114]:
class BoardEnv:

    def __init__(self, board_size=7, win_length=4, reward_mode: str = 'sparse'):
        """RL environment for Connect 4. 

        Args:
            board_size (int, optional): The size of the board. Defaults to 7.
            win_length (int, optional): The minimum connected length to win. Defaults to 4.
            reward_mode (str, optional): One of 'sparse', 'dense_stateless', 'dense_advantage'. 
                - For 'sparse', the reward is 1 if the player has attained a connect `win_length`,
                    and is 0 otherwise.
                - For 'dense_stateless', the reward increases linearly with the number of N-in-a-row's
                    for all values of N from 0 to board_size weighted logarithmically by N.
                - For 'dense_advantage', the reward is determined by the difference between the
                    previous and current dense reward for each player individually.
                `reward_mode` defaults to 'sparse'.
        """
        self.board_size = board_size
        self.win_length = win_length
        self.reward_mode = reward_mode

        self.reset()

    def reset(self) -> NoBatchStep:
        self.board = Board(self.board_size, self.win_length)

        if self.reward_mode == 'dense_advantage':
            self.prev_dense_reward = [0., 0.]

        return NoBatchStep(
            obs=np.zeros_like(obs),
            next_obs=self._make_obs(), 
            reward=0., 
            done=False, 
            info=dict()
        )

    def step(self, action: np.ndarray) -> NoBatchStep:
        """Apply agent X's action to the board and 
        returns the next agent's timestep.

        Args:
            action (np.ndarray): array shaped (board_size,). The arg max
                action index is the column where next piece is placed.

        Returns:
            tuple: NoBatchStep with values:
                obs (np.ndarray[H, W, 2]): None
                next_obs (np.ndarray[H, W, 2]): the next board state with
                    self's entered squares represented in channel 0 and
                    opponent's squares represented in channel 1
                action (np.ndarray[board_size]): None
                reward (float): the reward for the agent
                    If sparse_reward is True, then reward is -1, 0, or +1.
                    If sparse_reward is False, then reward is:
                        ego_dense_reward - opponent_dense_reward.
                done (bool): whether the game is over
                info (dict): extra information
        """
        # Apply action
        action_index = np.argmax(action)  # []
        self.board.make_move(action_index)  # this flips `board.turn`
        # Temporarily unflip `board.turn`
        self.board.turn *= -1

        # Make egocentric observation
        obs = self._make_obs()

        # Compute reward
        # This is the lazy way to do it, but it's fast enough

        # Sparse reward
        winner = self.board.check_win()
        sparse_reward = self.board.turn * winner

        # Dense reward
        def dense_reward_for_turn(board, turn):
            r = 0
            for length in range(2, self.board_size):
                r += math.log(length) * board.num_connected(length, turn)
            return r
        ego_dense_reward = dense_reward_for_turn(self.board, self.board.turn)
        opponent_dense_reward = dense_reward_for_turn(self.board, -self.board.turn)
        dense_reward = ego_dense_reward - opponent_dense_reward  

        if self.reward_mode == 'sparse':
            reward = sparse_reward
        elif self.reward_mode == 'dense_stateless':
            reward = dense_reward
        elif self.reward_mode == 'dense_advantage':
            turn_index = (self.board.turn+1)//2
            reward = dense_reward - self.prev_dense_reward[turn_index]
            self.prev_dense_reward[turn_index] = dense_reward
        else:
            raise ValueError(f'Invalid reward_mode: {self.reward_mode}')

        # Evaluate whether game is over
        winner = self.board.check_win()
        done = winner != 0

        # Record debugging info
        info = dict()

        # Revert temporary flip on `board.turn`
        self.board.turn *= -1

        return NoBatchStep(
            obs=None,
            next_obs=obs, 
            action=None,
            reward=reward, 
            done=done, 
            info=info
        )

    def render(self):
        print(self.board)

    def _make_obs(self) -> np.ndarray:
        """Only show ego values on first channel and opponent values on second channel"""
        obs = np.stack([
            self.board.turn * self.board.board, 
            -self.board.turn * self.board.board
            ], axis=-1)  # [board_size, board_size, 2]
        obs[obs<0] = 0  # rectify negative values
        return obs

board_size = 10
env = BoardEnv(board_size=board_size, win_length=5, reward_mode='dense_advantage')

step = env.reset()
while not step.done:
    action = np.random.uniform(0, 1, (board_size,))
    step = env.step(action)

    print(f'Action: {action}')
    env.render()
    print(f'Reward: {step.reward}\n')

print(f'Winner: {env.board.winner}')

TypeError: __new__() missing 1 required positional argument: 'action'

In [None]:
INITIAL_HPARAMS = dict(
    board_size=8,           # Board size
    discount=0.99,          # Discount factor
    learning_rate=0.001,    # Learning rate
    batch_size=32,          # Number of samples per training batch
    train_freq=10000,       # Number of timesteps between training steps
    epoch=0,                # Current epoch
    epochs=10,              # Number of training epochs
    epsilon_start=1.0,      # Starting value for epsilon
    min_epsilon=0.01,       # Final value for epsilon
    epsilon_decay=0.95,     # Decay rate for epsilon per epoch
)

INITIAL_HPARAMS['optimizer'] = \
    SGD(INITIAL_HPARAMS['learning_rate'])   # Optimizer


HPARAMS = INITIAL_HPARAMS.copy()