-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
142 changes: 142 additions & 0 deletions
142
enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_mod1_split.ll
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s | ||
|
||
;#include <cblas.h> | ||
; | ||
;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double*, double, double); | ||
; | ||
;void outer(double *Afoo, double *Bfoo, double *Cfoo, double alpha, double beta) { | ||
; cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 4, 3, 2, alpha, Afoo, 4, Bfoo, 2, beta, Cfoo, 4); | ||
;} | ||
; | ||
;void g(double *A, double *B, double *C, double alpha, double beta) { | ||
; outer(A, B, C, alpha, beta); | ||
; B[0] = 10; | ||
;} | ||
; | ||
;int main() { | ||
; double A[] = {0.11, 0.12, 0.13, 0.14, | ||
; 0.21, 0.22, 0.23, 0.24}; | ||
; double B[] = {1011, 1012, | ||
; 1021, 1022, | ||
; 1031, 1032}; | ||
; double C[] = {0.00, 0.00, 0.00, 0.00, | ||
; 0.00, 0.00, 0.00, 0.00, | ||
; 0.00, 0.00, 0.00, 0.00}; | ||
; double A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; | ||
; double B1[] = {0, 0, 0, 0, 0, 0}; | ||
; double C1[] = {1, 3, 7, 11, | ||
; 0.00, 0.00, 0.00, 0.00, | ||
; 0.00, 0.00, 0.00, 0.00}; | ||
; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); | ||
;} | ||
|
||
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" | ||
target triple = "x86_64-unknown-linux-gnu" | ||
|
||
@__const.main.A = private unnamed_addr constant [8 x double] [double 1.100000e-01, double 1.200000e-01, double 1.300000e-01, double 1.400000e-01, double 2.100000e-01, double 2.200000e-01, double 2.300000e-01, double 2.400000e-01], align 16 | ||
|
||
define dso_local void @outer(double* %Afoo, double* %Bfoo, double* %Cfoo, double %alpha, double %beta) { | ||
entry: | ||
call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, double %alpha, double* %Afoo, i32 4, double* %Bfoo, i32 2, double %beta, double* %Cfoo, i32 4) | ||
ret void | ||
} | ||
|
||
declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) | ||
|
||
define dso_local void @g(double* %A, double* %B, double* %C, double %alpha, double %beta) { | ||
entry: | ||
call void @outer(double* %A, double* %B, double* %C, double %alpha, double %beta) | ||
%arrayidx = getelementptr inbounds double, double* %B, i64 0 | ||
store double 1.000000e+01, double* %arrayidx, align 8 | ||
ret void | ||
} | ||
|
||
define dso_local i32 @main() { | ||
entry: | ||
%A = alloca [8 x double], align 16 | ||
%B = alloca [6 x double], align 16 | ||
%C = alloca [12 x double], align 16 | ||
%A1 = alloca [8 x double], align 16 | ||
%B1 = alloca [6 x double], align 16 | ||
%C1 = alloca [12 x double], align 16 | ||
%0 = bitcast [8 x double]* %A to i8* | ||
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x double]* @__const.main.A to i8*), i64 64, i1 false) | ||
%1 = bitcast [6 x double]* %B to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false) | ||
%2 = bitcast i8* %1 to [6 x double]* | ||
%3 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 0 | ||
store double 1.011000e+03, double* %3, align 16 | ||
%4 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 1 | ||
store double 1.012000e+03, double* %4, align 8 | ||
%5 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 2 | ||
store double 1.021000e+03, double* %5, align 16 | ||
%6 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 3 | ||
store double 1.022000e+03, double* %6, align 8 | ||
%7 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 4 | ||
store double 1.031000e+03, double* %7, align 16 | ||
%8 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 5 | ||
store double 1.032000e+03, double* %8, align 8 | ||
%9 = bitcast [12 x double]* %C to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 96, i1 false) | ||
%10 = bitcast [8 x double]* %A1 to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 64, i1 false) | ||
%11 = bitcast [6 x double]* %B1 to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 48, i1 false) | ||
%12 = bitcast [12 x double]* %C1 to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %12, i8 0, i64 96, i1 false) | ||
%13 = bitcast i8* %12 to <{ double, double, double, double, [8 x double] }>* | ||
%14 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 0 | ||
store double 1.000000e+00, double* %14, align 16 | ||
%15 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 1 | ||
store double 3.000000e+00, double* %15, align 8 | ||
%16 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 2 | ||
store double 7.000000e+00, double* %16, align 16 | ||
%17 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 3 | ||
store double 1.100000e+01, double* %17, align 8 | ||
%arraydecay = getelementptr inbounds [8 x double], [8 x double]* %A, i32 0, i32 0 | ||
%arraydecay1 = getelementptr inbounds [8 x double], [8 x double]* %A1, i32 0, i32 0 | ||
%arraydecay2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 | ||
%arraydecay3 = getelementptr inbounds [6 x double], [6 x double]* %B1, i32 0, i32 0 | ||
%arraydecay4 = getelementptr inbounds [12 x double], [12 x double]* %C, i32 0, i32 0 | ||
%arraydecay5 = getelementptr inbounds [12 x double], [12 x double]* %C1, i32 0, i32 0 | ||
%call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 3.000000e+00) | ||
ret i32 0 | ||
} | ||
|
||
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1) | ||
|
||
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) | ||
|
||
declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) | ||
|
||
;CHECK:define internal { double*, double* } @augmented_outer(double* %Afoo, double* %"Afoo'", double* %Bfoo, double* %"Bfoo'", double* %Cfoo, double* %"Cfoo'", double %alpha, double %beta) { | ||
;CHECK-NEXT:entry: | ||
;CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i32 1) to i64), i64 8)) | ||
;CHECK-NEXT: %0 = bitcast i8* %malloccall to double* | ||
;CHECK-NEXT: %1 = bitcast double* %Afoo to i8* | ||
;CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %malloccall, i8* %1, i32 8, i1 false) | ||
;CHECK-NEXT: %malloccall2 = tail call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i32 1) to i64), i64 6)) | ||
;CHECK-NEXT: %2 = bitcast i8* %malloccall2 to double* | ||
;CHECK-NEXT: %3 = bitcast double* %Bfoo to i8* | ||
;CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %malloccall2, i8* %3, i32 6, i1 false) | ||
;CHECK-NEXT: %4 = insertvalue { double*, double* } undef, double* %0, 0 | ||
;CHECK-NEXT: %5 = insertvalue { double*, double* } %4, double* %2, 1 | ||
;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, double %alpha, double* nocapture readonly %Afoo, i32 4, double* nocapture readonly %Bfoo, i32 2, double %beta, double* %Cfoo, i32 4) | ||
;CHECK-NEXT: ret { double*, double* } %5 | ||
;CHECK-NEXT:} | ||
|
||
;CHECK:define internal { double, double } @diffeouter(double* %Afoo, double* %"Afoo'", double* %Bfoo, double* %"Bfoo'", double* %Cfoo, double* %"Cfoo'", double %alpha, double %beta, { double*, double* }) { | ||
;CHECK-NEXT:entry: | ||
;CHECK-NEXT: %1 = extractvalue { double*, double* } %0, 0 | ||
;CHECK-NEXT: %2 = extractvalue { double*, double* } %0, 1 | ||
;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 112, i32 4, i32 2, i32 3, double %alpha, double* nocapture readonly %"Cfoo'", i32 4, double* nocapture readonly %2, i32 3, double 1.000000e+00, double* %"Afoo'", i32 4) | ||
;CHECK-NEXT: %3 = bitcast double* %1 to i8* | ||
;CHECK-NEXT: tail call void @free(i8* %3) | ||
;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %1, i32 4, double* nocapture readonly %"Cfoo'", i32 4, double 1.000000e+00, double* %"Bfoo'", i32 2) | ||
;CHECK-NEXT: %4 = bitcast double* %2 to i8* | ||
;CHECK-NEXT: tail call void @free(i8* %4) | ||
;CHECK-NEXT: call void @cblas_dscal(i32 12, double %beta, double* %"Cfoo'", i32 1) | ||
;CHECK-NEXT: ret { double, double } zeroinitializer | ||
;CHECK-NEXT:} | ||
|
||
;CHECK:declare void @cblas_dscal(i32, double, double*, i32) |
130 changes: 130 additions & 0 deletions
130
enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_mod1_split.ll
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s | ||
|
||
;#include <cblas.h> | ||
; | ||
;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); | ||
; | ||
;void outer(float *Afoo, float *Bfoo, float *Cfoo, float alpha, float beta) { | ||
; cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 4, 3, 2, alpha, Afoo, 4, Bfoo, 2, beta, Cfoo, 4); | ||
;} | ||
; | ||
;void g(float *A, float *B, float *C, float alpha, float beta) { | ||
; outer(A, B, C, alpha, beta); | ||
; B[0] = 10; | ||
;} | ||
; | ||
;int main() { | ||
; float A[] = {0.11, 0.12, 0.13, 0.14, | ||
; 0.21, 0.22, 0.23, 0.24}; | ||
; float B[] = {1011, 1012, | ||
; 1021, 1022, | ||
; 1031, 1032}; | ||
; float C[] = {0.00, 0.00, 0.00, 0.00, | ||
; 0.00, 0.00, 0.00, 0.00, | ||
; 0.00, 0.00, 0.00, 0.00}; | ||
; float A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; | ||
; float B1[] = {0, 0, 0, 0, 0, 0}; | ||
; float C1[] = {1, 3, 7, 11, | ||
; 0.00, 0.00, 0.00, 0.00, | ||
; 0.00, 0.00, 0.00, 0.00}; | ||
; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); | ||
;} | ||
|
||
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" | ||
target triple = "x86_64-unknown-linux-gnu" | ||
|
||
@__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FBEB851E0000000, float 0x3FC0A3D700000000, float 0x3FC1EB8520000000, float 0x3FCAE147A0000000, float 0x3FCC28F5C0000000, float 0x3FCD70A3E0000000, float 0x3FCEB851E0000000], align 16 | ||
@__const.main.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.012000e+03, float 1.021000e+03, float 1.022000e+03, float 1.031000e+03, float 1.032000e+03], align 16 | ||
|
||
define dso_local void @outer(float* %Afoo, float* %Bfoo, float* %Cfoo, float %alpha, float %beta) { | ||
entry: | ||
call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, float %alpha, float* %Afoo, i32 4, float* %Bfoo, i32 2, float %beta, float* %Cfoo, i32 4) | ||
ret void | ||
} | ||
|
||
declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) | ||
|
||
define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { | ||
entry: | ||
call void @outer(float* %A, float* %B, float* %C, float %alpha, float %beta) | ||
%arrayidx = getelementptr inbounds float, float* %B, i64 0 | ||
store float 1.000000e+01, float* %arrayidx, align 4 | ||
ret void | ||
} | ||
|
||
define dso_local i32 @main() { | ||
entry: | ||
%A = alloca [8 x float], align 16 | ||
%B = alloca [6 x float], align 16 | ||
%C = alloca [12 x float], align 16 | ||
%A1 = alloca [8 x float], align 16 | ||
%B1 = alloca [6 x float], align 16 | ||
%C1 = alloca [12 x float], align 16 | ||
%0 = bitcast [8 x float]* %A to i8* | ||
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false) | ||
%1 = bitcast [6 x float]* %B to i8* | ||
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([6 x float]* @__const.main.B to i8*), i64 24, i1 false) | ||
%2 = bitcast [12 x float]* %C to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 48, i1 false) | ||
%3 = bitcast [8 x float]* %A1 to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 32, i1 false) | ||
%4 = bitcast [6 x float]* %B1 to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 24, i1 false) | ||
%5 = bitcast [12 x float]* %C1 to i8* | ||
call void @llvm.memset.p0i8.i64(i8* align 16 %5, i8 0, i64 48, i1 false) | ||
%6 = bitcast i8* %5 to <{ float, float, float, float, [8 x float] }>* | ||
%7 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 0 | ||
store float 1.000000e+00, float* %7, align 16 | ||
%8 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 1 | ||
store float 3.000000e+00, float* %8, align 4 | ||
%9 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 2 | ||
store float 7.000000e+00, float* %9, align 8 | ||
%10 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 3 | ||
store float 1.100000e+01, float* %10, align 4 | ||
%arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0 | ||
%arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0 | ||
%arraydecay2 = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 | ||
%arraydecay3 = getelementptr inbounds [6 x float], [6 x float]* %B1, i32 0, i32 0 | ||
%arraydecay4 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0 | ||
%arraydecay5 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0 | ||
%call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 3.000000e+00) | ||
ret i32 0 | ||
} | ||
|
||
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) | ||
|
||
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) | ||
|
||
declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) | ||
|
||
;CHECK:define internal { float*, float* } @augmented_outer(float* %Afoo, float* %"Afoo'", float* %Bfoo, float* %"Bfoo'", float* %Cfoo, float* %"Cfoo'", float %alpha, float %beta) { | ||
;CHECK-NEXT:entry: | ||
;CHECK-NEXT: %malloccall = tail call i8* @malloc(i64 mul (i64 ptrtoint (float* getelementptr (float, float* null, i32 1) to i64), i64 8)) | ||
;CHECK-NEXT: %0 = bitcast i8* %malloccall to float* | ||
;CHECK-NEXT: %1 = bitcast float* %Afoo to i8* | ||
;CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %malloccall, i8* %1, i32 8, i1 false) | ||
;CHECK-NEXT: %malloccall2 = tail call i8* @malloc(i64 mul (i64 ptrtoint (float* getelementptr (float, float* null, i32 1) to i64), i64 6)) | ||
;CHECK-NEXT: %2 = bitcast i8* %malloccall2 to float* | ||
;CHECK-NEXT: %3 = bitcast float* %Bfoo to i8* | ||
;CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %malloccall2, i8* %3, i32 6, i1 false) | ||
;CHECK-NEXT: %4 = insertvalue { float*, float* } undef, float* %0, 0 | ||
;CHECK-NEXT: %5 = insertvalue { float*, float* } %4, float* %2, 1 | ||
;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %Afoo, i32 4, float* nocapture readonly %Bfoo, i32 2, float %beta, float* %Cfoo, i32 4) | ||
;CHECK-NEXT: ret { float*, float* } %5 | ||
;CHECK-NEXT:} | ||
|
||
;CHECK:define internal { float, float } @diffeouter(float* %Afoo, float* %"Afoo'", float* %Bfoo, float* %"Bfoo'", float* %Cfoo, float* %"Cfoo'", float %alpha, float %beta, { float*, float* }) { | ||
;CHECK-NEXT:entry: | ||
;CHECK-NEXT: %1 = extractvalue { float*, float* } %0, 0 | ||
;CHECK-NEXT: %2 = extractvalue { float*, float* } %0, 1 | ||
;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 112, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"Cfoo'", i32 4, float* nocapture readonly %2, i32 3, float 1.000000e+00, float* %"Afoo'", i32 4) | ||
;CHECK-NEXT: %3 = bitcast float* %1 to i8* | ||
;CHECK-NEXT: tail call void @free(i8* %3) | ||
;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %1, i32 4, float* nocapture readonly %"Cfoo'", i32 4, float 1.000000e+00, float* %"Bfoo'", i32 2) | ||
;CHECK-NEXT: %4 = bitcast float* %2 to i8* | ||
;CHECK-NEXT: tail call void @free(i8* %4) | ||
;CHECK-NEXT: call void @cblas_sscal(i32 12, float %beta, float* %"Cfoo'", i32 1) | ||
;CHECK-NEXT: ret { float, float } zeroinitializer | ||
;CHECK-NEXT:} | ||
|
||
;CHECK:declare void @cblas_sscal(i32, float, float*, i32) |