Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRITON] Sync with triton upstream #19

Merged
merged 10 commits into from Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 10 additions & 13 deletions .github/workflows/python-app.yml
Expand Up @@ -12,6 +12,10 @@ on:
permissions:
contents: read

concurrency:
group: ${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}

jobs:
build:
runs-on: ubuntu-latest
Expand All @@ -32,29 +36,22 @@ jobs:
pip install pre-commit
pre-commit run --all-files

- name: Cache Dependencies
uses: actions/cache@v3
id: cache-pip
with:
path: /opt/hostedtoolcache/Python/3.10.13/x64
key: ${{ runner.os }}-pip-3.10-${{ hashFiles('**/setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install Dependencies
if: steps.cache-pip.outputs.cache-hit != 'true'
run: |
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip uninstall pytorch-triton -y

- name: Clone Triton and Install
run: |
git clone https://github.com/openai/triton.git
cd triton/python
pip install -e .

- name: Install Dependencies if Cache Missed
if: steps.cache-pip.outputs.cache-hit != 'true'
- name: Install Triton-Viz
run: |
cd triton_viz
pip install -e .
pre-commit install
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
pip uninstall pytorch-triton -y

- name: Test with pytest
run: |
Expand Down
3 changes: 2 additions & 1 deletion examples/vec_add.py
Expand Up @@ -31,7 +31,8 @@ def add_kernel(
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
output = tl.zeros(x.shape, dtype=x.dtype)
output = output + x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)

Expand Down
17 changes: 12 additions & 5 deletions triton_viz/draw.py
Expand Up @@ -31,8 +31,17 @@
DEFAULT = Color("grey")
BLACK = Color("black")
GREY = Color("grey")
palette = "#f29f05,#f25c05,#d6568c,#4d8584,#a62f03,#400d01,#274001,#828a00,".split(",")
ACTIVE = list([Color(p) for p in palette])
palette = [
"#f29f05",
"#f25c05",
"#d6568c",
"#4d8584",
"#a62f03",
"#400d01",
"#274001",
"#828a00",
]
ACTIVE = [Color(p) for p in palette]

MRATIO = 1 / 3

Expand All @@ -52,9 +61,7 @@ def box(d: Diagram, width: float, height: float, outer=0.2) -> Diagram:
def reshape(d: Diagram) -> Diagram:
"Use log-scale if ratio is too sharp"
h, w = d.get_envelope().height, d.get_envelope().width
if h / w > MRATIO:
d = d.scale_y(math.log(h + 1, 2) / h).scale_x(math.log(w + 1, 2) / w)
elif w / h > MRATIO:
if (h / w > MRATIO) or (w / h > MRATIO):
d = d.scale_y(math.log(h + 1, 2) / h).scale_x(math.log(w + 1, 2) / w)
return d

Expand Down
64 changes: 40 additions & 24 deletions triton_viz/interpreter.py
Expand Up @@ -18,17 +18,17 @@
GridExecutor,
_implicit_cvt,
RESERVED_KWS,
builder,
interpreter_builder,
)
from triton.runtime.interpreter import _patch_lang as triton_patch_lang
from triton.runtime import JITFunction
from typing import Tuple, List, Optional
from contextlib import contextmanager
from functools import wraps


def _patch_lang(fn):
from triton.runtime.interpreter import _patch_lang as patch_lang

patch_lang(fn)
triton_patch_lang(fn)
tl.sum = _create_reduce(tl.reduce, "sum")
tl.min = _create_reduce(tl.reduce, "min")
tl.max = _create_reduce(tl.reduce, "max")
Expand Down Expand Up @@ -169,21 +169,26 @@ def _grid_executor_call(self, *args_dev, **kwargs):
grid = self.grid(call_args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
builder.set_grid_dim(*grid)
interpreter_builder.set_grid_dim(*grid)
record_builder.set_grid_dim(*grid)
record_builder.add_tensors(tensors)
record_builder.sort_tensor_handles()
for x in range(grid[0]):
for y in range(grid[1]):
for z in range(grid[2]):
builder.set_grid_idx(x, y, z)
interpreter_builder.set_grid_idx(x, y, z)
record_builder.set_grid_idx(x, y, z)
self.fn(**call_args)
# Copy arguments back to propagate side-effects
self._restore_args_dev(args_dev, args_hst)
_unpatch_lang()


def _jit_function_call(self, *args, **kwargs):
triton_patch_lang(self.fn)
return self.fn(*args, **kwargs)


def check_out_of_bounds_access(ptrs, masks):
first_ptr = np.reshape(ptrs.data, (-1))[0]
tensor_ptr = record_builder.get_tensor_ptr(first_ptr)
Expand Down Expand Up @@ -336,26 +341,37 @@ def wrapper(input, axis=None, keep_dims=False):
@contextmanager
def patch():
old_grid_executor_call = GridExecutor.__call__
old_create_make_range = builder.create_make_range
old_create_masked_load = builder.create_masked_load
old_create_expand_dims = builder.create_expand_dims
old_binary_op = builder.binary_op
old_create_dot = builder.create_dot
old_create_masked_store = builder.create_masked_store
old_jit_function_call = JITFunction.__call__
old_create_make_range = interpreter_builder.create_make_range
old_create_masked_load = interpreter_builder.create_masked_load
old_create_expand_dims = interpreter_builder.create_expand_dims
old_binary_op = interpreter_builder.binary_op
old_create_dot = interpreter_builder.create_dot
old_create_masked_store = interpreter_builder.create_masked_store
GridExecutor.__call__ = _grid_executor_call
builder.create_make_range = _create_make_range(builder.create_make_range)
builder.create_masked_load = _create_masked_load(builder.create_masked_load)
builder.create_expand_dims = _create_expand_dims(builder.create_expand_dims)
builder.binary_op = _create_binary_op(builder.binary_op)
builder.create_dot = _create_dot(builder.create_dot)
builder.create_masked_store = _create_masked_store(builder.create_masked_store)
JITFunction.__call__ = _jit_function_call
interpreter_builder.create_make_range = _create_make_range(
interpreter_builder.create_make_range
)
interpreter_builder.create_masked_load = _create_masked_load(
interpreter_builder.create_masked_load
)
interpreter_builder.create_expand_dims = _create_expand_dims(
interpreter_builder.create_expand_dims
)
interpreter_builder.binary_op = _create_binary_op(interpreter_builder.binary_op)
interpreter_builder.create_dot = _create_dot(interpreter_builder.create_dot)
interpreter_builder.create_masked_store = _create_masked_store(
interpreter_builder.create_masked_store
)
try:
yield
finally:
GridExecutor.__call__ = old_grid_executor_call
builder.create_make_range = old_create_make_range
builder.create_masked_load = old_create_masked_load
builder.create_expand_dims = old_create_expand_dims
builder.binary_op = old_binary_op
builder.create_dot = old_create_dot
builder.create_masked_store = old_create_masked_store
JITFunction.__call__ = old_jit_function_call
interpreter_builder.create_make_range = old_create_make_range
interpreter_builder.create_masked_load = old_create_masked_load
interpreter_builder.create_expand_dims = old_create_expand_dims
interpreter_builder.binary_op = old_binary_op
interpreter_builder.create_dot = old_create_dot
interpreter_builder.create_masked_store = old_create_masked_store
1 change: 1 addition & 0 deletions triton_viz/trace.py
Expand Up @@ -21,6 +21,7 @@ def trace(kernel):


def dump(path: str):
# TODO: Dump the record_builder to a file
for launch in record_builder.launches:
print(launch)

Expand Down