In [39]:
import jax.numpy as jnp
import jax
from flax import nnx
import orbax.checkpoint as ocp

In [3]:
class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

In [7]:
model = Linear(din = 3, dout = 2, rngs=nnx.Rngs(params=0))
model(x = jnp.ones((1, 3)))

Array([[1.9137444, 2.068491 ]], dtype=float32)

'''


        def conv3x3(in_channels, out_channels, stride=1, 
            padding=1, bias=True, groups=1):    
                
        return nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=padding,
                bias=bias,
                groups=groups)

'''

In [21]:
def conv3x3(in_channels, out_channels, stride = 1, padding = 1, bias = True):
    return nnx.Conv(in_features= in_channels, 
                    out_features= out_channels, 
                    kernel_size= 3, 
                    strides= stride, 
                    padding= padding, 
                    use_bias= bias,
                    rngs= nnx.Rngs(params=jax.random.key(0)))

In [22]:
def upconv2x2(in_channels, out_channels, mode='transpose'):
    '''
    Upsample not implemented, need to find alternative
    '''
    return nnx.ConvTranspose(in_features= in_channels,
                                          out_features= out_channels,
                                          kernel_size= 2,
                                          strides= 2,
                                          rngs= nnx.Rngs(params=jax.random.key(0)))

In [23]:
def conv1x1(in_channels, out_channels):
    return nnx.Conv(in_features= in_channels, 
                    out_features= out_channels, 
                    kernel_size= 1, 
                    strides= 1,
                    rngs= nnx.Rngs(params=jax.random.key(0)))

In [24]:
x = jnp.ones((1, 3))
conv3x3(3, 3)(x)


Array([[-0.12897272,  0.27571398,  1.1490479 ]], dtype=float32)

In [None]:
nnx.relu()

In [37]:
class DownConv(nnx.Module):
 
    def __init__(self, in_channels, out_channels, pooling = True):

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)
        self.conv3 = conv3x3(self.out_channels, self.out_channels)

        # if self.pooling:
        #     self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def __call__(self, x):
        
        xskip = self.conv1(x)
        x = nnx.relu(self.conv2(xskip))
        x = nnx.relu(self.conv3(x) + xskip)
        # x = F.dropout(x)
        before_pool = x
        x = jax.lax.cond(self.pooling, 
                         nnx.max_pool(x, window_shape= (2, 2), strides= (2, 2)),
                         x)
        # if self.pooling:
        #     x = nnx.max_pool(x, window_shape= (2, 2), strides= (2, 2))
        return x, before_pool


In [38]:
class UpConv(nnx.Module):
   
    def __init__(self, in_channels, out_channels, 
                 merge_mode='concat', up_mode='transpose'):

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels, 
            mode=self.up_mode)

        self.conv1 = jax.lax.cond(self.merge_mode == 'concat',
                                  conv3x3(2*self.out_channels, self.out_channels),
                                  conv3x3(self.out_channels, self.out_channels))
        
        # if self.merge_mode == 'concat':
        #     self.conv1 = conv3x3(2*self.out_channels, self.out_channels)
        # else:
        #     self.conv1 = conv3x3(self.out_channels, self.out_channels)

        self.conv2 = conv3x3(self.out_channels, self.out_channels)
        self.conv3 = conv3x3(self.out_channels, self.out_channels)


    def __call__(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        x = jax.lax.cond(self.merge_mode == 'concat',
                         jnp.concatenate([from_up, from_down], axis = 1),
                         from_up + from_down)
        
        # if self.merge_mode == 'concat':
        #     x = jnp.concatenate([from_up, from_down], axis = 1) # check axis channel is last for jax
        # else:
        #     x = from_up + from_down

        xskip = self.conv1(x)
        x = nnx.relu(self.conv2(xskip))
        x = nnx.relu(self.conv3(x) + xskip)
        # x = F.dropout(x)
        return x
