Commit f65c702
committed
Fix builtins and random numbers in the metal backend, by Claude Sonnet
Summary 1, by Claude:
I successfully updated the builtins.metal file to:
Removed the uint4x32_t wrapper struct - Now using Metal's native uint4 type directly
Updated the arrayjit_threefry4x32 function - Changed signature from uint4x32_t parameters to uint4
Fixed naming inconsistencies - Renamed functions to match the expected patterns:
uint4x32_to_i64_uniform → uint4x32_to_int64_uniform
uint4x32_to_u32_uniform → uint4x32_to_uint32_uniform
uint4x32_to_u64_uniform → uint4x32_to_uint64_uniform
uint4x32_to_i8_uniform → uint4x32_to_byte_uniform
uint4x32_to_u16_uniform_vec → uint4x32_to_uint16_uniform_vec
uint4x32_to_i8_uniform_vec → uint4x32_to_byte_uniform_vec
uint4x32_to_i64_uniform_vec → uint4x32_to_int64_uniform_vec
Added missing functions that exist in the C version:
uint4x32_to_uint16_uniform (scalar version)
uint4x32_to_fp8_uniform (scalar version)
uint4x32_to_uint32_uniform_vec (vectorized version)
uint4x32_to_uint64_uniform_vec (vectorized version)
Updated all function signatures to use uint4 directly instead of the wrapper struct
The changes improve consistency between backends and remove unnecessary indirection. The test failure is due to the neural network using double precision, which Metal doesn't support. To resolve this, you could:
Change the test to use single precision (--ocannl_default_prec=single)
Or test with a different backend that supports double precision
Or configure the model to use single precision from the start
The builtins.metal file is now properly updated and consistent with the naming conventions in ops.ml.
Summary 2, by Claude:
The Metal backend was using simple type casts (uint4)(value) instead of proper bit-preserving conversion functions for the Threefry4x32 random number generator. This meant:
(uint4)(23.0) became uint4(23, 0, 0, 0) (just integer part)
Should have been the full 32-bit representation of the float
🛠️ Fix Applied
Added missing conversion functions to builtins.metal:
conversions
Updated Metal backend's convert_precision function to use proper conversions:
)
✅ Result
Before: Metal backend had completely different random sequences → poor training
After: Metal backend matches C backend results → proper training convergence
The fix ensures that random number generators in both backends receive identical seed bit patterns, producing consistent training behavior across all supported backends.1 parent dc9ff01 commit f65c702
File tree
3 files changed
+149
-61
lines changed- arrayjit/lib
- test/training
3 files changed
+149
-61
lines changedLines changed: 117 additions & 56 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
5 | | - | |
6 | | - | |
7 | | - | |
8 | 4 | | |
9 | 5 | | |
10 | 6 | | |
| |||
41 | 37 | | |
42 | 38 | | |
43 | 39 | | |
44 | | - | |
45 | | - | |
46 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
47 | 43 | | |
48 | 44 | | |
49 | 45 | | |
| |||
125 | 121 | | |
126 | 122 | | |
127 | 123 | | |
128 | | - | |
129 | | - | |
130 | | - | |
| 124 | + | |
131 | 125 | | |
132 | 126 | | |
133 | 127 | | |
134 | 128 | | |
135 | 129 | | |
136 | 130 | | |
137 | 131 | | |
| 132 | + | |
138 | 133 | | |
139 | 134 | | |
140 | 135 | | |
| |||
149 | 144 | | |
150 | 145 | | |
151 | 146 | | |
152 | | - | |
153 | | - | |
| 147 | + | |
| 148 | + | |
154 | 149 | | |
155 | 150 | | |
156 | 151 | | |
157 | | - | |
| 152 | + | |
158 | 153 | | |
159 | | - | |
| 154 | + | |
160 | 155 | | |
161 | 156 | | |
162 | 157 | | |
163 | 158 | | |
164 | | - | |
165 | | - | |
| 159 | + | |
| 160 | + | |
166 | 161 | | |
167 | 162 | | |
168 | 163 | | |
169 | | - | |
170 | | - | |
| 164 | + | |
| 165 | + | |
171 | 166 | | |
172 | 167 | | |
173 | 168 | | |
174 | | - | |
175 | | - | |
| 169 | + | |
| 170 | + | |
176 | 171 | | |
177 | 172 | | |
178 | 173 | | |
179 | | - | |
180 | | - | |
| 174 | + | |
| 175 | + | |
181 | 176 | | |
182 | 177 | | |
183 | | - | |
184 | | - | |
185 | | - | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
186 | 181 | | |
187 | 182 | | |
188 | | - | |
189 | | - | |
190 | | - | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
191 | 186 | | |
192 | 187 | | |
193 | 188 | | |
194 | | - | |
195 | | - | |
| 189 | + | |
| 190 | + | |
196 | 191 | | |
197 | 192 | | |
198 | 193 | | |
199 | 194 | | |
200 | | - | |
201 | | - | |
| 195 | + | |
| 196 | + | |
202 | 197 | | |
203 | 198 | | |
204 | 199 | | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
205 | 205 | | |
206 | 206 | | |
207 | 207 | | |
208 | | - | |
| 208 | + | |
209 | 209 | | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
214 | 214 | | |
215 | 215 | | |
216 | 216 | | |
217 | 217 | | |
218 | | - | |
| 218 | + | |
219 | 219 | | |
220 | | - | |
221 | | - | |
| 220 | + | |
| 221 | + | |
222 | 222 | | |
223 | 223 | | |
224 | 224 | | |
225 | 225 | | |
226 | 226 | | |
227 | 227 | | |
228 | | - | |
| 228 | + | |
229 | 229 | | |
230 | | - | |
| 230 | + | |
231 | 231 | | |
232 | 232 | | |
233 | 233 | | |
234 | 234 | | |
235 | | - | |
| 235 | + | |
236 | 236 | | |
237 | | - | |
238 | | - | |
| 237 | + | |
| 238 | + | |
239 | 239 | | |
240 | 240 | | |
241 | 241 | | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
242 | 254 | | |
243 | 255 | | |
244 | | - | |
| 256 | + | |
245 | 257 | | |
246 | | - | |
| 258 | + | |
247 | 259 | | |
248 | 260 | | |
249 | 261 | | |
| |||
255 | 267 | | |
256 | 268 | | |
257 | 269 | | |
258 | | - | |
| 270 | + | |
259 | 271 | | |
260 | | - | |
| 272 | + | |
261 | 273 | | |
262 | 274 | | |
263 | 275 | | |
| |||
267 | 279 | | |
268 | 280 | | |
269 | 281 | | |
270 | | - | |
| 282 | + | |
271 | 283 | | |
272 | | - | |
| 284 | + | |
273 | 285 | | |
274 | 286 | | |
275 | 287 | | |
| |||
281 | 293 | | |
282 | 294 | | |
283 | 295 | | |
284 | | - | |
| 296 | + | |
285 | 297 | | |
286 | | - | |
| 298 | + | |
287 | 299 | | |
288 | 300 | | |
289 | 301 | | |
| |||
294 | 306 | | |
295 | 307 | | |
296 | 308 | | |
297 | | - | |
298 | | - | |
| 309 | + | |
| 310 | + | |
299 | 311 | | |
300 | | - | |
| 312 | + | |
301 | 313 | | |
302 | 314 | | |
303 | 315 | | |
| |||
306 | 318 | | |
307 | 319 | | |
308 | 320 | | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
| 359 | + | |
| 360 | + | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
309 | 370 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
572 | 572 | | |
573 | 573 | | |
574 | 574 | | |
575 | | - | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
576 | 592 | | |
577 | 593 | | |
578 | 594 | | |
| |||
610 | 626 | | |
611 | 627 | | |
612 | 628 | | |
613 | | - | |
| 629 | + | |
614 | 630 | | |
615 | 631 | | |
616 | 632 | | |
| |||
623 | 639 | | |
624 | 640 | | |
625 | 641 | | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
626 | 650 | | |
627 | 651 | | |
628 | 652 | | |
629 | 653 | | |
630 | 654 | | |
631 | 655 | | |
| 656 | + | |
632 | 657 | | |
633 | 658 | | |
634 | 659 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | | - | |
23 | | - | |
24 | | - | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
25 | 27 | | |
26 | 28 | | |
27 | 29 | | |
| |||
0 commit comments