---
title: "Implementing a ResNet34 in JAX for fun ✨"
author:
  - name: "Tugdual Kerjan"
    url: https://tugdual.fr
    email: tkerjan@outlook.com
date: "November 9, 2024"
number-sections: true
reference-location: margin
toc: true
format: 
  html:
    standalone: true
    embed-resources: true
    self-contained-math: true
    code-fold: false
    code-tools: true
execute:
  output:
    false
bibliography: assets/bib.bibtex
theme: united
github: "https://github.com/TugdualKerjan/ResNet-for-JAX"
lightbox: true
---

For the full project, visit the [GitHub repository](https://github.com/TugdualKerjan/ResNet-for-JAX).

# Context 👀

I'm trying to rewrite XTTS in JAX to understand how it works. 


We are going to implement ResNet, first mentionned in [@he2015deepresiduallearningimage]. This is used as part of a Text to Speech model written by the defunct Coqai company [@casanova2024xttsmassivelymultilingualzeroshot]. The goal of this model is to take in speech, which gives information about the voice of the person we're trying to reproduce, and output some latent representation that will condition HiFiGAN, the model that takes in a high level representation of mumbo jumbo from the GPT2 and spits out Audio ! ResNet is relatively straight forward, it's:

::: {.column-margin}

![A high level overview of a ResNet from [@dumakude2023automated_image]](assets/encode.png)

:::

__A lot of convolutional layers__

The first layer would look for dots that compose a line, the second for lines that compose a shape, the third for a series of shapes that compose a face... Basically these can be seen as layers of abstraction !

__Residual networks__

Residual networks mean taking the input and feeding it to the output. What the network has left to interpret would be the residual, i.e. the "rest". This also allows the network to set the weights of that layer to 0 if it considers that this added layer isn't necessary !

![Image showing a possible configuration where the input is passed to the output [@he2015deepresiduallearningimage]](assets/res.png)

__Squeeze and Excite__

The ResNet we're implenting here has been spiced up as it also uses sqeeze and excite modules [@hu2019squeezeandexcitationnetworks] to exchange information between "channels" (think red green and blue in images for example). This allows more context to be shared in the network for better abstraction of what's really important !

 ![Image showing how the Squeeze and Excite adds information [@hu2019squeezeandexcitationnetworks]](assets/se.png)

# Goal 🎯

Get a ResNet that can classify MNIST data as a proof that it works !



# Model

We code from bottom to top: First the Squeeze and Excite layer, then the Residual 'block' it's a part of, and finally the various layers.

We start by importing our favorite libraries:

In [1]:
import jax
import jax.numpy as jnp
import equinox as eqx
import equinox.nn as nn
import typing as tp

## SEBlock

We start by implementing the Sqeeze and Excite layer. A good explanation for this is provided here : https://amaarora.github.io/fastexplain/2020/07/24/SeNet.html

This excerpt might help to make more sense of this:

    We expect the learning of convolutional features to be enhanced by explicitly modelling channel interdependencies, so that the network is able to increase its sensitivity to informative features which can be exploited by subsequent transformations. Consequently, we would like to provide it with access to global information and recalibrate filter responses in two steps, squeeze and excitation, before they are fed into the next transformation.

In [4]:
import equinox as eqx
import jax


class SELayer(eqx.Module):
    fc1: eqx.nn.Linear
    fc2: eqx.nn.Linear

    def __init__(self, channel, reduction=8, key=None):
        key1, key2 = jax.random.split(key, 2)
        self.fc1 = eqx.nn.Linear(channel, channel // reduction, use_bias=True, key=key1)
        self.fc2 = eqx.nn.Linear(channel // reduction, channel, use_bias=True, key=key2)

    def __call__(self, x):
        y = eqx.nn.AdaptiveAvgPool2d(1)(x)
        y = jax.numpy.squeeze(y)
        y = self.fc1(y)
        y = jax.nn.relu(y)
        y = self.fc2(y)
        y = jax.nn.sigmoid(y)
        y = jax.numpy.expand_dims(y, (1, 2))

        return x * y

We can test that this works correctly with some quick code below

In [None]:
# | code-fold: true

import torch
import torch.nn as nn
import numpy as np


# Define PyTorch version of SELayer
class SELayerTorch(nn.Module):
    def __init__(self, channel, reduction=8):
        super(SELayerTorch, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

In [None]:
# | code-fold: true
import jax
import torch

import numpy


@jax.grad
# @jax.jit
def loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)  # L2 Loss


# loss = jax.grad(loss)

x_key, y_key, model_key = jax.random.split(jax.random.PRNGKey(0), 3)
# Example data
x = jax.random.normal(x_key, (1, 100, 4, 4))
y = jax.random.normal(y_key, (1, 100))

tor = SELayerTorch(100, reduction=2)
model = SELayer(100, reduction=2, key=model_key)

res_torch = tor(torch.from_numpy(numpy.array(x)))
res = jax.vmap(model)(x)
# res = torch.from_numpy(numpy.array(res))
# print(res_torch.type)
# # print(res[0,0,0,0])
# # print(res_torch[0,0,0,0])

# assert torch.testing.assert_close(res_torch, res)

<built-in method type of Tensor object at 0x32af17f10>


AssertionError: Tensor-likes are not close!

Mismatched elements: 1599 / 1600 (99.9%)
Greatest absolute difference: 0.17174112796783447 at index (0, 39, 0, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.20674952864646912 at index (0, 43, 0, 0) (up to 1.3e-06 allowed)

## ResBlock

We can move onto the ResBlock, which uses the SEBlock and implements the concept we saw earlier about Residuals. 

The current issues with Models is that they fail to approxmiate simple funcitions when sufficiently deep because of **vanishing gradients** and the **curse of dimensionality**. Simple shallow ones function though. So why not skip some layers to match the accuracy of the shallow ones ? Residual blocks can do this easily by setting the weights of a layer to 0 and simply letting the input be passed to the output.

::: {.column-margin}

![The image we saw previously illustrating the concept of Residual networks](assets/res.png)

:::

When observing the image on the side we notice that the network can learn the identity function by simply setting $$f(x) = 0$$

In [22]:
import jax
import equinox as eqx


class SEBasicBlock(eqx.Module):
    conv1: eqx.nn.Conv2d
    conv2: eqx.nn.Conv2d
    bn1: eqx.nn.BatchNorm
    bn2: eqx.nn.BatchNorm
    se: SELayer
    downsample: None

    def __init__(self, channels_in, channels_out, stride=1, downsample=None, key=None):
        key1, key3, key5 = jax.random.split(key, 3)

        # TODO Understand why bias isn't added.
        # TODO Do we want to have a state or simply do GroupNorm instead ?

        self.conv1 = eqx.nn.Conv2d(
            channels_in,
            channels_out,
            kernel_size=(3, 3),
            stride=stride,
            padding=1,
            use_bias=False,
            key=key1,
        )
        self.bn1 = eqx.nn.BatchNorm(channels_out, axis_name="batch")
        self.conv2 = eqx.nn.Conv2d(
            channels_out,
            channels_out,
            kernel_size=(3, 3),
            padding=1,
            use_bias=False,
            key=key3,
        )
        self.bn2 = eqx.nn.BatchNorm(channels_out, axis_name="batch")

        self.se = SELayer(channels_out, key=key5)
        self.downsample = downsample

    def __call__(self, x, state):
        residual = x

        y = self.conv1(x)

        y = jax.nn.relu(y)
        y, state = self.bn1(y, state)

        y = self.conv2(y)
        y, state = self.bn2(y, state)

        y = self.se(y)

        if self.downsample is not None:
            residual, state = self.downsample(x, state)

        y = y + residual  # Residual
        y = jax.nn.relu(y)

        return y, state

## ResNet

We can now move onto the building of the actual network ! We're going to create ResNet for audio using the SEBasicBlock we created. XTTS bases themselves on [@heo2020clovabaselinevoxcelebspeaker] for the model. In the paper they show how they combine multiple layers to embed an image (Mel spectrogram in our case) into a latent vector.

![The layers of the architecture proposed in [@heo2020clovabaselinevoxcelebspeaker]](assets/clove.png)

The stride at the first layer is removed compared to ResNet-34. Attentive Statistic Pooling [@Okabe_2018] is used to aggregate temporal frames. The channel-wise __weighted standard deviation__ is calculated in addition to the __weighted mean__. This is based on the results showing that information like this is useful when using attention !

Below, we add a create_layer method directly taken from XTTS to help with the initialization process.



In [None]:
import jax
import equinox as eqx
import jax.tools


class ResNet(eqx.Module):
    conv1: eqx.nn.Conv2d
    batch_norm: eqx.nn.BatchNorm

    layer1: list
    layer2: list
    layer3: list
    layer4: list

    instance_norm: eqx.nn.GroupNorm

    attention_conv1: eqx.nn.Conv1d
    attention_batch_norm: eqx.nn.BatchNorm
    attention_conv2: eqx.nn.Conv1d

    fc: eqx.nn.Linear

    def create_layer(self, channels_in, channels_out, layers, stride=1, key=None):

        downsample = None
        if type(stride) == int or channels_in != channels_out:
            key, grab = jax.random.split(key, 2)
            downsample = eqx.nn.Sequential(
                [
                    eqx.nn.Conv2d(
                        channels_in,
                        channels_out,
                        kernel_size=1,
                        stride=stride,
                        use_bias=False,
                        key=grab,
                    ),
                    eqx.nn.BatchNorm(
                        channels_out, axis_name="batch", channelwise_affine=False
                    ),
                ]
            )

        stack_of_blocks = []
        # print(key)
        key, grab = jax.random.split(key, 2)

        stack_of_blocks.append(
            SEBasicBlock(channels_in, channels_out, stride, downsample, key=grab)
        )
        for _ in range(1, layers):

            key, grab = jax.random.split(key, 2)
            stack_of_blocks.append(
                SEBasicBlock(channels_out, channels_out, stride=1, key=grab)
            )

        return stack_of_blocks

    def __init__(
        self,
        input_dims,
        proj_dim,
        layers=[3, 4, 6, 3],
        num_filters=[32, 64, 128, 256],
        key=None,
    ):
        # he_init = jax.nn.initializers.variance_scaling(scale=2.0, mode="fan_out", distribution="truncated_normal")

        key, grab = jax.random.split(key, 2)
        # TODO self.conv1 = eqx.nn.Conv2d(1, num_filters[0], key=grab, weight_init=he_init)
        self.conv1 = eqx.nn.Conv2d(
            1, num_filters[0], kernel_size=3, padding=1, key=grab
        )
        self.batch_norm = eqx.nn.BatchNorm(
            num_filters[0], axis_name="batch", channelwise_affine=False
        )

        key, key1, key2, key3, key4 = jax.random.split(key, 5)
        self.layer1 = self.create_layer(
            num_filters[0], num_filters[0], layers[0], key=key2
        )
        self.layer2 = self.create_layer(
            num_filters[0], num_filters[1], layers[1], stride=(2, 2), key=key3
        )
        self.layer3 = self.create_layer(
            num_filters[1], num_filters[2], layers[2], stride=(2, 2), key=key4
        )
        self.layer4 = self.create_layer(
            num_filters[2], num_filters[3], layers[3], stride=(2, 2), key=key1
        )

        # Instance norm seems to be a specific example of groupnorm.
        self.instance_norm = eqx.nn.GroupNorm(1, channelwise_affine=False)

        key, key1, key2 = jax.random.split(key, 3)

        # Basically a FFN but without needing to deal with the channel dimensions.
        # doesn't really explain the lowering of dimensions in the middle though...
        current_channel_size = int(num_filters[3] * input_dims / 8)
        self.attention_conv1 = eqx.nn.Conv1d(
            current_channel_size, 128, kernel_size=1, key=key1
        )
        self.attention_batch_norm = eqx.nn.BatchNorm(
            128, axis_name="batch", channelwise_affine=False
        )
        self.attention_conv2 = eqx.nn.Conv1d(
            128, current_channel_size, kernel_size=1, key=key2
        )
        # TODO  nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

        # Encoder type is ASP, thus the current dims are B, Input_dim / 8 because of the 4 layers,  * 2 * output of layer4.
        self.fc = eqx.nn.Linear(current_channel_size * 2, proj_dim, key=key)

    def __call__(self, x, state):
        y = x

        # We expect a mel spectrogram as input for now.
        # y = self.torch_spec(y)
        # print(y.shape)
        y = self.instance_norm(y)
        # y = jax.numpy.expand_dims(y, 0)
        # print(y.shape)

        y = self.conv1(y)
        # print(y.shape)

        y = jax.nn.relu(y)
        y, state = self.batch_norm(y, state)
        # print(y.shape)
        # y, state = self.test(y, state)
        for block in self.layer1:
            y, state = block(y, state)
            # print(y.shape)

        for block in self.layer2:
            y, state = block(y, state)

        for block in self.layer3:
            y, state = block(y, state)

        for block in self.layer4:
            y, state = block(y, state)

        y = jax.numpy.reshape(y, (-1, y.shape[-1]))

        # TODO not really justified...
        w = self.attention_conv1(y)
        w = jax.nn.relu(w)
        w, state = self.attention_batch_norm(w, state)
        w = self.attention_conv2(w)  # W represents the
        w = jax.nn.softmax(w, axis=1)

        mu = jax.numpy.sum(y * w, axis=1)
        sg = jax.numpy.sqrt(jax.numpy.sum((y**1) * w, axis=1) - mu**2)
        sg = jax.lax.clamp(min=1e-5, x=sg, max=jax.numpy.float32(10))

        y = jax.lax.concatenate((mu, sg), 0)

        y = self.fc(y)

        return y, state

In [None]:
@eqx.filter_jit
def loss(model, state, x, y):
    res, state = jax.vmap(
        model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    return jax.numpy.mean((res - y) ** 2), state


loss = eqx.filter_grad(loss, has_aux=True)

key = jax.random.PRNGKey(seed=69)
key1, key2, key3 = jax.random.split(key, 3)
x = jax.random.normal(key1, (1, 1, 64, 32)).astype(jax.numpy.float32)
y = jax.random.normal(key2, (1, 512)).astype(jax.numpy.float32)

model, state = eqx.nn.make_with_state(ResNet)(64, 512, key=key3)
grads, state = loss(model, state, x, y)

Other things that I could implement to get a better version:

::: {.column-margin}

![Improvements suggested in @bello2021revisitingresnetsimprovedtraining](assets/improve.png)

:::