In [3]:
%reload_ext autoreload
%autoreload 2

from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections

%cd /hildafs/projects/phy230056p/junzhez/AI/maskgit
%ls
import maskgit
from maskgit.utils import visualize_images, read_image_from_url, draw_image_with_bbox, Bbox
from maskgit.inference import ImageNet_class_conditional_generator
from maskgit.nets import layers

/hildafs/projects/phy230056p/junzhez/AI/maskgit
[0m[38;5;33mcheckpoints[0m/     [38;5;33mgmmg[0m/  LICENSE   MaskGIT_demo.ipynb  requirements.txt
CONTRIBUTING.md  [38;5;33mimgs[0m/  [38;5;33mmaskgit[0m/  README.md


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


### Testing Resblock and Encoder

In [4]:
from maskgit.nets import layers
from maskgit.nets.vqgan_tokenizer import ResBlock, Encoder

In [6]:
from ml_collections import ConfigDict
def get_config():
    config = ConfigDict()
    config.vqvae = ConfigDict()
    config.vqvae.filters = 64
    config.vqvae.num_res_blocks = 2
    config.vqvae.channel_multipliers = [1, 2, 4]
    config.vqvae.embedding_dim = 128
    config.vqvae.conv_downsample = True
    config.vqvae.norm_type = 'BN'
    config.vqvae.activation_fn = 'relu'
    return config

config = get_config()

In [8]:
# Create Dummy Input Tensor
batch_size = 32
height = 256
width = 256
channels = 3
dummy_input = jnp.ones((batch_size, height, width, channels), dtype=jnp.float32)

# Initialize and Test the Encoder
encoder = Encoder(config=config, train=False)

# Initialize the parameters
rng = jax.random.PRNGKey(0)
params = encoder.init(rng, dummy_input)

# Apply the encoder to the dummy input
encoded_output = encoder.apply(params, dummy_input)

# Print the shapes of the input and output tensors
print("Input shape:", dummy_input.shape)
print("Output shape:", encoded_output.shape)

Input shape: (32, 256, 256, 3)
Output shape: (32, 64, 64, 128)


In [11]:
print('encoder params:', lambda x: f'{x.keys(), x.shape}', params)

encoder params: <function <lambda> at 0x1502f0119510> {'params': {'Conv_0': {'kernel': Array([[[[-0.10665826,  0.0840695 , -0.06965522, ...,  0.15173201,
          -0.4253455 ,  0.09675467],
         [ 0.17379448, -0.1942932 , -0.37668937, ...,  0.11859114,
          -0.1405779 , -0.20874853],
         [ 0.12931602,  0.20022647, -0.24547762, ..., -0.16336353,
          -0.20042422, -0.17801431]],

        [[ 0.04695836, -0.16743578, -0.28042   , ...,  0.16960709,
          -0.35830006,  0.07735524],
         [-0.05446373,  0.32084453,  0.16373922, ...,  0.00248188,
           0.13764007, -0.28850147],
         [ 0.23693158,  0.09874827,  0.18078464, ...,  0.28776744,
          -0.33542386,  0.09036874]],

        [[ 0.2505507 , -0.14395808, -0.00970777, ...,  0.04533508,
          -0.13314351, -0.1375708 ],
         [ 0.28502318, -0.07675822,  0.10134856, ...,  0.2452984 ,
          -0.06979237,  0.21082537],
         [-0.10497877, -0.36996925, -0.3380151 , ..., -0.24047703,
          

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import ml_collections

# You might need to define or import the layers.get_norm_layer function
# For simplicity, we'll define a basic one here
class NormLayer(nn.Module):
    def __call__(self, x):
        return x

def get_norm_layer(train, dtype, norm_type):
    return NormLayer



In [None]:
class ResBlock(nn.Module):
    """Basic Residual Block."""
    filters: int
    norm_fn: Any
    conv_fn: Any
    dtype: int = jnp.float32
    activation_fn: Any = nn.relu
    use_conv_shortcut: bool = False

    @nn.compact
    def __call__(self, x):
        input_dim = x.shape[-1]
        residual = x
        x = self.norm_fn()(x)
        x = self.activation_fn(x)
        x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x)
        x = self.norm_fn()(x)
        x = self.activation_fn(x)
        x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x)

        if input_dim != self.filters:
            if self.use_conv_shortcut:
                residual = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x)
            else:
                residual = self.conv_fn(self.filters, kernel_size=(1, 1), use_bias=False)(x)
        return x + residual

# Example input
x = jnp.ones((1, 32, 32, 64))  # Batch size 1, 32x32 image, 64 channels
res_block = ResBlock(filters=128, norm_fn=NormLayer, conv_fn=nn.Conv)
variables = res_block.init(jax.random.PRNGKey(0), x)
output = res_block.apply(variables, x)
print("ResBlock output shape:", output.shape)


In [None]:
class Encoder(nn.Module):
    """Encoder Blocks."""
    config: ml_collections.ConfigDict
    train: bool
    dtype: int = jnp.float32

    def setup(self):
        self.filters = self.config.vqvae.filters
        self.num_res_blocks = self.config.vqvae.num_res_blocks
        self.channel_multipliers = self.config.vqvae.channel_multipliers
        self.embedding_dim = self.config.vqvae.embedding_dim
        self.conv_downsample = self.config.vqvae.conv_downsample
        self.norm_type = self.config.vqvae.norm_type
        if self.config.vqvae.activation_fn == "relu":
            self.activation_fn = nn.relu
        elif self.config.vqvae.activation_fn == "swish":
            self.activation_fn = nn.swish
        else:
            raise NotImplementedError

    @nn.compact
    def __call__(self, x):
        conv_fn = nn.Conv
        norm_fn = get_norm_layer(train=self.train, dtype=self.dtype, norm_type=self.norm_type)
        block_args = dict(
            norm_fn=norm_fn,
            conv_fn=conv_fn,
            dtype=self.dtype,
            activation_fn=self.activation_fn,
            use_conv_shortcut=False,
        )
        x = conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x)
        num_blocks = len(self.channel_multipliers)
        for i in range(num_blocks):
            filters = self.filters * self.channel_multipliers[i]
            for _ in range(self.num_res_blocks):
                x = ResBlock(filters, **block_args)(x)
            if i < num_blocks - 1:
                if self.conv_downsample:
                    x = conv_fn(filters, kernel_size=(4, 4), strides=(2, 2))(x)
                else:
                    x = nn.avg_pool(x, (2, 2), strides=(2, 2))
        for _ in range(self.num_res_blocks):
            x = ResBlock(filters, **block_args)(x)
        x = norm_fn()(x)
        x = self.activation_fn(x)
        x = conv_fn(self.embedding_dim, kernel_size=(1, 1))(x)
        return x

# Example config
config = ml_collections.ConfigDict()
config.vqvae = ml_collections.ConfigDict()
config.vqvae.filters = 64
config.vqvae.num_res_blocks = 2
config.vqvae.channel_multipliers = [1, 2, 4]
config.vqvae.embedding_dim = 256
config.vqvae.conv_downsample = True
config.vqvae.norm_type = 'batch'
config.vqvae.activation_fn = 'relu'

# Example input
x = jnp.ones((1, 64, 64, 3))  # Batch size 1, 64x64 image, 3 channels
encoder = Encoder(config=config, train=True)
variables = encoder.init(jax.random.PRNGKey(0), x)
output = encoder.apply(variables, x)
print("Encoder output shape:", output.shape)


In [None]:
class Decoder(nn.Module):
    """Decoder Blocks."""
    config: ml_collections.ConfigDict
    train: bool
    output_dim: int = 3
    dtype: Any = jnp.float32

    def setup(self):
        self.filters = self.config.vqvae.filters
        self.num_res_blocks = self.config.vqvae.num_res_blocks
        self.channel_multipliers = self.config.vqvae.channel_multipliers
        self.norm_type = self.config.vqvae.norm_type
        if self.config.vqvae.activation_fn == "relu":
            self.activation_fn = nn.relu
        elif self.config.vqvae.activation_fn == "swish":
            self.activation_fn = nn.swish
        else:
            raise NotImplementedError

    @nn.compact
    def __call__(self, x):
        conv_fn = nn.Conv
        norm_fn = get_norm_layer(train=self.train, dtype=self.dtype, norm_type=self.norm_type)
        block_args = dict(
            norm_fn=norm_fn,
            conv_fn=conv_fn,
            dtype=self.dtype,
            activation_fn=self.activation_fn,
            use_conv_shortcut=False,
        )
        num_blocks = len(self.channel_multipliers)
        filters = self.filters * self.channel_multipliers[-1]
        x = conv_fn(filters, kernel_size=(3, 3), use_bias=True)(x)
        for _ in range(self.num_res_blocks):
            x = ResBlock(filters, **block_args)(x)
        for i in reversed(range(num_blocks)):
            filters = self.filters * self.channel_multipliers[i]
            for _ in range(self.num_res_blocks):
                x = ResBlock(filters, **block_args)(x)
            if i > 0:
                x = nn.ConvTranspose(filters, kernel_size=(3, 3), strides=(2, 2))(x)
        x = norm_fn()(x)
        x = self.activation_fn(x)
        x = conv_fn(self.output_dim, kernel_size=(3, 3))(x)
        return x

# Example input
x = jnp.ones((1, 8, 8, 256))  # Batch size 1, 8x8 feature map, 256 channels
decoder = Decoder(config=config, train=True)
variables = decoder.init(jax.random.PRNGKey(0), x)
output = decoder.apply(variables, x)
print("Decoder output shape:", output.shape)
