In [39]:
import jax.numpy as jnp
import jax
from flax import nnx
import orbax.checkpoint as ocp

In [3]:
class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

In [7]:
model = Linear(din = 3, dout = 2, rngs=nnx.Rngs(params=0))
model(x = jnp.ones((1, 3)))

Array([[1.9137444, 2.068491 ]], dtype=float32)

'''


        def conv3x3(in_channels, out_channels, stride=1, 
            padding=1, bias=True, groups=1):    
                
        return nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=padding,
                bias=bias,
                groups=groups)

'''

In [21]:
def conv3x3(in_channels, out_channels, stride = 1, padding = 1, bias = True):
    return nnx.Conv(in_features= in_channels, 
                    out_features= out_channels, 
                    kernel_size= 3, 
                    strides= stride, 
                    padding= padding, 
                    use_bias= bias,
                    rngs= nnx.Rngs(params=jax.random.key(0)))

In [22]:
def upconv2x2(in_channels, out_channels, mode='transpose'):
    '''
    Upsample not implemented, need to find alternative
    '''
    return nnx.ConvTranspose(in_features= in_channels,
                                          out_features= out_channels,
                                          kernel_size= 2,
                                          strides= 2,
                                          rngs= nnx.Rngs(params=jax.random.key(0)))

In [23]:
def conv1x1(in_channels, out_channels):
    return nnx.Conv(in_features= in_channels, 
                    out_features= out_channels, 
                    kernel_size= 1, 
                    strides= 1,
                    rngs= nnx.Rngs(params=jax.random.key(0)))

In [24]:
x = jnp.ones((1, 3))
conv3x3(3, 3)(x)


Array([[-0.12897272,  0.27571398,  1.1490479 ]], dtype=float32)

In [None]:
nnx.relu()

In [49]:
class DownConv(nnx.Module):
 
    def __init__(self, in_channels, out_channels, pooling = True):

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)
        self.conv3 = conv3x3(self.out_channels, self.out_channels)

        # if self.pooling:
        #     self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def __call__(self, x):
        
        xskip = self.conv1(x)
        x = nnx.relu(self.conv2(xskip))
        x = nnx.relu(self.conv3(x) + xskip)
        # x = F.dropout(x)
        before_pool = x
        x = jax.lax.cond(self.pooling, 
                         lambda x : nnx.max_pool(x, window_shape= (2, 2), strides= (2, 2)),
                         lambda x : x,
                         x)
        # if self.pooling:
        #     x = nnx.max_pool(x, window_shape= (2, 2), strides= (2, 2))
        return x, before_pool


In [43]:
class UpConv(nnx.Module):
   
    def __init__(self, in_channels, out_channels, 
                 merge_mode='concat', up_mode='transpose'):

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels, 
            mode=self.up_mode)

        self.conv1 = jax.lax.cond(self.merge_mode == 'concat',
                                  conv3x3(2*self.out_channels, self.out_channels),
                                  conv3x3(self.out_channels, self.out_channels))
        
        # if self.merge_mode == 'concat':
        #     self.conv1 = conv3x3(2*self.out_channels, self.out_channels)
        # else:
        #     self.conv1 = conv3x3(self.out_channels, self.out_channels)

        self.conv2 = conv3x3(self.out_channels, self.out_channels)
        self.conv3 = conv3x3(self.out_channels, self.out_channels)


    def __call__(self, from_down, from_up):
        
        from_up = self.upconv(from_up)
        x = jax.lax.cond(self.merge_mode == 'concat',
                         lambda _ : jnp.concatenate([from_up, from_down], axis = 1),
                         lambda _ : from_up + from_down)
        
        # if self.merge_mode == 'concat':
        #     x = jnp.concatenate([from_up, from_down], axis = 1) # check axis channel is last for jax
        # else:
        #     x = from_up + from_down

        xskip = self.conv1(x)
        x = nnx.relu(self.conv2(xskip))
        x = nnx.relu(self.conv3(x) + xskip)
        # x = F.dropout(x)
        return x


In [50]:
model = DownConv(3, 3)
model(jnp.ones((1,3, 3)))

TypeError: true_fun and false_fun output must have identical types, got
DIFFERENT ShapedArray(float32[0,1,3]) vs. ShapedArray(float32[1,3,3]).

In [45]:
model = UpConv(3, 3)
model(jnp.ones((1, 3)), jnp.ones((1, 3)))

TypeError: Conv.__call__() missing 1 required positional argument: 'inputs'

In [None]:
class UN(nnx.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).
    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """
    def __init__(self, levels, channels=3, depth=5, 
                 start_filts=64, up_mode='transpose', 
                 merge_mode='add'):
        
        self.up_mode = jax.lax.cond((up_mode == 'transpose') | (up_mode == 'upsample'),
                                    lambda _ : up_mode,
                                    lambda _ : None)
        if self.up_mode is None:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))

        # if up_mode in ('transpose', 'upsample'):
        #     self.up_mode = up_mode
        # else:
        #     raise ValueError("\"{}\" is not a valid mode for "
        #                      "upsampling. Only \"transpose\" and "
        #                      "\"upsample\" are allowed.".format(up_mode))

        self.merge_mode = jax.lax.cond((merge_mode == 'concat') | (merge_mode == 'add'),
                                    lambda _ : merge_mode,
                                    lambda _ : None)
        if self.merge_mode is None:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # if merge_mode in ('concat', 'add'):
        #     self.merge_mode = merge_mode
        # else:
        #     raise ValueError("\"{}\" is not a valid mode for"
        #                      "merging up and down paths. "
        #                      "Only \"concat\" and "
        #                      "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.levels = levels
        self.channels = channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []
        ## vmap? 
        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.channels * self.levels if i == 0 else outs
            outs = self.start_filts*(2**i)
        #    outs = self.start_filts
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
        #    outs = ins
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        self.conv_final = conv1x1(outs, self.channels)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nnx.initializers.xavier_normal(m.weight)
            nnx.initializers.constant(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)


    def __call__(self, x):
        epsilon = 1
        
        stack = None
        ## use of fourier transform to avoid transform
        ## no matter what value it lives in the sine wave
        ## the input is represented by sin transformation of different frequencies
        factor = 10.0
        for i in range (self.levels):
            scale = x.clone()*(factor**(-i))
            scale = torch.sin(scale)
            if stack is None:
                stack = scale
            else:
                stack = torch.cat((stack,scale),1)
        
        x = stack
        
        encoder_outs = []
         
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
        
        x = self.conv_final(x)
        return x



    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
        
        return {
           'optimizer': optimizer,
           'lr_scheduler': scheduler, 
           'monitor': 'val_loss'
        }
    
    
    def photonLoss(self,result, target):
        expEnergy = torch.exp(result)
        perImage =  -torch.mean(result*target, dim =(-1,-2,-3), keepdims = True )
        perImage += torch.log(torch.mean(expEnergy, dim =(-1,-2,-3), keepdims = True ))*torch.mean(target, dim =(-1,-2,-3), keepdims = True )
        return torch.mean(perImage)
    
    def MSELoss(self,result, target):
        expEnergy = torch.exp(result)
        expEnergy /= (torch.mean(expEnergy, dim =(-1,-2,-3), keepdims = True ))
        target = target / (torch.mean(target, dim =(-1,-2,-3), keepdims = True ))
        return torch.mean((expEnergy-target)**2)
    
    def training_step(self, batch, batch_idx):
        loss = self.photonLoss(self(batch[:,self.channels:,...]),batch[:,:self.channels,...] )
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.photonLoss(self(batch[:,self.channels:,...]),batch[:,:self.channels,...] )
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self.photonLoss(self(batch[:,self.channels:,...]),batch[:,:self.channels,...] )
        self.log("test_loss", loss)
