In [None]:
!pip3 install dm-haiku
!pip3 install optax

In [3]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import numpy as np
from typing import NamedTuple

In [9]:
def load_dataset():
    transformn = transforms.Compose([transforms.Resize((224, 224)), transforms.Grayscale(3), # A way to get 3 channel MNIST
                                     transforms.ToTensor()])
    batch_size = 8
    trainset = datasets.MNIST(root='./data', train=True,
                              download=True, transform=transformn)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    testset = datasets.MNIST(root='./data', train=False,
                              download=True, transform=transformn)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                              shuffle=False, num_workers=2)

    x_init = np.random.randn(16, 224, 224, 3).astype(np.float32)
    return trainloader, x_init

In [5]:
class ResNetBlock(nn.Module):
    def __init__(self, out_channels, stride=1, use_1x1_conv=False):
        super().__init__()
        self.conv1 = hk.Conv2D(output_channels=out_channels, kernel_shape=3, stride=stride, padding='SAME')
        self.bn1 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.999)
        self.conv2 = hk.Conv2D(output_channels=out_channels, kernel_shape=3, stride=1, padding='SAME')
        if use_1x1_conv:
            self.conv3 = hk.Conv2D(output_channels=out_channels, kernel_shape=1, stride=stride)
        else:
            self.conv3 = None
        self.bn2 = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.999)

    def forward(self, x):
        out = jax.nn.relu(self.bn1(self.conv1(x), is_training=True))
        out = self.bn2(self.conv2(out), is_training=True)
        if self.conv3:
            x = self.conv3(x)
        return jax.nn.relu((out + x))

In [25]:
class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = hk.Conv2D(output_channels=64, kernel_shape=7, stride=2, padding='SAME', name='conv1')
        self.bn1 = hk.BatchNorm(create_scale=False, create_offset=False, decay_rate=0.999)    # 64
        self.maxpool = hk.MaxPool(window_shape=3, strides=2, padding=1)
        self.l1 = self._make_layer(64, 2, first_layer=True)
        self.l2 = self._make_layer(128, 2)
        self.l3 = self._make_layer(256, 2)
        self.l4 = self._make_layer(512, 2)
        self.avgpool = hk.AvgPool(window_shape=7, padding='SAME', strides=1)
        self.fc = hk.Linear(10)
        
    def _make_layer(self, channels, num_blocks, first_layer=False):
        layers = []
        for b in range(num_blocks):
            if b == 0 and not first_layer:
                layers.append(ResNetBlock(channels, stride=2, use_1x1_conv=True))
            else:
                layers.append(ResNetBlock(channels))
        return hk.Sequential(layers)

    def __call__(self, x):
        out = jax.nn.relu(self.bn1(self.conv1(x), is_training=True))
        out = self.l1(out)
        out = self.l2(out)
        out = self.l3(out)
        out = self.l4(out)
        out = self.avgpool(out)
        out = self.fc(out)
        return out

In [29]:
class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState

def net_fn(x):
    return ResNet18()(x)

lossfn = optax.softmax_cross_entropy_with_integer_labels

def main():
    network = hk.transform_with_state(net_fn)
    optimizer = optax.sgd(1e-3)
    init_rng = jax.random.PRNGKey(42)

    def loss_fn(params, x, y, state, rng):
        out, state = network.apply(params, state, rng, x)
        loss = jnp.sum(lossfn(out, y))
        return loss, state

    def update_weights(training_state, x, y, state, rng):
        (loss, grads), state = jax.value_and_grad(loss_fn, has_aux=True)(training_state.params, x, y, state, rng)
        updates, opt_state = optimizer.update(grads, training_state.opt_state, params)
        params = optimizer.apply_updates(params, updates)
        return TrainingState(params, opt_state), loss, state

    trainloader, x_init = load_dataset()
    init_params, state = network.init(init_rng, x_init)
    init_opt_state = optimizer.init(init_params)
    training_state = TrainingState(params=init_params, opt_state=init_opt_state)

    for epoch in range(10):
        epoch_loss = 0
        for x, y in trainloader:
            x = np.array(x.view(x.shape[0], x.shape[2], x.shape[3], x.shape[1])).astype(np.float32) # NHWC
            y = np.array(y.view(y.shape[0], 1, 1))
            training_state, loss, state = update_weights(training_state, x, y, state, init_rng)
            epoch_loss += loss

        print(f"Loss on epoch: {epoch} was {epoch_loss}")

In [None]:
main()