Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ MaxDiffusion supports

We recommend starting with a single TPU host and then moving to multihost.

Minimum requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0.
Minimum requirements: Ubuntu Version 22.04, Python 3.12 and Tensorflow >= 2.12.0.

## Getting Started:

Expand Down
14 changes: 11 additions & 3 deletions docs/getting_started/first_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,27 @@ multiple hosts.

1. [Create and SSH to a single-host TPU (v6-8). ](https://cloud.google.com/tpu/docs/users-guide-tpu-vm#creating_a_cloud_tpu_vm_with_gcloud)
* You can find here [here](https://cloud.google.com/tpu/docs/regions-zones) the list of zones that support the v6(Trillium) TPUs
* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.10 and Tensorflow >= 2.12.0
* We recommend using the base VM image "v2-alpha-tpuv6e", which meets the version requirements: Ubuntu Version 22.04, Python 3.12 and Tensorflow >= 2.12.0

1. Clone MaxDiffusion in your TPU VM.
```bash
```
git clone https://github.com/AI-Hypercomputer/maxdiffusion.git
cd maxdiffusion
```

1. Within the root directory of the MaxDiffusion `git` repo, install dependencies by running:
```bash
```
# If a Python 3.12+ virtual environment doesn't already exist, you'll need to run the install command twice.
bash setup.sh MODE=stable DEVICE=tpu
```

1. Active your virtual environment:
```
# Replace with your virtual environment name if not using this default name
venv_name="maxdiffusion_venv"
source ~/$venv_name/bin/activate
```

## Getting Starting: Multihost development

[GKE, recommended] [Running MaxDiffusion with xpk](run_maxdiffusion_via_xpk.md) - Quick Experimentation and Production support
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ grain
google-cloud-storage>=2.17.0
absl-py
datasets
flax>=0.10.2
flax>=0.11.0
optax>=0.2.3
torch>=2.6.0
torchvision>=0.20.1
Expand Down
40 changes: 40 additions & 0 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,46 @@
set -e
export DEBIAN_FRONTEND=noninteractive

echo "Checking Python version..."
# This command will fail if the Python version is less than 3.12
if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; then
# If the command fails, print an error
CURRENT_VERSION=$(python3 --version 2>&1) # Get the full version string
echo -e "\n\e[31mERROR: Outdated Python Version! You are currently using $CURRENT_VERSION, but MaxDiffusion requires Python version 3.12 or higher.\e[0m"
# Ask the user if they want to create a virtual environment with uv
read -p "Would you like to create a Python 3.12 virtual environment using uv? (y/n) " -n 1 -r
echo # Move to a new line after input
if [[ $REPLY =~ ^[Yy]$ ]]; then
# Check if uv is installed first; if not, install uv
if ! command -v uv &> /dev/null; then
pip install uv
fi
maxdiffusion_dir=$(pwd)
cd
# Ask for the venv name
read -p "Please enter a name for your new virtual environment (default: maxdiffusion_venv): " venv_name
# Use a default name if the user provides no input
if [ -z "$venv_name" ]; then
venv_name="maxdiffusion_venv"
echo "No name provided. Using default name: '$venv_name'"
fi
echo "Creating virtual environment '$venv_name' with Python 3.12..."
uv venv --python 3.12 "$venv_name" --seed
printf '%s\n' "$(realpath -- "$venv_name")" >> /tmp/venv_created
echo -e "\n\e[32mVirtual environment '$venv_name' created successfully!\e[0m"
echo "To activate it, run the following command:"
echo -e "\e[33m source ~/$venv_name/bin/activate\e[0m"
echo "After activating the environment, please re-run this script."
cd $maxdiffusion_dir
else
echo "Exiting. Please upgrade your Python environment to continue."
fi
# Exit the script since the initial Python check failed
exit 1
fi
echo "Python version check passed. Continuing with script."
echo "--------------------------------------------------"

(sudo bash || bash) <<'EOF'
mkdir -p /etc/needrestart/conf.d
echo '$nrconf{restart} = "a";' > /etc/needrestart/conf.d/99-noninteractive.conf
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,8 @@ def __init__(

self.drop_out = nnx.Dropout(dropout)

self.norm_q = None
self.norm_k = None
self.norm_q = nnx.data(None)
self.norm_k = nnx.data(None)
if qk_norm is not None:
self.norm_q = nnx.RMSNorm(
num_features=self.inner_dim,
Expand Down
12 changes: 7 additions & 5 deletions src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(
):
self.dim = dim
self.mode = mode
self.time_conv = None
self.time_conv = nnx.data(None)

if mode == "upsample2d":
self.resample = nnx.Sequential(
Expand Down Expand Up @@ -554,8 +554,8 @@ def __init__(
precision=precision,
)
)
self.attentions = attentions
self.resnets = resnets
self.attentions = nnx.data(attentions)
self.resnets = nnx.data(resnets)

def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
x = self.resnets[0](x, feat_cache, feat_idx)
Expand Down Expand Up @@ -601,10 +601,10 @@ def __init__(
)
)
current_dim = out_dim
self.resnets = resnets
self.resnets = nnx.data(resnets)

# Add upsampling layer if needed.
self.upsamplers = None
self.upsamplers = nnx.data(None)
if upsample_mode is not None:
self.upsamplers = [
WanResample(
Expand Down Expand Up @@ -710,6 +710,7 @@ def __init__(
)
)
scale /= 2.0
self.down_blocks = nnx.data(self.down_blocks)

# middle_blocks
self.mid_block = WanMidBlock(
Expand Down Expand Up @@ -873,6 +874,7 @@ def __init__(
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
self.up_blocks = nnx.data(self.up_blocks)

# output blocks
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim

self.act_fn = None
self.act_fn = nnx.data(None)
if activation_fn == "gelu-approximate":
self.act_fn = ApproximateGELU(
rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision
Expand Down
Loading