In [6]:
%reload_ext autoreload
%autoreload 2

from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import random
import ml_collections

%cd /hildafs/projects/phy230056p/junzhez/AI/maskgit
%ls
import maskgit
from maskgit.nets import layers

# Set a random seed for reproducibility
key = random.PRNGKey(42)
print(jax.devices())

/hildafs/projects/phy230056p/junzhez/AI/maskgit
[0m[38;5;33mcheckpoints[0m/     [38;5;33mgmmg[0m/  LICENSE   MaskGIT_demo.ipynb  requirements.txt
CONTRIBUTING.md  [38;5;33mimgs[0m/  [38;5;33mmaskgit[0m/  README.md
[cuda(id=0)]


### _l2_normalize function
- each row of x is normalized to have unit L2 norm
$$
\texttt{l2normalize}(x) = \frac{x}{\sqrt{\sum_{i=1}^{n} x_i^2 + \epsilon}}
$$

In [10]:
# Test _l2_normalize function: each row of x is normalized to have unit L2 norm
x = random.normal(key, (2, 3))
print("Original array:")
print(x)
normalized_x = layers._l2_normalize(x, axis=-1)
print("L2-normalized array:")
print(normalized_x, jnp.linalg.norm(normalized_x, axis=-1))

Original array:
[[ 0.6122652   1.1225883   1.1373317 ]
 [-0.8127325  -0.890405    0.12623145]]
L2-normalized array:
[[ 0.35777485  0.65598017  0.6645954 ]
 [-0.6704925  -0.73457116  0.1041391 ]] [0.99999994 0.99999994]


### get_norm_layer function
- returns a normalization layer, depending on the `norm_type` specified
- Batch Normalization (BN)
$$
y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta
$$
- Layer Normalization (LN)
$$
y = \frac{x - \text{mean}(x)}{\sqrt{\text{var}(x) + \epsilon}} \cdot \gamma + \beta
$$
- Group Normalization (GN)
$$
y = \frac{x - \text{mean}_{\text{group}}(x)}{\sqrt{\text{var}_{\text{group}}(x) + \epsilon}} \cdot \gamma + \beta
$$

- The input tensor `x` typically has shape `(B, H, W, C)` or `(B, C, H, W)`, where `B` is the batch size, `H` is the height, `W` is the width, and `C` is the number of channels.
- $\mu$ and $\sigma$ are the mean and standard deviation of the input tensor `x`, e.g., across the batch dimension.



In [22]:
# Create a random input tensor
x = random.normal(key, (2, 4, 4, 3))  # Example shape: (batch_size, height, width, channels)
print("Input tensor shape")
print(x.shape, '(shape: batch_size x height x width x channels)')

norm_layer_fn = layers.get_norm_layer(train=True, dtype=jnp.float32, norm_type='BN')
norm_layer = norm_layer_fn()
print(norm_layer)

Input tensor shape
(2, 4, 4, 3) (shape: batch_size x height x width x channels)
BatchNorm(
    # attributes
    use_running_average = False
    axis = -1
    momentum = 0.9
    epsilon = 1e-05
    dtype = float32
    param_dtype = float32
    use_bias = True
    use_scale = True
    bias_init = zeros
    scale_init = ones
    axis_name = None
    axis_index_groups = None
    use_fast_variance = True
    force_float32_reductions = True
)


### tensorflow_style_avg_pooling function
- This function performs average pooling in a manner similar to TensorFlow, excluding padding cells from the average calculation.
- Pooling is a layer in convolutional neural networks (CNNs) that reduces the spatial dimensions of input data while keeping the most important information. Pooling layers are also known as downsample layers.
- **Mathematical Expression:**
$$
y_{i,j} = \frac{1}{|P_{i,j}|} \sum_{(m,n) \in P_{i,j}} x_{m,n}
$$
where $P_{i,j}$ is the pooling window centered at $(i,j)$, and $|P_{i,j}|$ is the number of elements in $P_{i,j}$.

**In this context, more specifically:**
- Given:
  - Input tensor `x` with shape `(N, H, W, C)`
  - Pooling window shape `(hw, ww)`
  - Strides `(hs, ws)`. Physically, this means the window moves `hs` steps in the height dimension and `ws` steps in the width dimension.
  - Padding mode: either "SAME" or "VALID"
- The function does two main operations:
    1. **Sum the elements within each pooling window**: This is done using `jax.lax.reduce_window` with `jax.lax.add`.
    2. **Count the number of elements within each pooling window**: This is also done using `jax.lax.reduce_window` with `jax.lax.add` on a tensor of ones.
- **Formula Steps:**
    1. **Pooling Window Sum (pool_sum)**:
$$
\text{pool-sum}_{i, j, k, l} = \sum_{m=0}^{h_w-1} \sum_{n=0}^{w_w-1} x_{i, (j \cdot h_s + m), (k \cdot w_s + n), l}
$$
where $i$ indexes the batch, $j$ and $k$ index the spatial dimensions, and $l$ indexes the channels. This sum considers the elements within the pooling window.
    2. **Pooling Window Denominator (pool_denom)**:
$$
\text{pool-denom}_{i, j, k, l} = \sum_{m=0}^{h_w-1} \sum_{n=0}^{w_w-1} 1
$$
which simply counts the number of elements within each pooling window, excluding any padding cells.
    3. **Average Pooling Calculation**:
$$
y_{i, j, k, l} = \frac{\text{pool-sum}_{i, j, k, l}}{\text{pool-denom}_{i, j, k, l}}
$$
Here, $y$ is the output tensor after average pooling.

In [39]:
# Test tensorflow_style_avg_pooling function
x = random.normal(key, (1, 4, 4, 2)) 
x = jnp.arange(32).reshape(1, 4, 4, 2)
pooled_x = layers.tensorflow_style_avg_pooling(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')
print("Original tensor:, dimensions: batch_size x height x width x channels")
print(x.shape)
print("Pooled tensor:", "dimensions: batch_size x height x width x channels")
print(pooled_x.shape)
print('first channel of x')
print(x[0, :, :, 0])
print('first channel of pooled_x')
print(pooled_x[0, :, :, 0])

Original tensor:, dimensions: batch_size x height x width x channels
(1, 4, 4, 2)
Pooled tensor: dimensions: batch_size x height x width x channels
(1, 2, 2, 2)
first channel of x
[[ 0  2  4  6]
 [ 8 10 12 14]
 [16 18 20 22]
 [24 26 28 30]]
first channel of pooled_x
[[ 5.  9.]
 [21. 25.]]


### Upsample and Downsample
- Upsample: Increase the spatial dimensions of the input tensor
$$
y = \text{resize}(x, (n, h \cdot \text{factor}, w \cdot \text{factor}, c))
$$
- Downsample: Decrease the spatial dimensions of the input tensor
$$
y_{i,j} = \frac{1}{4} \sum_{(m,n) \in P_{i,j}} x_{m,n}
$$
- i.e., tensorflow_style_avg_pooling with `hw=2, ww=2, hs=2, ws=2`, i.e., window size is 2x2, stride is 2x2

In [43]:
# Test upsample function
x = jnp.arange(32).reshape(1, 4, 4, 2)
print("Original tensor:, dimensions: batch_size x height x width x channels")
print(x.shape)

upsampled_x = layers.upsample(x, factor=2)
print("Upsampled tensor:", "dimensions: batch_size x height x width x channels")
print(upsampled_x.shape)

downsampled_x = layers.dsample(x)
print("Downsampled tensor:", "dimensions: batch_size x height x width x channels")
print(downsampled_x.shape)

print('first channel of x')
print(x[0, :, :, 0])
print('first channel of upsampled_x')
print(upsampled_x[0, :, :, 0])
print('first channel of downsampled_x')
print(downsampled_x[0, :, :, 0])

Original tensor:, dimensions: batch_size x height x width x channels
(1, 4, 4, 2)
Upsampled tensor: dimensions: batch_size x height x width x channels
(1, 8, 8, 2)
Downsampled tensor: dimensions: batch_size x height x width x channels
(1, 2, 2, 2)
first channel of x
[[ 0  2  4  6]
 [ 8 10 12 14]
 [16 18 20 22]
 [24 26 28 30]]
first channel of upsampled_x
[[ 0  0  2  2  4  4  6  6]
 [ 0  0  2  2  4  4  6  6]
 [ 8  8 10 10 12 12 14 14]
 [ 8  8 10 10 12 12 14 14]
 [16 16 18 18 20 20 22 22]
 [16 16 18 18 20 20 22 22]
 [24 24 26 26 28 28 30 30]
 [24 24 26 26 28 28 30 30]]
first channel of downsampled_x
[[ 5.  9.]
 [21. 25.]]
