Skip to content

Commit 4e9fda6

Browse files
committed
%cd: Safeguard more uses of inline declared tensors in declaring other tensors
-- `.value`, `.grad`, `.merge` cases.
1 parent 20e5c0a commit 4e9fda6

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

lib/ppx_cd.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ let guess_pun_hint ~punned filler_typ filler =
242242
Hashtbl.find punned name
243243
| (Tensor | Unknown), { pexp_desc = Pexp_ident _; _ } -> Some (hint, true)
244244
| (Tensor | Unknown), _ -> Some (hint, false)
245+
| ( ( Value_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
246+
| Grad_of_tensor { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
247+
| Merge_value { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ }
248+
| Merge_grad { pexp_desc = Pexp_ident { txt = Lident name; _ }; _ } ),
249+
_ )
250+
when Hashtbl.mem punned name ->
251+
Hashtbl.find punned name
245252
| (Value_of_tensor t | Grad_of_tensor t | Merge_value t | Merge_grad t), _ -> (
246253
let hint = [%expr [%e t].Tensor.value.Arrayjit.Tnode.label] in
247254
match t with { pexp_desc = Pexp_ident _; _ } -> Some (hint, true) | _ -> Some (hint, false))

0 commit comments

Comments
 (0)