导入包

In [1]:
import jax
import jax.numpy as jnp
import optax

from jax import grad, jit, vmap
from jax.nn import relu, one_hot
from jax.random import PRNGKey, normal, randint



In [2]:
import alpa
from alpa.model.model_util import TrainState

  from .autonotebook import tqdm as notebook_tqdm
2024-12-25 08:36:49,464	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


定义模型参数

In [3]:
def initialize_params(key, layer_sizes):
    keys = jax.random.split(key, len(layer_sizes))
    return [normal(k, (m, n)) for m, n, k in zip(layer_sizes[:-1], layer_sizes[1:], keys)]

In [4]:
def create_model(input_dim: int, hidden_dim: int, output_dim: int):
    """创建一个简单的由4层全连接层组成的模型"""
    def init_layer(key, input_dim, output_dim):
        w_key, b_key = jax.random.split(key)
        w = jax.random.normal(w_key, (input_dim, output_dim)) * jnp.sqrt(1 / input_dim)
        b = jnp.zeros(output_dim)
        return w, b
    
    key = jax.random.PRNGKey(0)
    key1, key2, key3, key4 = jax.random.split(key, 4)

    params = {
        'fc1': init_layer(key1, input_dim, hidden_dim),
        'fc2': init_layer(key2, hidden_dim, hidden_dim),
        'fc3': init_layer(key3, hidden_dim, hidden_dim),
        'fc4': init_layer(key4, hidden_dim, output_dim),
    }
    return params

定义前向传播

In [5]:
def forward(params, x):
    *hidden, out = params
    for layer in hidden:
        x = relu(jnp.dot(x, layer))
    return jnp.dot(x, out)

In [6]:
def forward_pass(params, x):
    x = jnp.dot(x, params['fc1'][0]) + params['fc1'][1]
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['fc2'][0]) + params['fc2'][1]
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['fc3'][0]) + params['fc3'][1]
    x = jax.nn.relu(x)
    x = jnp.dot(x, params['fc4'][0]) + params['fc4'][1]
    return x

定义损失函数

In [7]:
def loss_fn(params, x, y):
    logits = forward_pass(params, x)
    loss = optax.softmax_cross_entropy(logits, y).mean()
    return loss


创建状态

In [8]:
def create_train_state(rng, input_dim, hidden_dim, output_dim, learning_rate):
    params = create_model(input_dim, hidden_dim, output_dim)
    tx = optax.adamw(learning_rate=learning_rate)
    return TrainState.create(apply_fn=forward_pass, params=params, tx=tx, dynamic_scale=None)

定义训练步骤

In [9]:
def train_step(state, batch):
    def loss_fn_inner(params):
        x, y = batch
        return loss_fn(params, x, y)
    
    grad_fn = alpa.value_and_grad(loss_fn_inner)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [11]:

# 设置随机种子
rng = PRNGKey(0)

input_dim = 784
hidden_dim = 128
output_dim = 10

num_samples = 1000
x_train = normal(rng, (num_samples, input_dim))
y_train = randint(rng, (num_samples, ), 0, output_dim, dtype=int)
y_train = one_hot(y_train, output_dim)


num_train_epochs = 1
train_batch_size = 32
num_batches = num_samples // train_batch_size
micro_batch_size = 8
learning_rate = 5e-5

rng, init_rng = jax.random.split(rng)

state = create_train_state(init_rng, input_dim, hidden_dim, output_dim, learning_rate,)


jax训练

In [None]:
import os
import time

from flax.training.common_utils import shard
from jax.experimental import maps, pjit

from tqdm import tqdm

In [None]:
"""
# Create parallel version of the train step
devices = jax.devices()
state = shard(state.replicate(), devices)
train_step_p = pjit(
    train_step,
    in_shardings=(None, P('batch')),
    out_shardings=(None, None),
    donate_argnums=(0,),
)


train_time = 0
last_time = time.time()
epochs = tqdm(range(num_train_epochs), desc=f"Epoch ... (1/{num_train_epochs})", position=0)

for epoch in epochs:
    # ======================== Training ================================
    train_start = time.time()

    # Create sampling rng
    rng, input_rng = jax.random.split(rng)
    train_metrics = []

    train_step_progress_bar = tqdm(total=num_batches, desc="Training...", position=1, leave=False)
    # train
    for step in range(num_batches):
        start_idx = step * train_batch_size
        end_idx = (step + 1) * train_batch_size
        batch = (x_train[start_idx:end_idx], y_train[start_idx:end_idx])
        state, loss = train_step_p(state, batch)
        train_metrics.append(loss)

        train_step_progress_bar.update(1)

    latency = time.time() - last_time
    images_per_second = num_samples / latency
    train_time += time.time() - train_start
    last_time = time.time()

    train_step_progress_bar.close()
    epochs.write(
        f"Epoch... ({epoch + 1}/{num_train_epochs} | Loss: {jnp.mean(train_metrics)}, "
        f"Throughput: {images_per_second:.2f} images/s"
    )"""


AttributeError: 'TrainState' object has no attribute 'replicate'

## alpa_heter

In [12]:
alpa.init(cluster="ray")

2024-12-25 08:37:38,253	INFO worker.py:1636 -- Connecting to existing Ray cluster at address: 172.17.0.6:6379...
2024-12-25 08:37:38,312	INFO worker.py:1821 -- Connected to Ray cluster.


Initializing ray with address auto
GPUlets enabled, getting GPU info from node: 172.17.0.6

all_host_info: [{'NodeID': 'efb40bc566dcb61bd44a8056a21b0412ed15a6a8f44ad1256c05cb72', 'Alive': True, 'NodeManagerAddress': '172.17.0.6', 'NodeManagerHostname': '7b1e0a1f1e75', 'NodeManagerPort': 37155, 'ObjectManagerPort': 44293, 'ObjectStoreSocketName': '/tmp/ray/session_2024-12-25_08-24-58_226897_50998/sockets/plasma_store', 'RayletSocketName': '/tmp/ray/session_2024-12-25_08-24-58_226897_50998/sockets/raylet', 'MetricsExportPort': 53761, 'NodeName': '172.17.0.6', 'RuntimeEnvAgentPort': 62597, 'DeathReason': 0, 'DeathReasonMessage': '', 'alive': True, 'Resources': {'node:__internal_head__': 1.0, 'node:172.17.0.6': 1.0, 'GPU': 2.0, 'accelerator_type:G': 1.0, 'CPU': 40.0, 'object_store_memory': 39348381696.0, 'memory': 81812890624.0}, 'Labels': {'ray.io/node_id': 'efb40bc566dcb61bd44a8056a21b0412ed15a6a8f44ad1256c05cb72'}}]
all_host_num_devices: [2]


In [13]:
train_method = alpa.parallel_method.PipeshardParallel(stage_option="auto")
p_train_step = alpa.parallelize(train_step,
                                method=train_method)
                                #donate_argnums=(0,))

In [14]:
dump_debug_info_train_step = True
for epoch in range(num_train_epochs):
    # ======================== Training ================================

    # Create sampling rng
    rng, input_rng = jax.random.split(rng)
    train_metrics = []

    # train
    for step in range(num_batches):
        start_idx = step * train_batch_size
        end_idx = (step + 1) * train_batch_size
        batch = (x_train[start_idx:end_idx], y_train[start_idx:end_idx])
        state, loss = p_train_step(state, batch)
        train_metrics.append(loss)

        if dump_debug_info_train_step:
            dump_debug_info_train_step = False
            executable = p_train_step.get_last_executable()
            executable.sync()
            executable.dump_debug_info("alpa_debug_info")


    

INFO:alpa.pipeline_parallel.stage_construction:num_devices = 1



-*-*-*-*-VirtualPhysicalMesh-*-*-*-*-
host_ids:  [0]
host_info:  [{'NodeID': 'efb40bc566dcb61bd44a8056a21b0412ed15a6a8f44ad1256c05cb72', 'Alive': True, 'NodeManagerAddress': '172.17.0.6', 'NodeManagerHostname': '7b1e0a1f1e75', 'NodeManagerPort': 37155, 'ObjectManagerPort': 44293, 'ObjectStoreSocketName': '/tmp/ray/session_2024-12-25_08-24-58_226897_50998/sockets/plasma_store', 'RayletSocketName': '/tmp/ray/session_2024-12-25_08-24-58_226897_50998/sockets/raylet', 'MetricsExportPort': 53761, 'NodeName': '172.17.0.6', 'RuntimeEnvAgentPort': 62597, 'DeathReason': 0, 'DeathReasonMessage': '', 'alive': True, 'Resources': {'node:__internal_head__': 1.0, 'node:172.17.0.6': 1.0, 'GPU': 2.0, 'accelerator_type:G': 1.0, 'CPU': 40.0, 'object_store_memory': 39348381696.0, 'memory': 81812890624.0}, 'Labels': {'ray.io/node_id': 'efb40bc566dcb61bd44a8056a21b0412ed15a6a8f44ad1256c05cb72'}}]
num_devices_per_host:  2
devices:  [[0, 1]]
num_gpus:  2
node efb40bc566dcb61bd44a8056a21b0412ed15a6a8f44ad1256c

100%|██████████| 2/2 [00:00<00:00,  4.92it/s]


profile_results before check: {}
get_compute_cost() in stage_profiling.py:
==auto_sharding_configs:
config: (<alpa.shard_parallel.auto_sharding.LogicalDeviceMesh object at 0x7f4b25f5d370>, {'force_batch_dim_to_mesh_dim': 0})
config: (<alpa.shard_parallel.auto_sharding.LogicalDeviceMesh object at 0x7f4b25f5d9a0>, {})
==stages to profile:
stage_idx: (0, 0, 0, 0)
stage_idx: (0, 0, 0, 1)
stage_idx: (0, 1, 0, 0)
stage_idx: (0, 1, 0, 1)
stage_idx: (1, 1, 0, 0)
stage_idx: (1, 1, 0, 1)
==sliced_virtual_meshes:
sliced_virtual_mesh 0: host_ids->[0]                         num_devices_per_host->1  devices->[[0]]
sliced_virtual_mesh 1: host_ids->[0]                         num_devices_per_host->1  devices->[[1]]
- Compile all stages


100%|██████████| 6/6 [00:01<00:00,  3.65it/s]


- Profile all stages


  8%|▊         | 1/12 [00:05<00:59,  5.41s/it]

[36m(ProfileWorker pid=51994)[0m INFO:alpa.device_mesh:num_devices_per_host: 1
[36m(ProfileWorker pid=51994)[0m INFO:alpa.device_mesh:num_gpus: 1
[36m(ProfileWorker pid=51994)[0m INFO:alpa.device_mesh:Launching workers on hosts: [0]
[36m(MeshHostWorker pid=54495)[0m INFO:alpa.mesh_executable:num_devices: 1, len(worker.backend.devices()): 1
[36m(MeshHostWorker pid=54495)[0m 2024-12-25 08:38:01.252178: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:219] failed to create cublas handle: cublas error
[36m(MeshHostWorker pid=54495)[0m 2024-12-25 08:38:01.252230: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:221] Failure to initialize cublas may be due to OOM (cublas needs some free memory when you initialize it, and your deep-learning framework may have preallocated more than its fair share), or may be because this binary was not built with support for the GPU in your machine.
[36m(MeshHostWorker pid=5449

result[(0, 0, 0, 0), 0] = ModuleProfileResult(compute_cost=inf, peak_memory=0.001 GB, invar_size=0.001 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)
result[(0, 0, 0, 1), 0] = ModuleProfileResult(compute_cost=inf, peak_memory=0.001 GB, invar_size=0.001 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)


 25%|██▌       | 3/12 [00:05<00:13,  1.48s/it]

result[(0, 0, 0, 1), 1] = ModuleProfileResult(compute_cost=inf, peak_memory=0.001 GB, invar_size=0.000 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)


 50%|█████     | 6/12 [00:06<00:03,  1.72it/s]

result[(0, 1, 0, 0), 0] = ModuleProfileResult(compute_cost=inf, peak_memory=0.001 GB, invar_size=0.001 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)
result[(0, 0, 0, 0), 1] = ModuleProfileResult(compute_cost=0.000, peak_memory=0.001 GB, invar_size=0.000 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=20.767 GB)
result[(0, 1, 0, 0), 1] = ModuleProfileResult(compute_cost=inf, peak_memory=0.001 GB, invar_size=0.000 GB, outvar_size=0.001 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)


 67%|██████▋   | 8/12 [00:06<00:01,  2.71it/s]

result[(0, 1, 0, 1), 1] = ModuleProfileResult(compute_cost=inf, peak_memory=0.001 GB, invar_size=0.000 GB, outvar_size=0.001 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)
result[(0, 1, 0, 1), 0] = ModuleProfileResult(compute_cost=0.000, peak_memory=0.001 GB, invar_size=0.001 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=20.767 GB)


 92%|█████████▏| 11/12 [00:06<00:00,  4.32it/s]

result[(1, 1, 0, 0), 1] = ModuleProfileResult(compute_cost=0.000, peak_memory=0.000 GB, invar_size=0.000 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=20.767 GB)
result[(1, 1, 0, 0), 0] = ModuleProfileResult(compute_cost=inf, peak_memory=0.000 GB, invar_size=0.000 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)
result[(1, 1, 0, 1), 1] = ModuleProfileResult(compute_cost=inf, peak_memory=0.000 GB, invar_size=0.000 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=1.869 GB)


100%|██████████| 12/12 [00:06<00:00,  1.81it/s]


result[(1, 1, 0, 1), 0] = ModuleProfileResult(compute_cost=0.000, peak_memory=0.000 GB, invar_size=0.000 GB, outvar_size=0.000 GB, temp_buffer_size=0.000 GB, available_memory=20.767 GB)
Profiling for submesh 0 (1, 1) takes 9.12 seconds
--------------------------------------------------
Profile result saved to: profile-results-2024-12-25-08-38-02.npy
----------------------------------------------------------------------


[33m(raylet)[0m A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: fffffffffffffffff0a612bae68291825611299701000000 Worker ID: b781bbdd396ece6446c185ac244083f45de663c363da738b7e330e1b Node ID: efb40bc566dcb61bd44a8056a21b0412ed15a6a8f44ad1256c05cb72 Worker IP address: 172.17.0.6 Worker port: 10043 Worker PID: 54495 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker exits unexpectedly by a signal. SystemExit is raised (sys.exit is called). Exit code: 1. The process receives a SIGTERM.


AssertionError: no solution in auto stage construction.

In [None]:
import os
import time

from flax.training.common_utils import shard
from jax.experimental import maps, pjit

from tqdm import tqdm

In [None]:
import os
import time

from flax.training.common_utils import shard
from jax.experimental import maps, pjit

from tqdm import tqdm

In [None]:
import os
import time

from flax.training.common_utils import shard
from jax.experimental import maps, pjit

from tqdm import tqdm

In [None]:
import os
import time

from flax.training.common_utils import shard
from jax.experimental import maps, pjit

from tqdm import tqdm