Skip to content

Commit

Permalink
MLIR: complex
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 19, 2024
1 parent e0ced08 commit 5bb3600
Show file tree
Hide file tree
Showing 14 changed files with 269 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class FloatTypeInterface
Value b) const {
return builder.create<arith::AddFOp>(loc, a, b);
}
Value createConjOp(Type self, OpBuilder &builder, Location loc,
Value a) const {
return a;
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
Expand Down Expand Up @@ -72,6 +76,15 @@ class TensorTypeInterface
return builder.create<arith::AddFOp>(loc, a, b);
}

Value createConjOp(Type self, OpBuilder &builder, Location loc,
Value a) const {
auto tenType = self.cast<TensorType>();
auto ET = tenType.getElementType();
auto iface = cast<AutoDiffTypeInterface>(ET);
auto added = iface.createConjOp(builder, loc, a);
return added;
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Expand Down Expand Up @@ -100,6 +113,11 @@ class IntegerTypeInterface
return builder.create<arith::AddIOp>(loc, a, b);
}

Value createConjOp(Type self, OpBuilder &builder, Location loc,
Value a) const {
return a;
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ set(LLVM_TARGET_DEFINITIONS MemRefDerivatives.td)
enzyme_tablegen(MemRefDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(MemRefDerivativesIncGen)

set(LLVM_TARGET_DEFINITIONS ComplexDerivatives.td)
enzyme_tablegen(ComplexDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(ComplexDerivativesIncGen)

set(LLVM_TARGET_DEFINITIONS MathDerivatives.td)
enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(MathDerivativesIncGen)
Expand All @@ -41,6 +45,7 @@ add_mlir_library(MLIREnzymeImplementations
LLVMAutoDiffOpInterfaceImpl.cpp
NVVMAutoDiffOpInterfaceImpl.cpp
MemRefAutoDiffOpInterfaceImpl.cpp
ComplexAutoDiffOpInterfaceImpl.cpp
FuncAutoDiffOpInterfaceImpl.cpp
LinalgAutoDiffOpInterfaceImpl.cpp
BuiltinAutoDiffTypeInterfaceImpl.cpp
Expand All @@ -52,6 +57,7 @@ add_mlir_library(MLIREnzymeImplementations
MLIRAutoDiffOpInterfaceIncGen
AffineDerivativesIncGen
ArithDerivativesIncGen
ComplexDerivativesIncGen
LLVMDerivativesIncGen
FuncDerivativesIncGen
NVVMDerivativesIncGen
Expand All @@ -63,6 +69,7 @@ add_mlir_library(MLIREnzymeImplementations
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRFuncDialect
MLIRComplexDialect
MLIRLLVMDialect
MLIRMemRefDialect
MLIREnzymeAutoDiffInterface
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,17 @@ class ConstantFP<string val, string dialect_, string op_, string type_=""> : Ope
string type = type_;
}

class ConjIfComplex<string dialect_, string op_> : Operation</*primal*/1, /*shadow*/0> {
string dialect = dialect_;
string opName = op_;
}

def ResultTypes : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op->getResultTypes()">;

def TypeOf : Operation</*primal*/0, /*shadow*/0> {
}

class ComplexInst<string m> : Inst<m, "complex">;
class ArithInst<string m> : Inst<m, "arith">;
class MathInst<string m> : Inst<m, "math">;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===- ComplexAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR complex dialect.
//
//===----------------------------------------------------------------------===//

#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"

#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
#include "Implementations/ComplexDerivatives.inc"


class ComplexTypeInterface
: public AutoDiffTypeInterface::ExternalModel<ComplexTypeInterface,
ComplexType> {
public:
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
auto fltType = self.cast<ComplexType>().getElementType().cast<FloatType>();
mlir::Attribute attrs[2] = {
builder.getFloatAttr(fltType, APFloat(fltType.getFloatSemantics(), 0)),
builder.getFloatAttr(fltType, APFloat(fltType.getFloatSemantics(), 0))
};
return builder.create<complex::ConstantOp>(
loc, self, builder.getArrayAttr(attrs));
}

Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
Value b) const {
return builder.create<complex::AddOp>(loc, a, b);
}
Value createConjOp(Type self, OpBuilder &builder, Location loc, Value a) const {
return builder.create<complex::ConjOp>(loc, a);
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
}

bool isMutable(Type self) const { return false; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
}
};
} // namespace

void mlir::enzyme::registerComplexDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, complex::ComplexDialect *) {
registerInterfaces(context);
ComplexType::attachInterface<ComplexTypeInterface>(*context);
});
}
17 changes: 17 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
include "Common.td"

def : MLIRDerivative<"complex", "AddOp", (Op $x, $y),
[
(DiffeRet),
(DiffeRet),
]
>;

def CMul : ComplexInst<"MulOp">;

def : MLIRDerivative<"complex", "MulOp", (Op $x, $y),
[
(CMul (DiffeRet), $y),
(CMul (DiffeRet), $x)
]
>;
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces(
enzyme::registerNVVMDialectAutoDiffInterface(registry);
enzyme::registerMathDialectAutoDiffInterface(registry);
enzyme::registerMemRefDialectAutoDiffInterface(registry);
enzyme::registerComplexDialectAutoDiffInterface(registry);
enzyme::registerSCFDialectAutoDiffInterface(registry);
enzyme::registerCFDialectAutoDiffInterface(registry);
enzyme::registerLinalgDialectAutoDiffInterface(registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry);
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry);
void registerNVVMDialectAutoDiffInterface(DialectRegistry &registry);
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry);
void registerComplexDialectAutoDiffInterface(DialectRegistry &registry);
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class PointerTypeInterface
llvm_unreachable("TODO");
}

Value createConjOp(Type self, OpBuilder &builder, Location loc,
Value a) const {
llvm_unreachable("TODO");
}

Type getShadowType(Type self, unsigned width) const {
assert(width == 1 && "unsupported width != 1");
return self;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ class MemRefTypeInterface
return self;
}

Value createConjOp(Type self, OpBuilder &builder, Location loc,
Value a) const {
llvm_unreachable("TODO");
}

bool isMutable(Type self) const { return true; }

LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> {
/*methodName=*/"getShadowType",
/*args=*/(ins "unsigned":$width)
>,
InterfaceMethod<
/*desc=*/[{
Construct complex conjugate for the given type.
}],
/*retTy=*/"::mlir::Value",
/*methodName=*/"createConjOp",
/*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::Location":$loc, "::mlir::Value":$a)
>,
InterfaceMethod<
/*desc=*/[{
Returns whether the type is mutable in place or not.
Expand Down
57 changes: 56 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "PassDetails.h"
#include "Passes/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"

#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -64,13 +65,67 @@ struct SubSimplify : public OpRewritePattern<arith::SubFOp> {
}
};

bool isZero(mlir::Value v) {
ArrayAttr lhs;
matchPattern(v, m_Constant(&lhs));
if (lhs) {
for (auto e : lhs) {
if (!e.cast<FloatAttr>().getValue().isZero())
return false;
}
return true;
}
return false;
}

struct CAddSimplify : public OpRewritePattern<complex::AddOp> {
using OpRewritePattern<complex::AddOp>::OpRewritePattern;

LogicalResult matchAndRewrite(complex::AddOp op,
PatternRewriter &rewriter) const final {

if (isZero(op.getLhs())) {
rewriter.replaceOp(op, op.getRhs());
return success();
}

if (isZero(op.getRhs())) {
rewriter.replaceOp(op, op.getLhs());
return success();
}

return failure();
}
};

struct CSubSimplify : public OpRewritePattern<complex::SubOp> {
using OpRewritePattern<complex::SubOp>::OpRewritePattern;

LogicalResult matchAndRewrite(complex::SubOp op,
PatternRewriter &rewriter) const final {

if (isZero(op.getRhs())) {
rewriter.replaceOp(op, op.getLhs());
return success();
}

if (isZero(op.getLhs())) {
rewriter.replaceOpWithNewOp<complex::NegOp>(op, op.getRhs());
return success();
}

return failure();
}
};

struct MathematicSimplification
: public enzyme::MathematicSimplificationPassBase<
MathematicSimplification> {
void runOnOperation() override {

RewritePatternSet patterns(&getContext());
patterns.insert<AddSimplify, SubSimplify>(&getContext());
patterns.insert<AddSimplify, SubSimplify, CAddSimplify, CSubSimplify>(
&getContext());

GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -55,6 +56,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::LLVM::LLVMDialect>();
registry.insert<mlir::memref::MemRefDialect>();
registry.insert<mlir::async::AsyncDialect>();
registry.insert<mlir::complex::ComplexDialect>();
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::arith::ArithDialect>();
registry.insert<mlir::cf::ControlFlowDialect>();
Expand Down
21 changes: 21 additions & 0 deletions enzyme/test/MLIR/ReverseMode/csquare.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s

module {
func.func @square(%x: complex<f64>) -> complex<f64> {
%next = complex.mul %x, %x : complex<f64>
return %next : complex<f64>
}

func.func @dsquare(%x: complex<f64>, %dr: complex<f64>) -> complex<f64> {
%r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme<activity enzyme_active>], ret_activity=[#enzyme<activity enzyme_activenoneed>] } : (complex<f64>, complex<f64>) -> complex<f64>
return %r : complex<f64>
}
}

// CHECK: func.func private @diffesquare(%arg0: complex<f64>, %arg1: complex<f64>) -> complex<f64>
// CHECK-NEXT: %0 = complex.conj %arg1 : complex<f64>
// CHECK-NEXT: %1 = complex.mul %0, %arg0 : complex<f64>
// CHECK-NEXT: %2 = complex.conj %1 : complex<f64>
// CHECK-NEXT: %3 = complex.add %2, %2 : complex<f64>
// CHECK-NEXT: return %3 : complex<f64>
// CHECK-NEXT: }
Loading

0 comments on commit 5bb3600

Please sign in to comment.