Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,17 @@ print(torch.norm(Z))
```
**Note**: you don't need Pytorch geometric to use our kernels. When
`deterministic=False`, the `sender` and `receiver` indices can have
arbitrary order.
arbitrary order.

**New:** If you're working in FP32 precision and want
higher accuracy during graph convolution, we offer a Kahan
summation variant of our deterministic algorithm:

```python
tp_conv_kahan = oeq.TensorProductConv(problem, torch_op=True, deterministic=True, kahan=True)
Z = tp_conv_kahan.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm)
print(torch.norm(Z))
```

## Installation
We currently support Linux systems only.
Expand Down Expand Up @@ -172,6 +182,7 @@ python tests/benchmark.py -o outputs/uvu uvu --plot
python tests/benchmark.py -o outputs/uvw uvw --plot
python tests/benchmark.py -o outputs/roofline roofline --plot
python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures
python tests/benchmark.py -o outputs/kahan_conv kahan_conv --data data/molecular_structures/
```

If your GPU has limited memory, you might want to try
Expand Down
39 changes: 22 additions & 17 deletions openequivariance/benchmark/ConvBenchmarkSuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ def __init__(self, configs,
num_warmup = 10,
num_iter = 30,
reference_impl=None,
torch_op=True,
test_name=None,
prng_seed = 12345):
prng_seed = 12345,
correctness_threshold = 1e-5):
self.configs = configs
self.num_warmup = num_warmup
self.num_iter = num_iter
self.reference_impl = reference_impl
self.prng_seed = 12345
self.correctness_threshold = 1e-5
self.torch_op = torch_op
self.correctness_threshold = correctness_threshold
self.exp_count = 0
self.test_name = test_name

self.millis_since_epoch = round(time.time() * 1000)

def run(self, graph, implementations, direction, output_folder=None, correctness=True, double_backward_correctness=False, benchmark=True):
def run(self, graph, implementations, direction, output_folder=None,
correctness=True, benchmark=True, high_precision_ref=False):
if output_folder is None:
if oeq._check_package_editable():
output_folder = oeq._editable_install_output_path / f"{self.millis_since_epoch}"
Expand All @@ -65,20 +65,15 @@ def run(self, graph, implementations, direction, output_folder=None, correctness
for impl in implementations:
tc_name = f"{config}, {impl.name()}"
logger.info(f'Starting {tc_name}, graph {graph.name}, {direction}')
conv = impl(config, torch_op=self.torch_op)

if double_backward_correctness:
double_backward_correctness = conv.test_correctness_double_backward(self.graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
reference_implementation=self.reference_impl)
conv = impl(config)

if direction == "forward":
if correctness:
correctness = conv.test_correctness_forward(graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
reference_implementation=self.reference_impl)
reference_implementation=self.reference_impl,
high_precision_ref=high_precision_ref)

if benchmark:
benchmark = conv.benchmark_forward(self.num_warmup,
Expand All @@ -90,23 +85,33 @@ def run(self, graph, implementations, direction, output_folder=None, correctness
correctness = conv.test_correctness_backward(graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
reference_implementation=self.reference_impl)
reference_implementation=self.reference_impl,
high_precision_ref=high_precision_ref)

if benchmark:
benchmark = conv.benchmark_backward(self.num_warmup,
self.num_iter, graph, prng_seed=12345)

if direction == "double_backward":
if correctness:
correctness = conv.test_correctness_double_backward(self.graph,
thresh=self.correctness_threshold,
prng_seed=self.prng_seed,
reference_implementation=self.reference_impl,
high_precision_ref=high_precision_ref)

assert not benchmark

result = {
"config": str(config),
"irrep_dtype": str(config.irrep_dtype),
"weight_dtype": str(config.weight_dtype),
"torch_overhead_included": self.torch_op,
"torch_overhead_included": conv.torch_op,
"direction": direction,
"graph": graph.name,
"name": impl.name(),
"correctness": correctness,
"benchmark": benchmark,
"double_backward_correctness": double_backward_correctness
"benchmark": benchmark
}

fname = pathlib.Path(f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json")
Expand Down
32 changes: 27 additions & 5 deletions openequivariance/implementations/ComputationSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from openequivariance.implementations.TensorProductBase import *
logger = getLogger()

class SMEMCapacityException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)

class IrrepMapping:
'''
Maps irreps from a source to a destination set.
Expand Down Expand Up @@ -104,7 +109,7 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi
segments.append((cL1, cL2, cL3, cinst))
cL3, cinst = set(), []
else:
raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
else:
cL3.add(w)
cinst.append(inst_idx)
Expand All @@ -130,7 +135,7 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi
segments.append((cL1, cL2, cL3, cinst))
cL1, cL2, cL3, cinst = set(), set(), set(), []
else:
raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
else:
cL1.add(u)
cL2.add(v)
Expand Down Expand Up @@ -245,10 +250,15 @@ def __init__(self,
weight_dtype,
include_scratch=False,
stream_weights=False,
schedule_type=2):
schedule_type=2,
kahan=False):
'''
smem_limit: size of available shared memory in bytes
'''
self.kahan = kahan
if kahan:
assert irrep_dtype == weight_dtype == np.float32

# Note: does not work with variances for irreps; easy to add that in
self.total_warps = warps_per_block * block_count

Expand Down Expand Up @@ -288,10 +298,16 @@ def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs):
"L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L3": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L3_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr},
"weights": {"size": 0, "dtype": self.weight_dtype_cstr},
"scratch": {"size": 0, "dtype": self.weight_dtype_cstr}
}

if kahan:
smem["L3_kahan"]["size"] = smem["L3"]["size"]
else:
smem.pop("L3_kahan")

weights_smem = 0
for inst_idx in inst_idxs:
inst = self.new_instructions[inst_idx]
Expand Down Expand Up @@ -325,6 +341,7 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs,
smem = {
"L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L1_grad": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L1_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr},
"L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L2_grad": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
"L3_grad": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
Expand All @@ -333,6 +350,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs,
"scratch": {"size": 0, "dtype": self.weight_dtype_cstr}
}

if kahan:
smem["L1_kahan"]["size"] = smem["L1"]["size"]
else:
smem.pop("L1_kahan")

if L2_dgrad:
smem["L2_dgrad"] = {"size": smem["L2"]["size"], "dtype": self.irrep_dtype_cstr}

Expand Down Expand Up @@ -376,11 +398,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs,
schedule2_succeeded = False
try:
if schedule_type != 2:
raise Exception("Asked for schedule case 3.")
raise SMEMCapacityException("Asked for schedule case 3.")
self.segments = create_schedule_case2(self.new_instructions, self.memory_per_warp, calculate_smem, direction)
logger.info(f"{direction.title()} case 2 scheduling succeeded with {len(self.segments)} segments.")
schedule2_succeeded = True
except Exception as e:
except SMEMCapacityException as e:
self.segments = create_schedule_case3(self.new_instructions, self.memory_per_warp, calculate_smem, direction)
logger.info(f"{direction.title()} case 3 scheduling succeeded with {len(self.segments)} segments.")

Expand Down
Loading