In [None]:
!pip3 install dm-haiku
!pip3 install optax

In [15]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import jax
import jax.numpy as jnp
import haiku as hk
import optax
import numpy as np
from typing import NamedTuple

In [7]:
def load_dataset():
    transformn = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    batch_size = 16
    trainset = datasets.MNIST(root='./data', train=True,
                              download=True, transform=transformn)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)
    testset = datasets.MNIST(root='./data', train=False,
                              download=True, transform=transformn)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                              shuffle=False, num_workers=2)

    x_init = np.random.randn(224, 224, 3)
    return trainloader, x_init

In [41]:
class ResNetBlock(nn.Module):
    def __init__(self, out_channels, stride=1, use_1x1_conv=False):
        super().__init__()
        self.conv1 = hk.Conv2D(output_channels=out_channels, kernel_shape=3, stride=stride, padding='VALID', bias=False)
        self.bn1 = hk.BatchNorm(out_channels)
        self.conv2 = hk.Conv2D(output_channels=out_channels, kernel_shape=3, stride=1, padding='VALID', bias=False)
        if use_1x1_conv:
            self.conv3 = hk.Conv2D(output_channels=out_channels, kernel_shape=1, stride=stride, bias=False)
        else:
            self.conv3 = None
        self.bn2 = hk.BatchNorm(out_channels)

    def forward(self, x):
        out = jax.nn.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.conv3:
            x = self.conv3(x)
        return jax.nn.relu((out + x))

In [42]:
class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = hk.Conv2D(output_channels=64, kernel_shape=7, stride=1, padding='VALID')
        self.bn1 = hk.BatchNorm(64)
        self.maxpool = hk.MaxPool2D(kernel_size=3, stride=2, padding=1)
        self.l1 = self._make_layer(64, 2, first_layer=True)
        self.l2 = self._make_layer(128, 2)
        self.l3 = self._make_layer(256, 2)
        self.l4 = self._make_layer(512, 2)
        self.fc = hk.Linear(10)

    def _make_layer(self, channels, num_blocks, first_layer=False):
        layers = []
        for b in range(num_blocks):
            if b == 0 and not first_layer:
                layers.append(ResNetBlock(channels, stride=2, use_1x1_conv=True))
            else:
                layers.append(ResNetBlock(channels))
        return hk.Sequential(layers)

    def forward(self, x):
        out = jax.nn.relu(self.bn1(self.conv1(x)))
        out = self.l1(out)
        out = self.l2(out)
        out = self.l3(out)
        out = self.l4(out)
        out = out.mean(3).mean(2)
        out = self.fc(out)
        return out

In [43]:
class TrainingState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState

def net_fn(x):
    return ResNet18()(x)

lossfn = optax.softmax_cross_entropy_with_integer_labels

def main():
    network = hk.without_apply_rng(hk.transform(net_fn))
    optimizer = optax.sgd(1e-3)
    init_rng = jax.random.PRNGKey(42)

    def loss_fn(params, x, y):
        out = network.apply(params, x)
        loss = lossfn(out, y)
        return loss

    def update_weights(params, x, y):
        grads, loss = jax.value_and_grad(loss_fn)(params, x, y)
        params, opt_state = optimizer.update(grads, state.opt_state)
        return TrainingState(params, opt_state), loss

    trainloader, x_init = load_dataset()
    init_params = network.init(init_rng, x_init)
    init_opt_state = optimizer.init(init_params)
    state = nn.TrainingState(params=init_params, opt_state=init_opt_state)

    for epoch in range(10):
        epoch_loss = 0
        for x, y in trainloader:
            x = np.array(x)
            y = np.array(y)
            state, loss = update_weights(state.params, x, y)
            epoch_loss += loss

        print(f"Loss on epoch: {epoch} was {epoch_loss}")

In [44]:
main()

TypeError: ignored