Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of place broadcasting creates unexpected results #278

Closed
ToucheSir opened this issue Apr 13, 2022 · 7 comments · Fixed by #279
Closed

Out of place broadcasting creates unexpected results #278

ToucheSir opened this issue Apr 13, 2022 · 7 comments · Fixed by #279

Comments

@ToucheSir
Copy link

ToucheSir commented Apr 13, 2022

MWE from Slack:

using Enzyme, LinearAlgebra, Zygote

x = Float32[3, 2, 1]
w = Float32[1, 2, 3]
dw = zero(w)

loss(w, x) = sum(w .* x)

Enzyme.autodiff(loss, Active, Duplicated(w, dw), Const(x))
zdw = gradient(w -> loss(w, x), w) |> only
@show w, dw, zdw
@assert isapprox(dw, zdw; atol=1e-3)

Output:

(w, dw, zdw) = (Float32[1.0, 2.0, 3.0], Float32[3.0, 2.0, 1.0], Float32[4.0, 4.0, 4.0])
ERROR: AssertionError: isapprox(dw, zdw; atol = 0.001)
...
@wsmoses
Copy link
Member

wsmoses commented Apr 13, 2022

Just for fun can you also print w

@ToucheSir
Copy link
Author

I didn't save the original seed, but this repros with fixed inputs as well so I've updated the MWE to match.

@wsmoses
Copy link
Member

wsmoses commented Apr 13, 2022

This feels like what was fixed in EnzymeAD/Enzyme#604

I’ll try it locally to see if it resolves. If so I’ll do a jll bump with it and the fix for #277 once ready

@wsmoses
Copy link
Member

wsmoses commented Apr 15, 2022

Reducing a bit:

using Enzyme
Enzyme.API.printall!(true)
x = Float32[3]
w = Float32[1]
dw = zero(w)

# loss(w, x) = @inbounds (w .* x)[1]

loss(w, x) = @inbounds Base.materialize(Base.broadcasted(*,w,x))[1]

@show w, dw, x
Enzyme.autodiff(loss, Active, Duplicated(w, dw), Const(x))

@show w, dw, x
# (w, dw, x) = (Float32[1.0], Float32[3.0], Float32[4.0])

Incidentally the bug here is not that it computes the wrong derivative (it gets the right one) but somewhere activity analysis incorrectly believes x to be active and +='s its derivative into x.

Incidentally making x duplicated should get you the desired outcome. This also requires the broadcast and works fine without it. Looking into it here (https://fwd.gymni.ch/H2S9ke).

@vchuravy this relates to our earlier convo.

@wsmoses
Copy link
Member

wsmoses commented Apr 15, 2022

Reducing further

using Enzyme
Enzyme.API.printall!(true)
x = Float32[3]
w = Float32[1]
dw = zero(w)

# loss(w, x) = @inbounds (w .* x)[1]

#loss(w, x) = @inbounds Base.materialize(Base.broadcasted(*,w,x))[1]

function loss0(w, x)
    mid = Float32[0.0]
    mid[1] = w[1] * x[1]
    mid[1]
end

function loss(w, x)
  r = Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, (w, x), axes(w))

  dest = Float32[0]

   bcc = Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(r.f, (Base.Broadcast.extrude(w), Base.Broadcast.extrude(Base.Broadcast.broadcast_unalias(dest, x))), axes(w))
   # preprocess_args(dest, r.args), r.axes)
  @inbounds bcc[1]
end

@show w, dw, x
Enzyme.autodiff(loss, Active, Duplicated(w, dw), Const(x))

@show w, dw, x

https://enzyme.mit.edu/explorer/#z:OYLghAFBqd5TKALEBjA9gEwKYFFMCWALugE4A0BIEAZgQDbYB2AhgLbYgDkAjF%2BTXRMiAZVQtGIHgBYBQogFUAztgAKAD24AGfgCsp5eiyah69AG5tyKxqiIEh1ZpgDC6egFc2TKTwDM5M4AMgRM2AByXgBG2KQgAEw85AAO6ErEDkxunt6%2BAanp9kIhYZFsMXGJ1ti2RUwiRCykRNlePjz%2B1bWZDU1EJRHRsQlJSo3NrbkdAWN9A2UVIwCU1ugepKicXACkfgBCANQAslgejACSACIHu9dgXETY6kT321oAgkprG9gA%2BnSMVgcG5%2Ba7beLxR7PcHxN7vcbAbBEA6YFiNIwATzWyNuNwh2AAtGwQITkvEAOxaEB%2BeLU%2BIEsnknh0ukMim0gBs0hAXIJBC5POkBJoAA4qTx4iKCUwRVIOSzeSIJVKmFQeOLmRKpH4YXCEUiDkRSARkowQWCIeoRRzfrzkqgCfRQh51ATgEwPLqPnDdocAGIeJh2TIHeFGpQgA6kbAsTBMITYOE4VBGaM3cl7bbky4AKjzBwAArozgQWAA6RFEX7JYCoWaoADWECWB3o6HE9F%2BgaB2Ewv1jmFIeL8Wh93v2BwDQbqoaI4cj8cDAHdQpgozG4wmk9gU01sAd40wPWZ05nsznC8WneXUqFHqR/qR0GxfugorpozQIFmz9cB6QlGSFhNggDolhzFs2w7Lsj3YXt%2B0wQdhx4Md3mTVN9x/LM/0QgCgJA9VwMvTsmlIFgMV%2BDBkgxb8M2wg5/0A4DsFArRwMg9sJBgns%2B3/VDfUnQNgyEWd5wOUJgM2JR0iiRgODYIR6AxbddzTeMJFLJRT3oxj8JYwiLyLEizHbftSDIiieEwWjfwY3CmIItic3IcSuQ46Duzg3jcOHWFvQ%2BdC9xRL5figiQDnMdACDXAtfl%2BZgAC8MQ4fsPBIQgaC/MtsqWVCcDoMIIqimLMCUABHDw9xsnT7L01iiMPY96AYp13QOHgORRWJsBobqgxjWSWOkNi8QAVi0FysOzOzBwc/SnIPIQmpagg2o6rrPz6zYWEGiBhpbcFRp4dyuM8jhvKQrCPmYI0MRAOEDgODtmsi6KDggbKyxbWL4qYJKUpYNKsAITKIAIEULyiYhxDGd6aDbNF3qmnDZrqgzJro6bdOY%2BqIMMq9S1CtIlF%2BDpyWkMtQjCUgyx4Q10HEiGVm0rHapxgzFqPM5mo0tbOpwTbo36nbGD2kbDomlmUbw9mFsa7mVr5jaeq2gbRf2sbjoe9dkVezAfWzPKetCfd4fQRGixLFgiek0meHJymj1iWnqtZ1HZYazjkiIdZ93lk9o1jRSMUVph2v57reqF7bdo1iWMds7HHM9rmT15sP1oFlXo7VobxfiI6Ts7M74P/Yc/G067hFIO7tcO%2BJzSeiRmuRvN8atiskWrWt6ybXKPkew6K9xZ7Fo0lgtP9lvMelubceI/sTNQMyLNJ6zkZmmXk4vRcmHyze5ph5FXdzcThBIb2hzBrl2r8ABOcl4i0O%2BX/tjktA/%2BmpYgr%2BN6T%2Bb2IuX5NIRagZ97HRcmAMAmAojAAOFA8kdcC4gNxFDIgR8pYHzRgtIeX8zaI3/qBPwRFCHo3gdA2B5D7ZINGqNRu/8MH4ORKQ4hF5SE4OQXghGzC2YEVYVg9mx1nLkJgXAqB1CB5jU6riBGa4mECL4eBFy8iWEkN4fpIRY1RouXTgcEUkCKFiLAPbAxRAogsBYFQu%2BBip7kOkNrMYZBTbcIOFoMsH8PFaETPEPYH9lEuNUReQ6HIdGtTDrIERlDxHklMeYyxUD4j2MkYdck9D1EYL/uo%2Beh06YkEwaQoRBjRHkIpDQkUaT3abAyTPBRACgkFwrnkzJlSNFKMiUY0pySC5lnjDiUE4lUBsGSAcbAZV8lZLAvU0a%2BixoxPaSUxBkiohDgIHTQ6PT0BEBckYGILcC4E2tm2W2ZMKZU2dk8Yg2ydo1DGkEHgo1KZFKiWATp7w4R3IeVQR63yfm/L%2Bf8gFgKBLJGjKVRuh0bo1xoXfRusjGZAIhmfOc6BL7vWAbfB%2BT8X53zfloUanU8ngwgqE1aYcZlQOKVAvwSRyFmIsXY%2BIBjQjmCaKWYQZY4VQKSe8QeBd1T0L3ozMa1i9FPKMdIblvKHm9JkBsvp1wCCDOGaMoVOTJaSwpc8iV2tlniTWd0mVFNelXN2bc7pBATU3MOgcm2JMTmO2pmWC5Wz5lcu5e881kZAXep9b6kEhwQW9i0riQ6HzKY0I6I3UetjW6Xg7reYQsRHzPlfO%2BT8J9amgU0YdEUB14jUrFXY0aEaG6oOhhPHEM8pmRsJRDYRmrxXFq6UdYe/S4VEoRfmHJjLQ6itdWAaQ2jaVxKsTYpaCsuURpQf06N46TyxstteMsCb7zJpfG%2BD8PUM0FKIikvNBb%2B2DojXQst6CK3f01iA2teZC1cqbTyzW0i23mzXB2xmXa%2BVDt0eSwxRbYn0vESKqBtjJ3NrlfbRuiqhkjLGUSzWARH23rAKNPyD7dWrLGuBuZOyrUFyCH4ckjzWzXL2aNG1Ry7X21OU7GmzqkPv1QvhwjXy/WsbY/6g4gawUhrw%2BaiNqSR7N05stZpW86kL1IuRSiKKaKibnhzGNfLmYNvIRyFCSyhw4dI%2BR4mdsHZnNo%2BoS5/aGP%2BXeDp45VGHXnKM0QL17HfnAtBcG/pobzWTTw/cojELq4Yg86NJj4bm0ss8H8ZISACCU0buFgg4yWnZNGocHJczQ0Eai9sUalxJqJc1v5yFIcMtZZuDlnJ/mw2xcK9lpLfKyvmuK5cCN5TT3VMThMjh40uHmx4fF/waietOXo/eqVPAYW4gYeelREz%2BHsN3Xy8peTJs9emxMwpJmhtjSfrCl9BxFtiaIW03b8nls9ezXykVuiOqDf/fE4xQGwAgYHTQxIW3Yw7e4f4rrmbetTKfiStqESVNQI5EOqBdKbuJKe6W/paCWs1X67N0aIWPBhYi1FhbASpt9b22QwHYByZPdbbPPSjCMdLax/J9rEJOsEMx2wlbbTcf4%2BbYkl7cjSfY9YR9mnZO6cnYR/mv7ZKkPk2u6O8hD3JUbZPc%2B17TCufdY5/ziJuiAe/oQarsHJTJf1yfdcGgbAzhvc%2B/XGl9cQdq7x4sh90ZkTyJ1wbBrZmBJTmEmHMMAFIxNGAPJYOi0aDRj9mkDEQZQFMBXIKlcZgbfrCYCpDCRU3oFjMJYDlIMkQEA4GWWYzRl1aHBtfEBGe2Be87YtcQ3tfb7vdeOf0QkZzu4jAxUg3vsAKSYEpP3AfFpKGD6gUP4e1yR/oNH0gseAo7nj3rQsye2Cp96vYTPzhc/5/RUXkv76y8sAr9GKvqE0RGgIFENK2Bg35q0OCjM65YyHkwviP6yU/gSWDOYbxpbDYfH38aI/jxT/Uov4cXeAfK/TcQqbCOET/Q/Y/X/KHMES/J/KSGSOSVvX3MAj/OcL/KA8uf/PEeIEFN8QkXuXUUEGEUIJ0MIAkCeNgGEdMR3d4CA7/E/YcadWAw4L3H3dvEOeMf3bAQPHvEPQA1cA4IfEfMOVAt5D4KBGfMsBSTAM4bAMseGFgYAFzO4MAH8KBDVYxMQyQiwWfURMsVADwRuKBdQl5AxYtd/cQ94DQ4wtQjMAgGkA4HtEwiES4JcJoGgA4AANViHSCEBhCAUcPsUsPEVsJ/AcIbmcLUNcOwCPzgXOCYEEG8N8MyACPEkcJ1BCJeUbkIDGFCDsHIUuHODcCGQYGwAUFVCIAgCMBMEqkREjEuAAHVfggh3hwgABxX4AAKStmUTKMjCpRcjwNkM2DiBwIOTSIICUAAHlvYM8CAEpexIwjRkcXJSBAwF9sAfCAJMhIxJZW8pi/CmAABpVcSMAYAAFRFhP2mI4JcmYC8Eby5Rch7CuMGlOL3kjHCATH7msLAEJ0KPOD9DKNoDKJ7EjBhB2jGDIjsEkwxDLF0HoDSMIGjDsDIDuhwLLBhF%2BK5TCJnjhCgWl2uFyPsGnEBJKOSDKIqOIGqOMGADqM4AOCaJaLaM6J6OvD6MkFUyGKfBGOGHGKtkmJmLmLYAWKWMNHWOwDWI2Izy2JSMcFcXuNFOkkyA%2BMwHOJNjeMYBmLuJGQ9DYCeIHReLgi1OwDVK%2BJ%2BIJLAF10BOBNFgBGYDgghIhAAHpvAiAXTLg0QWAXTgBiAXTcB79M9ESXSlBHiyoESkSIQXIUSdwSAa5nT4gsSIQcS8dbCiighOI6hqiTZIxB05klAqJGSoEZlSCTZMB4QBjjEtYJCwAmt%2BkST8jkQoEiiRAPAog8DgAyI2AIBwScD3hzJyJfQ0iyCGwlCIgnT%2BzBzlJoyDhCyUViywARUHSqyRUiAMRkhFz1QXJAI/QjBlCGjzgRBVA9ylDLhjZKiQwswXAmSjyTz9zZiF8xTMAXJAxiAqyh0bcWByzvicBDTpBUzRt%2BkWygSQSHS%2ByYQoh0BNlIzkSCBUT4yMSYRkz4hUz%2BVcQQK2zllsRNSNyWJ1zNz/y0K1kGypjSSCiQLMzxBsyyDGTpAJQdyiyqyGKz5aKKy7MqEAKrSJQciyKmzASsLOzuzezJyYRKxVwnhhzZzRzxzwhRKIRxK95JLZz5zNzmKEMVyqEEMCKtzJZdz9zG9Wz7yzyLyMgRJrzbzjzTzgBHz5jFiXyDg3yOLCS1ikRvywhMBfyT8qyuLaz/B0zQL7SwT5L4g4TYLZzYy0SEzMTsTuLmCUQ%2BKyTKKszMgcywhIxfs5ymKqEh0yyPLKyqEOQ0KiSEq8ikqwBWz2yhL2ARKOBEyKMRzQgxzEQ5K6qBTrxbU9N7EVLsqmpOTFyQlWxcz2o5kdLmK9KiyBhIwTE5zkhrLDK7zrLzyCozLRDyQbyjLrLbLRT7LXzKiPzXLGgfysBvK7E0KbTGzyqMyUqhA0rGSJqFzmLRo0KBNgKKrzgqK0RUraLIw74tAENVKtzSz29yyCrxEay/ieB6ziTEqKL3rBKnwuyaqIKIQVBSScB1ATD9hGqmBmqJy2qYQ0aJLMboie1Ab1L%2BqKbDQ8Lxqdy5qDLDyrL9zlrQhVr0wNrFqHyRTny9r3zyFPy3Ljq/yfK0KgLVDTDMqElJZ0VRwsjNsMK7C9gXD4hdAwcV5yJL40iEkaUZbtCXkSLxaMxlbVa4l1aZyojZzda5aYDyFTCYQTaLEtaXkEMrbaCEkASTCjbSaHaWBdQ5b4rPalaXlzcC4gEb5Za3bg68Sg77a1a4Sj8aAnb4hBrXarTk7o7ja1bURGgk6e1U7ayKQArPqaLhq74/AIlyaSlga2Kwb/i2I07obSryLmz4aqrEbhKUb4hsLYwj4YJx4lApKe0ZKWqQru7MBe7Ax%2B60jK6Ellz%2BiSk1yaaqE9L6alCFqmaTKVqZwLLNquanzdrHL9r%2BbDr3LewvKiK06xbbSwLgqCaIQx6j5wqe1IrELEyULUyRxeKyq4brrqLvrhqerHqcqgEQb8rnLjEiqrS/8Fbf6vrbqfqDgCNGKgGqU87QHexa7y6P6bbLqf6jy270AkaezO7Kw8DNzmhlJsbpKmrZKQrSGnxyH1zp7sqqUNL57WGXIxrl66b5rGbjLgAWbLzzL1rLL%2BHtqebD6%2BaXKdZT7PKTqL7ay/APb3q7SWJwKQq9gJ55DES4KEL0S37YrFH4rcGW7YGS70qDhaEAaWH/jcr0H2KqyCMP6SqTGBKCGiHarGSSClA0EB6qGh6aGR6774gpjfHmGUH/jBrNKqVBquHxEV7eHRGlrTLt6RHd6lDxGD6nKDqZGhbTquUP6bSQLVHQTAQQrPxughAn6Yz4K4z9GYqUyoHXqYbv7TGPqbqmA7rprpAlHkG1LyE/Bq7QbwH74P7G7XHML3GO6QrRTgAkB0RNJB7tlAn8avGIRZn5mp7AH%2BnBidt2H/jOGl74meGGaknmaUmry0nOaMnuasmj7pGvy8mFG/j74v7m7ATi7/6LH/A74ZkZ6B1pb7Ha76LUzho3n%2BLJmOz27kaQrJ7rw/HMxqHcbaHgm4XFntnFy2GuTdm4njEEnTn0mBGLnhGOaN6bLbnexebwGBajqPLz6RarT6LwWrr2m/74HS7y6%2BnFzJQQGa6Rn67azpAbbA6uUENRWw6QFpab4UIsiemM7va1aMAmAxhc6JXXE9aJUi6Omun2oenrGInBneXhnHG75QWSrkq2XOmEHgcuWfKU6gXwHpBc1GWLrYa2mEbCHpnUXlWWBeppMmAX9mglmhrkWgm1n4hAwlBfW/glXA2iBwmdmB05lomk3DnE3tzZrEnCXBG2ad7rnyX97KXJHqWT6nmGXBXmnr6gqyngmV1nYdGIramoqkKIR37GXxm3WPntWEGOguRbW7ERU8qMHqWBW/jpAr6JnW6oXPWYXa2opE1SBg3h7VnEy63F2MXHHKaBm03dKTm16%2BHkmt7LnSWxGKWHLsnj7cm6X5Hy2/iOtcRJ2zGvnGSZBRoCybGjojWwHHHyRUyjpmW8GPWPHO7s7vzSol2VnWrw3QPooB6N3t29nsWDnqb038X92znN7WbUmT2tqz2qWcnHnr3hazqrSUMtXLWdXpAOQ/mP2XaHWPznrSOATH38Hp3gOQrbhOP/HlnQ2V2cCuOGt4PCSIkU3RoIlcWM39L0Ps3iW1qcO967Ki2L2HnBaiP8mB0/2A6VGb6a3w28CFI6hqmEq9HorkLDG72XHO2LW4GrWAHyQaOIn8Uv3h2PznXaz8UAP3WpnZ3w2saGtuOQ28aoPEy/OE3Fy32t3CTRqjm8W92DyMOiWj2SWEvMmlP7nkNS21PnmouAqSn1HgmFyyJ4ylAjOX76mzPGn3OO3Wmu2KOe2JUHrE3RpB36OBnf3SOJ2rOp3qriGQrfgpjcjdB5342Avl3gucD%2BulBBvhuwvN3EPMWd3abM2CX82c3sOUu8Pi2CPVOz6b2SPaz35yObOdWfmHPE3hpnOHG7EIagcDbavjuEG/AOQRV/mOQ0G%2BWqyORiErS3vPO3G2OvXw2qJKHEWAnePxuYRgfZutLIvjFtKYvJPV74uZOku5ONvC3z30uaXZH6X9u/ivu/vrPzGX36KK6bG%2B3WLjWBnTWfvjGuvKqAefPEy0WJ4ofRvIPYXWB4WoehPkOU34fUO4v17%2BG1vj30fFPMepGMur3dviOCmfvzX3rPn2WLHrQaV/mWKh2rugdIGDvXWauifn3IwvvBrXu5ktfa7qPUyORK2WOgPAfEzfgH6K1fhFKMaIPwe%2BvnexhXf08lL1BofZ7Yf4hF7BflvpPVvZP2bxedq0upfsey28egdqv3nDeVfGSORM/%2B2gcWuPvyFKRrfOuDfuvoXevgnff0bJL2fPfy%2B3flKyabGQ/g/Q/d3w/kfI/Ufo/CXUvJeS2Ze5G5eNOrTKRCelfu3hquQG%2BImbfLva6mRUymRR%2Bn30/Ixmv32InH5Z/wHyRvvazH4l/WOevPHEy6%2BA/q%2Bgu6G/f3feem/5uqzb%2BJO0P2%2ByXRfkvu/NvlPpfCPZf1PfK/id%2BD%2ByvWzhY0STZ9jEdjPPuIl17/8tOy/YAYyTV5gDNerXBBIxz36Wdi%2BcAnVibzAEz9Ke37fPtAIQT69U%2BY/OrhPzwGvdc%2BVPZPgv0rZp94BeZLQGTw35DMCBJZUdgghT4QsS%2BM7MvuGxzAe8L%2BwTQQbz3JBz0kO4gxbtwzb7C9D2WHMXu/wx74dL23/Afr/wX5X1imOnR0sEyYSlcm2r9BpqhStJihAB4/Cxjb0a6LkRQOtFAXWRMG1lbBB/e3kzwm7e8qwp/eKCFkTrn8UW4bJ3k%2BB7ou8vB2AHwYHxeQSDuWLfJblJ2f4i8o%2BebMlj3xUEqdaWP/bLg4PMHkDLBPTXAeb3sHz8uAKwegNwFGj8AfAXAHQOQHQDcAggQQLwkcAODnAAASnOW%2BCbBkI/AOzFUOKErAGwIAckByDLAih8UTrLQMnQlT4okgpQrgNIAqHaB%2BAtQrgPwAjATQehOgFYHAFgAoAMApRRgBQCoAQA9hlJA4SABRREAJQBIRKA/iJDuUBADAe8BGAgBRBFh5AKGKwATK8B%2BAewjgMIFuJKQ3hOAYvCYEkC9DyAL9AgC/gjDgingO4Y/G8LvA1A3hToZZE0AxBuAcAbwg/MSG%2BErBFCyhLwgQGwBLhZizAbgHwDkDCAxAEgTgDICpGKAVAGgN4foBmF0kQAM%2BQwIfgjCQAVgFwzIDCNWE1A4yuxCAM4EmDtBAge8eYEMDiBJACgq1CUQYAVF1AZR5QYYKMGFGu5egEwdwG0AMA2ARRQgHUf0BNgLANR1gcYC0D1FTBLRcwM0bKKkArAvg6wTYAYCNA8EKRxQwwGUIWHgjlhFw9qA3AgC0IOg0gXsC2AgD1DGhzQloS5HwDEAyAyEFyBSTKJDhwQ1KJYN0MWFbCUA8FSgE4EIDxlzg6pJIIIGpEdg6RsgcsYyLUCaBwRrIroEaJ8Bii94SopIMEAdHqi5RKQNIIqJtGSiVRmQNUYsE1GVN6gVo9sU2O1FWiRxFo7PNaJyCSiFxc4uUUsBKG%2BjyAlQ6ocsOjFNDWh3hYkUuFiAHBAxPFEMUdBkARiIACY%2BMsmIOCpiDhyELMeQA2F9DyAAw/NLTB6ZHQH4fgWhD0z%2BqyBZh8wrcW8OWGrCQA6wnMRuK4DxB%2BAxIL7mWESCjQaQI2GQMDio4xJtxSw7gNmN6FbDEAEAFANiGSBpQCxxw58KcItG9hExPYmsTSMkD0iaxygOsdiIDxejyAS4MiMkC9GwTyhYE/0dwGmJpQyJyIdAJ4TPHBjQxV4zAJGMfEniMxx0fCZsJWBIANwwwZsD6LmEITqQww%2BzlhLFCjDy6qEgIDhJqF4TrAUE18TBO0l%2BA/RO4yyW%2BPXHkBA2RxEANICAA%3D%3D

@wsmoses
Copy link
Member

wsmoses commented Apr 15, 2022

using Enzyme
Enzyme.API.printall!(true)
x = Float32[3]
w = Float32[1]
dw = zero(w)

function loss(w, x, cond)
   if cond
      x = copy(x)
   end
  @inbounds w[1] * x[1]
end

@show w, dw, x
Enzyme.autodiff(loss, Active, Duplicated(w, dw), Const(x), Const(false))

@show w, dw, x

@wsmoses
Copy link
Member

wsmoses commented Apr 16, 2022

An activity analysis fix in Enzyme proper fixes the minimal case, still investigating below:

using Enzyme
Enzyme.API.printall!(true)
Enzyme.API.printactivity!(true)
x = Float32[3]
w = Float32[1]
dw = zero(w)

loss2(w, x, c) = @inbounds Base.materialize(Base.broadcasted(*,w,x))[1]

function loss(w, x, cond)
  r = Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(*, (w, x), axes(w))

  # m = Base.materialize(r)
  # @inbounds m[1]
  dest = Float32[0]

   bcc = Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(r.f, (Base.Broadcast.extrude(Base.Broadcast.broadcast_unalias(dest, w)), Base.Broadcast.extrude(Base.Broadcast.broadcast_unalias(dest, x))), axes(w))
   # preprocess_args(dest, r.args), r.axes)
   dest[1] = bcc[1]
  @inbounds dest[1]
end

@show w, dw, x
Enzyme.autodiff(loss, Active, Duplicated(w, dw), Const(x), Const(false))

@show w, dw, x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants