diff --git a/README.md b/README.md index c7e621de..33ae7a51 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ MaxDiffusion supports We recommend starting with a single TPU host and then moving to multihost. -Minimum requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0. +Minimum requirements: Ubuntu Version 22.04, Python 3.12 and Tensorflow >= 2.12.0. ## Getting Started: diff --git a/docs/getting_started/first_run.md b/docs/getting_started/first_run.md index 7f4ef9c0..aae4736a 100644 --- a/docs/getting_started/first_run.md +++ b/docs/getting_started/first_run.md @@ -10,19 +10,27 @@ multiple hosts. 1. [Create and SSH to a single-host TPU (v6-8). ](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud) * You can find here [here](https://cloud.google.com/tpu/docs/regions-zones) the list of zones that support the v6(Trillium) TPUs -* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0 +* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.12 and Tensorflow >= 2.12.0 1. Clone MaxDiffusion in your TPU VM. -```bash +``` git clone https://github.com/AI-Hypercomputer/maxdiffusion.git cd maxdiffusion ``` 1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running: -```bash +``` +# If a Python 3.12+ virtual environment doesn't already exist, you'll need to run the install command twice. bash setup.sh MODE=stable DEVICE=tpu ``` +1. Active your virtual environment: +``` +# Replace with your virtual environment name if not using this default name +venv_name="maxdiffusion_venv" +source ~/$venv_name/bin/activate +``` + ## Getting Starting: Multihost development [GKE, recommended] [Running MaxDiffusion with xpk](run_maxdiffusion_via_xpk.md) - Quick Experimentation and Production support diff --git a/requirements.txt b/requirements.txt index d51653fa..6d5e2902 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ grain google-cloud-storage>=2.17.0 absl-py datasets -flax>=0.10.2 +flax>=0.11.0 optax>=0.2.3 torch>=2.6.0 torchvision>=0.20.1 diff --git a/setup.sh b/setup.sh index 15932df8..24c6b80b 100644 --- a/setup.sh +++ b/setup.sh @@ -23,6 +23,46 @@ set -e export DEBIAN_FRONTEND=noninteractive +echo "Checking Python version..." +# This command will fail if the Python version is less than 3.12 +if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; then + # If the command fails, print an error + CURRENT_VERSION=$(python3 --version 2>&1) # Get the full version string + echo -e "\n\e[31mERROR: Outdated Python Version! You are currently using $CURRENT_VERSION, but MaxDiffusion requires Python version 3.12 or higher.\e[0m" + # Ask the user if they want to create a virtual environment with uv + read -p "Would you like to create a Python 3.12 virtual environment using uv? (y/n) " -n 1 -r + echo # Move to a new line after input + if [[ $REPLY =~ ^[Yy]$ ]]; then + # Check if uv is installed first; if not, install uv + if ! command -v uv &> /dev/null; then + pip install uv + fi + maxdiffusion_dir=$(pwd) + cd + # Ask for the venv name + read -p "Please enter a name for your new virtual environment (default: maxdiffusion_venv): " venv_name + # Use a default name if the user provides no input + if [ -z "$venv_name" ]; then + venv_name="maxdiffusion_venv" + echo "No name provided. Using default name: '$venv_name'" + fi + echo "Creating virtual environment '$venv_name' with Python 3.12..." + uv venv --python 3.12 "$venv_name" --seed + printf '%s\n' "$(realpath -- "$venv_name")" >> /tmp/venv_created + echo -e "\n\e[32mVirtual environment '$venv_name' created successfully!\e[0m" + echo "To activate it, run the following command:" + echo -e "\e[33m source ~/$venv_name/bin/activate\e[0m" + echo "After activating the environment, please re-run this script." + cd $maxdiffusion_dir + else + echo "Exiting. Please upgrade your Python environment to continue." + fi + # Exit the script since the initial Python check failed + exit 1 +fi +echo "Python version check passed. Continuing with script." +echo "--------------------------------------------------" + (sudo bash || bash) <<'EOF' mkdir -p /etc/needrestart/conf.d echo '$nrconf{restart} = "a";' > /etc/needrestart/conf.d/99-noninteractive.conf diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 3cbb0cce..df152133 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -795,8 +795,8 @@ def __init__( self.drop_out = nnx.Dropout(dropout) - self.norm_q = None - self.norm_k = None + self.norm_q = nnx.data(None) + self.norm_k = nnx.data(None) if qk_norm is not None: self.norm_q = nnx.RMSNorm( num_features=self.inner_dim, diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index 19244f72..0226a859 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -225,7 +225,7 @@ def __init__( ): self.dim = dim self.mode = mode - self.time_conv = None + self.time_conv = nnx.data(None) if mode == "upsample2d": self.resample = nnx.Sequential( @@ -554,8 +554,8 @@ def __init__( precision=precision, ) ) - self.attentions = attentions - self.resnets = resnets + self.attentions = nnx.data(attentions) + self.resnets = nnx.data(resnets) def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]): x = self.resnets[0](x, feat_cache, feat_idx) @@ -601,10 +601,10 @@ def __init__( ) ) current_dim = out_dim - self.resnets = resnets + self.resnets = nnx.data(resnets) # Add upsampling layer if needed. - self.upsamplers = None + self.upsamplers = nnx.data(None) if upsample_mode is not None: self.upsamplers = [ WanResample( @@ -710,6 +710,7 @@ def __init__( ) ) scale /= 2.0 + self.down_blocks = nnx.data(self.down_blocks) # middle_blocks self.mid_block = WanMidBlock( @@ -873,6 +874,7 @@ def __init__( # Update scale for next iteration if upsample_mode is not None: scale *= 2.0 + self.up_blocks = nnx.data(self.up_blocks) # output blocks self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 718b5015..48ed7b8e 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -209,7 +209,7 @@ def __init__( inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - self.act_fn = None + self.act_fn = nnx.data(None) if activation_fn == "gelu-approximate": self.act_fn = ApproximateGELU( rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision