|
32 | 32 | let get_ident = |
33 | 33 | Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc) |
34 | 34 |
|
35 | | - let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 341) |
| 35 | + let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn) |
36 | 36 |
|
37 | 37 | let pp_zero_out ppf tn = |
38 | 38 | Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn |
@@ -72,15 +72,16 @@ struct |
72 | 72 | if not @@ Hash_set.mem is_global tn then |
73 | 73 | let ctx_ptr = B.hardcoded_context_ptr in |
74 | 74 | let mem : (Tn.memory_mode * int) option = tn.memory_mode in |
75 | | - match (in_ctx tn, ctx_ptr, ctx_arrays, B.use_host_memory, mem) with |
76 | | - | true, Some get_ptr, Some ctx_arrays, _, _ -> |
| 75 | + match (in_ctx tn, ctx_ptr, ctx_arrays, mem) with |
| 76 | + | Some true, Some get_ptr, Some ctx_arrays, _ -> |
77 | 77 | let ident = get_ident tn in |
78 | 78 | let ctx_array = |
79 | 79 | Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn |
80 | 80 | in |
81 | 81 | fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec); |
82 | 82 | Hash_set.add is_global tn |
83 | | - | false, _, _, true, Some (Hosted _, _) -> |
| 83 | + | Some false, _, _, Some (Hosted _, _) |
| 84 | + when B.(Tn.known_shared_with_host ~use_host_memory tn) -> |
84 | 85 | let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in |
85 | 86 | fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd); |
86 | 87 | Hash_set.add is_global tn |
@@ -294,14 +295,19 @@ struct |
294 | 295 | (* A rough approximation to the type Gccjit_backend.mem_properties. *) |
295 | 296 | let backend_info = |
296 | 297 | Sexp.Atom |
297 | | - (if in_ctx tn then "Ctx" |
298 | | - else if Hash_set.mem is_global tn then "Host" |
| 298 | + (if Hash_set.mem is_global tn then "Host" |
299 | 299 | else if Tn.is_virtual_force tn 3331 then "Virt" |
300 | | - else "Local") |
| 300 | + else |
| 301 | + match in_ctx tn with |
| 302 | + | Some true -> "Ctx" |
| 303 | + | Some false -> "Local" |
| 304 | + | None -> "Unk") |
301 | 305 | in |
302 | 306 | if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then |
303 | 307 | tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info; |
304 | | - if in_ctx tn && not (Hash_set.mem is_global tn) then |
| 308 | + (* We often don't know ahead of linking with relevant contexts what the stream sharing |
| 309 | + mode of the node will become. Conservatively, use passing as argument. *) |
| 310 | + if Option.value ~default:true (in_ctx tn) && not (Hash_set.mem is_global tn) then |
305 | 311 | (B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params |
306 | 312 | else params) |
307 | 313 | in |
@@ -367,7 +373,12 @@ struct |
367 | 373 | params); |
368 | 374 | fprintf ppf "/* Local declarations and initialization. */@ "; |
369 | 375 | Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node -> |
370 | | - if not (Tn.is_virtual_force tn 333 || in_ctx tn || Hash_set.mem is_global tn) then |
| 376 | + if |
| 377 | + not |
| 378 | + (Tn.is_virtual_force tn 333 |
| 379 | + || Option.value ~default:true (in_ctx tn) |
| 380 | + || Hash_set.mem is_global tn) |
| 381 | + then |
371 | 382 | fprintf ppf "%s %s[%d]%s;@ " |
372 | 383 | (B.typ_of_prec @@ Lazy.force tn.prec) |
373 | 384 | (get_ident tn) (Tn.num_elems tn) |
|
0 commit comments