@@ -195,16 +195,26 @@ let%op depthwise_separable_conv2d ~label ?(kernel_size = 3) ?(stride = 1) () x =
195195
196196(* * Max pooling for 2D spatial data - reduces spatial dimensions by taking maximum values. *)
197197let % op max_pool2d ?(stride = 2 ) ?(window_size = 2 ) () x =
198- (* Although there is no kernel participating, we need to set the iteration size. *)
199198 Shape. set_dim wh window_size;
200199 Shape. set_dim ww window_size;
201- x @^^ " ... | stride*oh+wh, stride*ow+ww, ..c.. => ... | oh, ow, ..c.." [ " wh" ; " ww" ]
200+ (* NOTE: projections inference runs per-assignment in a distinct phase from shape inference, so
201+ for it to know about the window size, we use a constant kernel = 1 to propagate the shape.
202+ We use a trick to create a shape-inferred constant tensor, equivalently we could write
203+ "NTDSL.term ~fetch_op:(Constant 1.) ()" but that's less concise. See:
204+ https://github.com/ahrefs/ocannl/discussions/381 *)
205+ x
206+ @^+ " ... | stride*oh+wh, stride*ow+ww, ..c..; wh, ww => ... | oh, ow, ..c.." [ " wh" ; " ww" ]
207+ (0.5 + 0.5 )
202208
203209(* * Average pooling for 2D spatial data - reduces spatial dimensions by averaging values. *)
204210let % op avg_pool2d ?(stride = 2 ) ?(window_size = 2 ) () x =
205211 Shape. set_dim wh window_size;
206212 Shape. set_dim ww window_size;
207- let sum = x ++ " ... | stride*oh+wh, stride*ow+ww, ..c.. => ... | oh, ow, ..c.." [ " wh" ; " ww" ] in
213+ let sum =
214+ x
215+ +++ " ... | stride*oh+wh, stride*ow+ww, ..c..; wh, ww => ... | oh, ow, ..c.." [ " wh" ; " ww" ]
216+ (0.5 + 0.5 )
217+ in
208218 sum /. (dim wh *. dim ww)
209219
210220(* * Global average pooling - reduces each feature map to a single value by averaging. Commonly used
0 commit comments