Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions csrc/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
// clang-format on
#pragma once

#include <concepts>
#include <coroutine>
#include <deque>
#include <iterator>
Expand Down Expand Up @@ -281,25 +280,27 @@ SPECIALIZE_PRINTER(VoidStar);
SPECIALIZE_PRINTER(uint32_t);
SPECIALIZE_PRINTER(int64_t);
SPECIALIZE_PRINTER(uint64_t);
SPECIALIZE_PRINTER(DataType);
SPECIALIZE_PRINTER(MemoryType);
SPECIALIZE_PRINTER(UnaryOpType);

SPECIALIZE_PRINTER(BinaryOpType);
SPECIALIZE_PRINTER(TernaryOpType);
SPECIALIZE_PRINTER(LoadStoreOpType);
SPECIALIZE_PRINTER(CircularBufferLoopStage);
SPECIALIZE_PRINTER(tma::TensorMapInterleave);
SPECIALIZE_PRINTER(tma::TensorMapL2Promotion);
SPECIALIZE_PRINTER(tma::TensorMapFloatOOBFill);
SPECIALIZE_PRINTER(DataType);
SPECIALIZE_PRINTER(LoadStoreOpType);
SPECIALIZE_PRINTER(MemoryType);
SPECIALIZE_PRINTER(MmaInputSmemSwizzle);
SPECIALIZE_PRINTER(SwizzleType);
SPECIALIZE_PRINTER(ParallelType);
SPECIALIZE_PRINTER(Swizzle2DType);
SPECIALIZE_PRINTER(SwizzleMode);
SPECIALIZE_PRINTER(SwizzleType);
SPECIALIZE_PRINTER(TernaryOpType);
SPECIALIZE_PRINTER(UnaryOpType);
SPECIALIZE_PRINTER(std::optional<bool>);
SPECIALIZE_PRINTER(std::vector<int64_t>);
SPECIALIZE_PRINTER(std::vector<int>);
SPECIALIZE_PRINTER(std::vector<uint32_t>);
SPECIALIZE_PRINTER(std::vector<int64_t>);
SPECIALIZE_PRINTER(std::vector<uint64_t>);
SPECIALIZE_PRINTER(std::optional<bool>);
SPECIALIZE_PRINTER(tma::TensorMapFloatOOBFill);
SPECIALIZE_PRINTER(tma::TensorMapInterleave);
SPECIALIZE_PRINTER(tma::TensorMapL2Promotion);

#undef SPECIALIZE_PRINTER

Expand Down
12 changes: 2 additions & 10 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,8 @@ void lowerToBroadcast(
const DeviceMesh& sender_mesh = input_tv->getDeviceMesh();
const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh();

NVF_ERROR_EQ(
sender_mesh.rank(),
1,
"Broadcast only supports a 1D sender mesh. Given ",
sender_mesh);
NVF_ERROR_EQ(
receiver_mesh.rank(),
1,
"Broadcast only supports a 1D receiver mesh. Given ",
receiver_mesh);
NVF_ERROR_EQ(sender_mesh.rank(), 1, "sender: ", input_tv->toString());
NVF_ERROR_EQ(receiver_mesh.rank(), 1, "receiver: ", output_tv->toString());

DeviceIdxType root = sender_mesh.at(0);
Team team = receiver_mesh.vector();
Expand Down
7 changes: 7 additions & 0 deletions csrc/multidevice/propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <unordered_map>
#include <vector>

#include "base.h"
#include "ir/interface_nodes.h"
#include "ir/internal_base_nodes.h"
#include "ir/internal_nodes.h"
Expand Down Expand Up @@ -255,6 +256,12 @@ void shardLoopLike(
TensorView* tv,
const std::unordered_set<ParallelType>& selected_parallel_types,
PropagateDirection direction) {
if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {
debug() << "Propagating shardings from " << ref->toString() << " to "
<< tv->toString() << " in " << direction << " for "
<< toDelimitedString(selected_parallel_types) << std::endl;
}

std::unordered_set<IterDomain*> device_or_stream_ids;
const std::unordered_map<IterDomain*, IterDomain*> ref2target =
getRef2TargetMap(ref, tv, direction);
Expand Down
4 changes: 2 additions & 2 deletions csrc/preseg_passes/exact_mapped_extent_substitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include "debug.h"
#include "id_model/id_model.h"
#include "ir/utils.h"
#include "logical_domain_map.h"
#include "options.h"

namespace nvfuser::preseg_passes {
Expand Down Expand Up @@ -150,7 +149,8 @@ void ExactMappedExtentSubstitutionPass::runPass(Fusion* fusion) {

if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {
debug() << "ExactLogicalDomainMap after " << name() << ":" << std::endl;
IdModel id_model(fusion, false, false, false);
IdModel id_model(
fusion, /*build_graphs=*/false, /*allow_self_mapping=*/true);
id_model.buildExactGraph();
const ValGraph& exact_graph = id_model.idGraph(IdMappingMode::EXACT);
const DisjointSets<Val*>& id_sets = exact_graph.disjointValSets();
Expand Down
8 changes: 7 additions & 1 deletion csrc/preseg_passes/propagate_shardings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "ir/iostream.h"
#include "ir/utils.h"
#include "multidevice/propagation.h"
#include "multidevice/utils.h"
#include "scheduler/utils.h"

namespace nvfuser::preseg_passes {
Expand Down Expand Up @@ -187,6 +186,13 @@ void PropagateShardingsPass::runPass(Fusion* fusion) {
PropagateDirection::kBackward);
}
}

if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {
debug() << std::endl
<< "Fusion Transforms after " << name() << ":" << std::endl;
fusion->printTransforms();
debug() << std::endl;
}
}

} // namespace nvfuser::preseg_passes
16 changes: 16 additions & 0 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@
#include <type.h>
#include <val_graph_visitor.h>

namespace nvfuser {

std::ostream& operator<<(std::ostream& os, PropagateDirection direction) {
switch (direction) {
case PropagateDirection::kForward:
os << "Forward";
break;
case PropagateDirection::kBackward:
os << "Backward";
break;
}
return os;
}

} // namespace nvfuser

namespace nvfuser::scheduler_utils {

// Minimal PTX code for a no-op kernel, used for occupancy queries
Expand Down
4 changes: 3 additions & 1 deletion csrc/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#pragma once

#include <base.h>
#include <disjoint_set.h>
#include <exceptions.h>
#include <fusion.h>
Expand All @@ -15,7 +16,6 @@
#include <scheduler/reduction_heuristic.h>
#include <scheduler/tools/maxinfo_propagator.h>
#include <visibility.h>
#include "base.h"

namespace nvfuser {

Expand All @@ -29,6 +29,8 @@ class HeuristicDataCache;
//! BoundedDirectionalTransformPropagator.
enum class PropagateDirection { kBackward = 0, kForward };

std::ostream& operator<<(std::ostream& os, PropagateDirection direction);

namespace scheduler_utils {

// Assume any only half of the register file is available to spend on buffers,
Expand Down
222 changes: 222 additions & 0 deletions tests/python/multidevice/test_alphafold3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause


# This file contains certain building blocks of the AlphaFold3 model.

import pytest
import torch
from dataclasses import dataclass
from enum import Enum, auto

import nvfuser_direct as nvfuser
from nvfuser_direct import FusionDefinition, DataType, TensorView


@dataclass
class ModelConfig:
c_z: int = 128
c_hidden: int = 32
n_heads: int = 4


_DEFAULT_CONFIG = ModelConfig()


class Direction(Enum):
INCOMING = auto() # aka ending node
OUTGOING = auto() # aka starting node


def layer_norm(
fd: FusionDefinition, x: TensorView, w: TensorView, b: TensorView
) -> TensorView:
io_dtype = x.dtype()
x = fd.ops.cast(x, dtype=DataType.Float)
var, mean = fd.ops.var_mean(x, dims=[-1], correction=0, keepdim=True)
y = fd.ops.sub(x, mean)
var = fd.ops.add(var, fd.define_scalar(1e-5))
y = fd.ops.mul(y, fd.ops.rsqrt(var))
shape = fd.ops.shape(x)
w = fd.ops.broadcast_in_dim(w, shape=shape, broadcast_dims=[-1])
y = fd.ops.mul(y, w)
b = fd.ops.broadcast_in_dim(b, shape=shape, broadcast_dims=[-1])
y = fd.ops.add(y, b)
y = fd.ops.cast(y, dtype=io_dtype)
return y


def gating(
fd: FusionDefinition,
z: TensorView,
w_p: TensorView,
z_in: TensorView,
w_g: TensorView,
) -> TensorView:
io_dtype = z.dtype()
p = fd.ops.linear(z, w_p)
g = fd.ops.linear(z_in, w_g)
g = fd.ops.sigmoid(g)
z = fd.ops.mul(p, g)
return fd.ops.cast(z, dtype=io_dtype)


# https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/#triangle-updates
#
# Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
# prediction with AlphaFold. Nature 596, 583–589 (2021).
# https://doi.org/10.1038/s41586-021-03819-2
# (see Supplementary Methods 1.6.5 for details)
@pytest.mark.mpi
@pytest.mark.parametrize(
"direction", [Direction.OUTGOING, Direction.INCOMING], ids=lambda d: d.name.lower()
)
def test_triangle_updates(direction, multidevice_test):
d = multidevice_test.size
cp_size = 1
if d % (cp_size * cp_size) != 0:
pytest.skip(
f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}."
)
dp_size = d // (cp_size * cp_size)

c_z = _DEFAULT_CONFIG.c_z

with FusionDefinition() as fd:
z_in_tv = fd.define_tensor(
shape=[-1, -1, -1, c_z],
dtype=DataType.BFloat16,
contiguity=True,
) # [b, i, j, c_z]
w_norm_in = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
b_norm_in = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
w_p_in = fd.define_tensor(
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g_in = fd.define_tensor(
shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_norm_out = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
b_norm_out = fd.define_tensor(
shape=[c_z], dtype=DataType.BFloat16, contiguity=True
)
w_p_out = fd.define_tensor(
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
)
w_g_out = fd.define_tensor(
shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
)
# Masking is used in an internal implementation: http://nv/e-4
mask_tv = fd.define_tensor(
shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
) # [b, i, j]

batch_size = fd.ops.size(z_in_tv, 0)
n_tokens = fd.ops.size(z_in_tv, 1)

z_in = layer_norm(fd, z_in_tv, w_norm_in, b_norm_in)
z = gating(fd, z_in_tv, w_p_in, z_in, w_g_in)
mask = fd.ops.broadcast_in_dim(
mask_tv,
shape=[batch_size, n_tokens, n_tokens, c_z],
broadcast_dims=[0, 1, 2],
)
z = fd.ops.where(mask, z, 0.0)
a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z])
b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2])

match direction:
case Direction.OUTGOING:
# z_out = einsum("bikc,bjkc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 1, 2]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 2, 1]) # [b, c, k, j]
case Direction.INCOMING:
# z_out = einsum("bkic,bkjc->bijc", a, b)
a = fd.ops.permute(a, [0, 3, 2, 1]) # [b, c, i, k]
b = fd.ops.permute(b, [0, 3, 1, 2]) # [b, c, k, j]
z = fd.ops.matmul(a, b) # [b, c, i, j]
z = fd.ops.permute(z, [0, 2, 3, 1]) # [b, i, j, c]

z = layer_norm(fd, z, w_norm_out, b_norm_out)
z = gating(fd, z, w_p_out, z_in, w_g_out)
fd.add_output(z)

mesh = nvfuser.multidevice.DeviceMesh(
torch.arange(d).reshape(dp_size, cp_size, cp_size)
)
for tv in [
z_in_tv,
w_norm_in,
b_norm_in,
w_p_in,
w_g_in,
w_norm_out,
b_norm_out,
w_p_out,
w_g_out,
mask_tv,
]:
tv.set_device_mesh(mesh)

for tv in [z_in_tv, mask_tv]:
tv.outer_split(2, cp_size)
tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
tv.outer_split(1, cp_size)
tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y)
tv.outer_split(0, dp_size)
tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z)

batch_per_rank = 3
n_tokens_per_rank = 5
z_in_ref = torch.testing.make_tensor(
batch_per_rank * dp_size,
n_tokens_per_rank * cp_size,
n_tokens_per_rank * cp_size,
c_z,
dtype=torch.bfloat16,
device="cpu",
)
mask_ref = torch.testing.make_tensor(
batch_per_rank * dp_size,
n_tokens_per_rank * cp_size,
n_tokens_per_rank * cp_size,
dtype=torch.bool,
device="cpu",
)

z_in = multidevice_test.shard_tensor(z_in_ref, z_in_tv)
w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_p_in = torch.testing.make_tensor(
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
)
w_g_in = torch.testing.make_tensor(
c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
)
w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
mask = multidevice_test.shard_tensor(mask_ref, mask_tv)
(z_out,) = fd.execute(
[
z_in,
w_norm_in,
b_norm_in,
w_p_in,
w_g_in,
w_norm_out,
b_norm_out,
w_p_out,
w_g_out,
mask,
]
)
assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z)