Skip to content

Run wan2.2 model on GPU A100 #412

@kliuchevsky

Description

@kliuchevsky

Is it possible to run wan2.2 on GPU A100? For now I have faced with this problem:
/mnt/nvme0/abeliaev/projects/wan/maxdiffusion$ CUDA_VISIBLE_DEVICES=2 && HF_HUB_CACHE=/mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ && python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=40 num_frames=81 width=832 height=480 jax_cache_dir=/mnt/nvme0/abeliaev/projects/wan/maxdiffusion/jax_cache/ per_device_batch_size=1 ici_data_parallelism=1 ici_context_parallelism=1 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=/mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output fps=16 flash_min_seq_length=0 skip_jax_distributed_system=True flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=42
✓ Mock Transformer Engine loaded (inference mode)
Skipping jax distributed system due to skip_jax_distributed_system=True flag.
Adding sequence sharding to q and kv if not already present because attention='flash' requires it or attention_sharding_uniform=True is set.
Initial logical axis rules: [('batch', ('data', 'fsdp')), ('activation_batch', ('data', 'fsdp')), ('activation_self_attn_heads', ('context', 'tensor')), ('activation_cross_attn_q_length', ('context', 'tensor')), ('activation_length', 'context'), ('activation_heads', 'tensor'), ('mlp', 'tensor'), ('embed', ('context', 'fsdp')), ('heads', 'tensor'), ('norm', 'tensor'), ('conv_batch', ('data', 'context', 'fsdp')), ('out_channels', 'tensor'), ('conv_out', 'context')]
Adding key/value sequence axis rule ('activation_kv_length', 'context')
Adding sequence parallel attention axis rule ['activation_self_attn_heads', None]
Adding sequence parallel attention axis rule ['activation_self_attn_q_length', 'context']
Adding sequence parallel attention axis rule ['activation_self_attn_kv_length', None]
Adding sequence parallel attention axis rule ['activation_cross_attn_heads', None]
Adding sequence parallel attention axis rule ['activation_cross_attn_q_length', 'context']
Adding sequence parallel attention axis rule ['activation_cross_attn_kv_length', None]
Final logical axis rules: (['activation_self_attn_heads', None], ['activation_self_attn_q_length', 'context'], ['activation_self_attn_kv_length', None], ['activation_cross_attn_heads', None], ['activation_cross_attn_q_length', 'context'], ['activation_cross_attn_kv_length', None], ('batch', ('data', 'fsdp')), ('activation_batch', ('data', 'fsdp')), ('activation_self_attn_heads', ('context', 'tensor')), ('activation_cross_attn_q_length', ('context', 'tensor')), ('activation_length', 'context'), ('activation_heads', 'tensor'), ('mlp', 'tensor'), ('embed', ('context', 'fsdp')), ('heads', 'tensor'), ('norm', 'tensor'), ('conv_batch', ('data', 'context', 'fsdp')), ('out_channels', 'tensor'), ('conv_out', 'context'), ('activation_kv_length', 'context'))
Failed to find host bounds for accelerator type: WARNING: could not determine TPU accelerator type, please set env var TPU_ACCELERATOR_TYPE manually, otherwise libtpu.so may not properly initialize.
WARNING: Logging before InitGoogle() is written to STDERR
E0000 00:00:1779285642.983642 3181942 common_lib.cc:530] INVALID_ARGUMENT: Error: unexpected worker hostname 'WARNING: could not determine TPU worker hostnames or IP addresses' from env var TPU_WORKER_HOSTNAMES. Expecting a valid hostname or IP address without port number, or hostname:port:address triple. (Full TPU workers' addr string: WARNING: could not determine TPU worker hostnames or IP addresses, please set env var TPU_WORKER_HOSTNAMES manually, otherwise libtpu.so may not properly initialize.)
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/libtpu_init_utils.cc:310
INFO:2026-05-20 14:00:43,117:jax._src.xla_bridge:834: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No jellyfish device found.
I0520 14:00:43.117704 133612851323072 xla_bridge.py:834] Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: No jellyfish device found.
Config param activations_dtype: bfloat16
Config param adam_b1: 0.9
Config param adam_b2: 0.999
Config param adam_eps: 1e-08
Config param adam_weight_decay: 0.0
Config param allow_split_physical_axes: False
Config param attention: flash
Config param attention_sharding_uniform: True
Config param base_output_directory:
Config param boundary_ratio: 0.875
Config param cache_latents_text_encoder_outputs: True
Config param caption_column: text
Config param center_crop: False
Config param checkpoint_dir: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/checkpoints/
Config param checkpoint_every: -1
Config param compile_topology_num_slices: -1
Config param controlnet_conditioning_scale: 0.5
Config param controlnet_from_pt: True
Config param controlnet_image: https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png
Config param controlnet_model_name_or_path: diffusers/controlnet-canny-sdxl-1.0
Config param data_sharding: (('data', 'fsdp', 'context', 'tensor'),)
Config param dataset_config_name:
Config param dataset_name: diffusers/pokemon-gpt4-captions
Config param dataset_save_location:
Config param dataset_type: tfrecord
Config param dcn_context_parallelism: -1
Config param dcn_data_parallelism: 1
Config param dcn_fsdp_parallelism: 1
Config param dcn_tensor_parallelism: 1
Config param diffusion_scheduler_config: {'_class_name': 'FlaxEulerDiscreteScheduler', 'prediction_type': 'epsilon', 'rescale_zero_terminal_snr': False, 'timestep_spacing': 'trailing'}
Config param do_classifier_free_guidance: True
Config param dropout: 0.0
Config param enable_data_shuffling: True
Config param enable_eval_timesteps: False
Config param enable_generate_video_for_eval: False
Config param enable_jax_named_scopes: False
Config param enable_lora: False
Config param enable_ml_diagnostics: False
Config param enable_mllog: False
Config param enable_ondemand_xprof: False
Config param enable_profiler: True
Config param enable_single_replica_ckpt_restoring: False
Config param enable_ssim: False
Config param eval_data_dir:
Config param eval_every: -1
Config param eval_max_number_of_samples_in_bucket: 60
Config param flash_block_sizes: {'block_q': 3024, 'block_kv_compute': 1024, 'block_kv': 2048, 'block_q_dkv': 3024, 'block_kv_dkv': 2048, 'block_kv_dkv_compute': 2048, 'block_q_dq': 3024, 'block_kv_dq': 2048}
Config param flash_min_seq_length: 0
Config param flow_shift: 5.0
Config param fps: 16
Config param from_pt: True
Config param gcs_metrics: False
Config param global_batch_size: 0
Config param global_batch_size_to_load: 1
Config param global_batch_size_to_train_on: 1
Config param guidance_rescale: 0.0
Config param guidance_scale_high: 4.0
Config param guidance_scale_low: 3.0
Config param hardware: tpu
Config param height: 480
Config param hf_access_token: None
Config param hf_data_dir:
Config param hf_train_files: None
Config param ici_context_parallelism: 1
Config param ici_data_parallelism: 1
Config param ici_fsdp_parallelism: 1
Config param ici_tensor_parallelism: 1
Config param image_column: image
Config param jax_cache_dir: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/jax_cache/
Config param jit_initializers: True
Config param learning_rate: 1e-05
Config param learning_rate_schedule_steps: 1500
Config param lightning_ckpt:
Config param lightning_from_pt: True
Config param lightning_repo:
Config param load_tfrecord_cached: True
Config param log_period: 100
Config param logical_axis_rules: (['activation_self_attn_heads', None], ['activation_self_attn_q_length', 'context'], ['activation_self_attn_kv_length', None], ['activation_cross_attn_heads', None], ['activation_cross_attn_q_length', 'context'], ['activation_cross_attn_kv_length', None], ('batch', ('data', 'fsdp')), ('activation_batch', ('data', 'fsdp')), ('activation_self_attn_heads', ('context', 'tensor')), ('activation_cross_attn_q_length', ('context', 'tensor')), ('activation_length', 'context'), ('activation_heads', 'tensor'), ('mlp', 'tensor'), ('embed', ('context', 'fsdp')), ('heads', 'tensor'), ('norm', 'tensor'), ('conv_batch', ('data', 'context', 'fsdp')), ('out_channels', 'tensor'), ('conv_out', 'context'), ('activation_kv_length', 'context'), ('layers_per_stage', None))
Config param lora_config: {'rank': [64], 'lora_model_name_or_path': ['lightx2v/Wan2.2-Distill-Loras'], 'high_noise_weight_name': ['wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors'], 'low_noise_weight_name': ['wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors'], 'adapter_name': ['wan22-distill-lora'], 'scale': [1.0], 'from_pt': []}
Config param mask_padding_tokens: True
Config param max_grad_norm: 1.0
Config param max_grad_value: 1.0
Config param max_train_samples: -1
Config param max_train_steps: 1500
Config param mesh_axes: ['data', 'fsdp', 'context', 'tensor']
Config param metrics_dir: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/metrics/
Config param metrics_file:
Config param model_name: wan2.2
Config param model_type: T2V
Config param names_which_can_be_offloaded: []
Config param names_which_can_be_saved: []
Config param negative_prompt: Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
Config param no_records_per_shard: 0
Config param norm_num_groups: 32
Config param num_eval_samples: 420
Config param num_frames: 81
Config param num_inference_steps: 40
Config param num_slices: 1
Config param num_train_epochs: 1
Config param opt_enable_grad_clipping: False
Config param opt_enable_grad_global_norm_clipping: False
Config param output_dir: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output
Config param per_device_batch_size: 1.0
Config param precision: DEFAULT
Config param pretrained_model_name_or_path: /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/
Config param profiler_gcs_path:
Config param profiler_steps: 10
Config param prompt: A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.
Config param prompt_2: A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window.
Config param quantization:
Config param quantization_calibration_method: absmax
Config param quantization_local_shard_count: 1
Config param qwix_module_path: .*
Config param random_flip: False
Config param remat_policy: NONE
Config param replicate_vae: False
Config param resolution: 1024
Config param reuse_example_batch: False
Config param revision:
Config param run_name: wan-inference-testing-720p
Config param save_config_to_gcs: False
Config param save_final_checkpoint: False
Config param save_optimizer: False
Config param scale_lr: False
Config param scan_diffusion_loop: False
Config param scan_layers: True
Config param seed: 42
Config param skip_first_n_steps_for_profiler: 5
Config param skip_jax_distributed_system: True
Config param snr_gamma: -1.0
Config param split_head_dim: True
Config param tensorboard_dir: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/tensorboard/
Config param text_encoder_learning_rate: 4.25e-06
Config param tfrecords_dir:
Config param timestep_bias: {'strategy': 'none', 'multiplier': 1.0, 'begin': 0, 'end': 1000, 'portion': 0.25}
Config param timesteps_list: [125, 250, 375, 500, 625, 750, 875]
Config param timing_metrics_file:
Config param tokenize_captions_num_proc: 4
Config param tokenizer_model_name_or_path: /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/
Config param total_train_batch_size: 1.0
Config param train_data_dir:
Config param train_split: train
Config param train_text_encoder: False
Config param transform_images_num_proc: 4
Config param unet_checkpoint:
Config param use_base2_exp: True
Config param use_batched_text_encoder: False
Config param use_cfg_cache: False
Config param use_experimental_scheduler: True
Config param use_kv_cache: False
Config param use_qwix_quantization: False
Config param use_sen_cache: False
Config param vae_logical_axis_rules: (('activation_batch', 'redundant'), ('activation_length', 'vae_spatial'), ('activation_heads', None), ('activation_kv_length', None), ('embed', None), ('heads', None), ('norm', None), ('conv_batch', 'redundant'), ('out_channels', 'vae_spatial'), ('conv_out', 'vae_spatial'), ('conv_in', 'vae_spatial'))
Config param vae_spatial: 1
Config param wan_transformer_pretrained_model_name_or_path: /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/
Config param warmup_steps_fraction: 0.1
Config param weights_dtype: bfloat16
Config param width: 832
Config param write_metrics: True
Config param write_timing_metrics: True
TensorBoard logs will be written to: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/tensorboard/
Git Commit Hash: 19d4e4d
Creating checkpoing manager...
checkpoint dir: /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/checkpoints/
I0520 14:00:47.052260 133612851323072 pytree_checkpoint_handler.py:589] save_device_host_concurrent_bytes=None
I0520 14:00:47.053372 133612851323072 base_pytree_checkpoint_handler.py:415] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7984b6dab0b0>, enable_pinned_host_transfer=True, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0520 14:00:47.053521 133612851323072 pytree_checkpoint_handler.py:589] save_device_host_concurrent_bytes=None
I0520 14:00:47.053648 133612851323072 base_pytree_checkpoint_handler.py:415] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7984b6dab0b0>, enable_pinned_host_transfer=True, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
I0520 14:00:47.053748 133612851323072 pytree_checkpoint_handler.py:589] save_device_host_concurrent_bytes=None
I0520 14:00:47.053846 133612851323072 base_pytree_checkpoint_handler.py:415] Created BasePyTreeCheckpointHandler: use_ocdbt=True, use_zarr3=False, pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False), array_metadata_store=<orbax.checkpoint._src.metadata.array_metadata_store.Store object at 0x7984b6dab0b0>, enable_pinned_host_transfer=True, save_concurrent_bytes: 96000000000 (89.4 GiB), restore_concurrent_bytes: 96000000000 (89.4 GiB)
item_names: ('low_noise_transformer_state', 'high_noise_transformer_state', 'wan_state', 'wan_config')
I0520 14:00:47.054114 133612851323072 checkpoint_manager.py:709] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=('low_noise_transformer_state', 'high_noise_transformer_state', 'wan_state', 'wan_config'), item_handlers={'wan_config': <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839a996720>, 'wan_state': <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839cae2330>, 'low_noise_transformer_state': <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79836f7022d0>, 'high_noise_transformer_state': <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839a9c42f0>}, handler_registry=None
I0520 14:00:47.055115 133612851323072 composite_checkpoint_handler.py:237] Deferred registration for item: "low_noise_transformer_state". Adding handler <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79836f7022d0> for item "low_noise_transformer_state" and save args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'> and restore args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'> to _handler_registry.
I0520 14:00:47.055238 133612851323072 composite_checkpoint_handler.py:237] Deferred registration for item: "high_noise_transformer_state". Adding handler <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839a9c42f0> for item "high_noise_transformer_state" and save args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'> and restore args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'> to _handler_registry.
I0520 14:00:47.055323 133612851323072 composite_checkpoint_handler.py:237] Deferred registration for item: "wan_state". Adding handler <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839cae2330> for item "wan_state" and save args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'> and restore args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'> to _handler_registry.
I0520 14:00:47.055402 133612851323072 composite_checkpoint_handler.py:237] Deferred registration for item: "wan_config". Adding handler <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839a996720> for item "wan_config" and save args <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'> and restore args <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'> to _handler_registry.
I0520 14:00:47.055481 133612851323072 composite_checkpoint_handler.py:237] Deferred registration for item: "metrics". Adding handler <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839c910560> for item "metrics" and save args <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'> and restore args <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'> to _handler_registry.
I0520 14:00:47.055576 133612851323072 composite_checkpoint_handler.py:505] Initialized registry DefaultCheckpointHandlerRegistry({('low_noise_transformer_state', <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>): <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79836f7022d0>, ('low_noise_transformer_state', <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>): <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79836f7022d0>, ('high_noise_transformer_state', <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>): <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839a9c42f0>, ('high_noise_transformer_state', <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>): <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839a9c42f0>, ('wan_state', <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'>): <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839cae2330>, ('wan_state', <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>): <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x79839cae2330>, ('wan_config', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839a996720>, ('wan_config', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839a996720>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839c910560>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x79839c910560>}).
I0520 14:00:47.056149 133612851323072 abstract_checkpointer.py:35] orbax-checkpoint version: 0.11.33
I0520 14:00:47.056293 133612851323072 async_checkpointer.py:177] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.. at 0x79836f7d6de0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I0520 14:00:47.057314 133612851323072 checkpoint_manager.py:1818] Found 0 checkpoint steps in /mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/checkpoints
I0520 14:00:47.057696 133612851323072 checkpoint_manager.py:929] [process=0][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=None, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, todelete_full_path=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None, save_decision_policy=None, preservation_policy=None, prevent_write_metrics=False, enable_should_save_is_saving_in_progress_check=True, enable_per_process_directory_creation=False, lightweight_initialize=False), root_directory=/mnt/nvme0/abeliaev/projects/wan/maxdiffusion/output/wan-inference-testing-720p/checkpoints: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x79839a9f0500>
Checkpoint manager created!
Latest WAN checkpoint step: None
No WAN checkpoint found.
No checkpoint found, loading default pipeline.
Devices: [CudaDevice(id=0)] (num_devices: 1)
Decided on mesh: [[[[CudaDevice(id=0)]]]]
Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of 1.
/mnt/nvme0/abeliaev/maxdiffusion_venv/lib/python3.12/site-packages/maxdiffusion/configuration_utils.py:262: FutureWarning: It is deprecated to pass a pretrained model name or path to from_config.
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ VAE on TFRT_CPU_0
torch_dtype is deprecated! Use dtype instead!
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 130.10it/s]
The config attributes {'rescale_betas_zero_snr': False, 'time_shift_type': 'exponential', 'use_dynamic_shifting': False} were passed to FlaxUniPCMultistepScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0
Load and port /mnt/nvme2/ckpt/Wan2.2-T2V-A14B-Diffusers/ transformer_2 on TFRT_CPU_0 and it stacked

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions