diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 663bfe6..e8f1756 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,5 +1,5 @@ -source-sha: 8d73de367a7f160dac777aa557f1c26069f84ea5 -synced-at: "2026-04-12" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 diff --git a/.translate/state/numba.md.yml b/.translate/state/numba.md.yml index fbbdfa3..c7d5b0a 100644 --- a/.translate/state/numba.md.yml +++ b/.translate/state/numba.md.yml @@ -1,5 +1,5 @@ -source-sha: be6eeaee8db0c8bfea65b89d57ca8aecf7f96dff -synced-at: "2026-04-12" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 5 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index 93adba6..34ec88f 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,5 +1,5 @@ -source-sha: 94dd7d22385ec46d740db1fc2cddf05c29377594 -synced-at: "2026-04-12" +source-sha: 95378b8382b4dbd1cd3e0ffe0e152811894c357f +synced-at: "2026-04-13" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 96b45e2..fdd359b 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -31,10 +31,9 @@ translation: Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX JIT Compilation: کامپایل JIT - JIT Compilation::Evaluating a more complicated function: ارزیابی یک تابع پیچیده‌تر - JIT Compilation::Evaluating a more complicated function::With NumPy: با NumPy - JIT Compilation::Evaluating a more complicated function::With JAX: با JAX - JIT Compilation::Compiling the whole function: کامپایل کل تابع + JIT Compilation::With NumPy: با NumPy + JIT Compilation::With JAX: با JAX + JIT Compilation::Compiling the Whole Function: کامپایل کل تابع JIT Compilation::How JIT compilation works: نحوه کار کامپایل JIT JIT Compilation::Compiling non-pure functions: کامپایل توابع غیرخالص Vectorization with `vmap`: برداری‌سازی با `vmap` @@ -638,11 +637,7 @@ random_sum_jax(key) ما قدرت کامپایلر JIT JAX را در ترکیب با سخت‌افزار موازی {ref}`در بالا ` مشاهده کردیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم. -بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم. - -### ارزیابی یک تابع پیچیده‌تر - -تابع زیر را در نظر بگیرید +بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم: ```{code-cell} def f(x): @@ -650,7 +645,7 @@ def f(x): return y ``` -#### با NumPy +### با NumPy ابتدا با NumPy امتحان خواهیم کرد @@ -665,7 +660,7 @@ with qe.Timer(): y = f(x) ``` -#### با JAX +### با JAX اکنون بیایید دوباره با JAX امتحان کنیم. @@ -701,7 +696,7 @@ with qe.Timer(): نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT. -علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم *کل* تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد. +علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد. ### کامپایل کل تابع diff --git a/lectures/numba.md b/lectures/numba.md index 35c1fa7..a421c01 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -130,7 +130,7 @@ n = 10_000_000 with qe.Timer() as timer1: # Time Python base version - x = qm(0.1, int(n)) + x = qm(0.1, n) ``` @@ -158,7 +158,7 @@ qm_numba = jit(qm) ```{code-cell} ipython3 with qe.Timer() as timer2: # Time jitted version - x = qm_numba(0.1, int(n)) + x = qm_numba(0.1, n) ``` این یک افزایش سرعت قابل توجه است. @@ -170,7 +170,7 @@ with qe.Timer() as timer2: ```{code-cell} ipython3 with qe.Timer() as timer3: # Second run - x = qm_numba(0.1, int(n)) + x = qm_numba(0.1, n) ``` در اینجا میزان افزایش سرعت نشان داده شده است: diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index b3b9709..8578b55 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -148,13 +148,13 @@ for x in grid: در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند. -(این استراتژی به MATLAB بازمی‌گردد.) +(این استراتژی به Matlab بازمی‌گردد.) ```{code-cell} ipython3 grid = np.linspace(-3, 3, 3_000) x, y = np.meshgrid(grid, grid) -with qe.Timer(precision=8): +with qe.Timer(): z_max_numpy = np.max(f(x, y)) print(f"NumPy result: {z_max_numpy:.6f}") @@ -177,13 +177,17 @@ def compute_max_numba(grid): for x in grid: for y in grid: z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z + m = max(m, z) return m +``` +بیایید آن را آزمایش کنیم: + +```{code-cell} ipython3 grid = np.linspace(-3, 3, 3_000) -with qe.Timer(precision=8): +with qe.Timer(): + # First run z_max_numba = compute_max_numba(grid) print(f"Numba result: {z_max_numba:.6f}") @@ -192,13 +196,16 @@ print(f"Numba result: {z_max_numba:.6f}") بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba(grid) ``` -بسته به دستگاه شما، نسخه Numba می‌تواند کمی کندتر یا کمی سریعتر از NumPy باشد. +بسته به دستگاه شما، نسخه Numba ممکن است کندتر یا سریعتر از NumPy باشد. -از یک طرف، NumPy محاسبات کارآمد (مانند Numba) را با مقداری چندنخی (برخلاف این کد Numba) ترکیب می‌کند که مزیتی فراهم می‌کند. +در اکثر موارد، Numba کمی بهتر است. + +از یک طرف، NumPy محاسبات کارآمد را با مقداری چندنخی ترکیب می‌کند که مزیتی فراهم می‌کند. از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده می‌کند، زیرا ما فقط با یک شبکه یک‌بعدی کار می‌کنیم. @@ -206,43 +213,6 @@ with qe.Timer(precision=8): حالا بیایید موازی‌سازی با Numba را با استفاده از `prange` امتحان کنیم: -در اینجا یک تلاش ساده و *نادرست* آمده است. - -```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel(grid): - n = len(grid) - m = -np.inf - for i in numba.prange(n): - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > m: - m = z - return m - -``` - -این `-inf` برمی‌گرداند --- مقدار اولیه `m`، انگار که هرگز به‌روزرسانی نشده است: - -```{code-cell} ipython3 -z_max_parallel_incorrect = compute_max_numba_parallel(grid) -print(f"Numba result: {z_max_parallel_incorrect} 😱") -``` - -برای درک چرایی این موضوع، به یاد بیاورید که `prange` حلقه بیرونی را بین نخ‌ها تقسیم می‌کند. - -هر نخ یک نسخه خصوصی از `m` دارد که با مقدار `-np.inf` مقداردهی اولیه شده و آن را در بازه تکرارهای خود به درستی به‌روزرسانی می‌کند. - -اما در پایان حلقه، Numba باید نسخه‌های هر نخ از `m` را در یک مقدار واحد ترکیب کند --- یک **تقلیل (reduction)**. - -برای الگوهایی که تشخیص می‌دهد، مانند `m += z` (جمع) یا `m = max(m, z)` (max)، Numba عملگر ترکیب را می‌شناسد. - -اما الگوی `if z > m: m = z` را به عنوان یک تقلیل max تشخیص نمی‌دهد، بنابراین نتایج هر نخ هرگز ترکیب نمی‌شوند و `m` مقدار اولیه خود را حفظ می‌کند. - -ساده‌ترین راه‌حل جایگزینی شرط با `max` است که Numba آن را می‌شناسد: - ```{code-cell} ipython3 @numba.jit(parallel=True) def compute_max_numba_parallel(grid): @@ -257,36 +227,21 @@ def compute_max_numba_parallel(grid): return m ``` -یک روش جایگزین این است که بدنه حلقه را بین `i` ها کاملاً مستقل کنیم و تقلیل را خودمان انجام دهیم: +در اینجا یک اجرای گرم‌کننده و آزمایش آمده است. ```{code-cell} ipython3 -@numba.jit(parallel=True) -def compute_max_numba_parallel_v2(grid): - n = len(grid) - row_maxes = np.empty(n) - for i in numba.prange(n): - row_max = -np.inf - for j in range(n): - x = grid[i] - y = grid[j] - z = np.cos(x**2 + y**2) / (1 + x**2 + y**2) - if z > row_max: - row_max = z - row_maxes[i] = row_max - return np.max(row_maxes) -``` - -در اینجا هر نخ به یک عنصر جداگانه از `row_maxes` می‌نویسد، بنابراین تقلیل را خودمان از طریق `np.max` انجام می‌دهیم. +with qe.Timer(): + # First run + z_max_parallel = compute_max_numba_parallel(grid) -```{code-cell} ipython3 -z_max_parallel = compute_max_numba_parallel(grid) print(f"Numba result: {z_max_parallel:.6f}") ``` -در اینجا زمان‌بندی آمده است. +در اینجا زمان‌بندی برای نسخه از پیش کامپایل شده آمده است. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run compute_max_numba_parallel(grid) ``` @@ -300,7 +255,7 @@ with qe.Timer(precision=8): اما تفاوت‌هایی نیز وجود دارد که در اینجا آنها را برجسته می‌کنیم. -بیایید با تابع شروع کنیم. +بیایید با تابع شروع کنیم که `np` را به `jnp` تغییر می‌دهد و `jax.jit` را اضافه می‌کند. ```{code-cell} ipython3 @@ -315,9 +270,15 @@ def f(x, y): ```{code-cell} ipython3 grid = jnp.linspace(-3, 3, 3_000) x_mesh, y_mesh = jnp.meshgrid(grid, grid) +``` + +حالا بیایید اجرا و زمان‌بندی کنیم -with qe.Timer(precision=8): +```{code-cell} ipython3 +with qe.Timer(): + # First run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() print(f"Plain vanilla JAX result: {z_max:.6f}") @@ -326,8 +287,10 @@ print(f"Plain vanilla JAX result: {z_max:.6f}") بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): + # Second run z_max = jnp.max(f(x_mesh, y_mesh)) + # Hold interpreter z_max.block_until_ready() ``` @@ -337,7 +300,7 @@ with qe.Timer(precision=8): ### JAX به علاوه vmap -یک مشکل با کد NumPy و کد JAX فوق وجود دارد: +یک مشکل با کد NumPy و کد JAX وجود دارد: در حالی که آرایه‌های تخت حافظه کمی دارند @@ -371,7 +334,7 @@ f_vec = jax.vmap(f_vec_x) بیایید زمان‌بندی را ببینیم: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() @@ -379,7 +342,7 @@ print(f"JAX vmap v1 result: {z_max:.6f}") ``` ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() ``` @@ -419,7 +382,7 @@ def compute_max_vmap(grid): بیایید آن را امتحان کنیم. ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = compute_max_vmap(grid).block_until_ready() print(f"JAX vmap result: {z_max:.6f}") @@ -428,7 +391,7 @@ print(f"JAX vmap result: {z_max:.6f}") بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): z_max = compute_max_vmap(grid).block_until_ready() ``` @@ -473,14 +436,14 @@ def qm(x0, n, α=4.0): ```{code-cell} ipython3 n = 10_000_000 -with qe.Timer(precision=8): +with qe.Timer(): x = qm(0.1, n) ``` بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود: ```{code-cell} ipython3 -with qe.Timer(precision=8): +with qe.Timer(): x = qm(0.1, n) ``` @@ -499,7 +462,7 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد ```{code-cell} ipython3 cpu = jax.devices("cpu")[0] -@partial(jax.jit, static_argnums=(1,), device=cpu) +@partial(jax.jit, static_argnames=('n',), device=cpu) def qm_jax(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -512,27 +475,27 @@ def qm_jax(x0, n, α=4.0): این کد خواندن آسانی ندارد اما، در اصل، `lax.scan` به طور مکرر `update` را فراخوانی می‌کند و بازگشت‌های `x_new` را در یک آرایه جمع می‌کند. ```{note} -خوانندگان تیزبین متوجه خواهند شد که ما `device=cpu` را در decorator `jax.jit` مشخص می‌کنیم. - -محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهره‌برداری GPU از موازی‌سازی باقی می‌گذارد. - -در نتیجه، سربار راه‌اندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسب‌تر برای این بار کاری می‌کند. - -خوانندگان کنجکاو می‌توانند حذف این گزینه را امتحان کنند تا ببینند چگونه عملکرد تغییر می‌کند. +ما `device=cpu` را در decorator `jax.jit` مشخص می‌کنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهره‌برداری GPU از موازی‌سازی باقی می‌گذارد. در نتیجه، سربار راه‌اندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسب‌تر برای این بار کاری می‌کند. ``` بیایید آن را با همان پارامترها زمان‌بندی کنیم: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # First run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود: ```{code-cell} ipython3 -with qe.Timer(precision=8): - x_jax = qm_jax(0.1, n).block_until_ready() +with qe.Timer(): + # Second run + x_jax = qm_jax(0.1, n) + # Hold interpreter + x_jax.block_until_ready() ``` JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است. @@ -569,12 +532,11 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است. -JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است. +JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهره‌وری اضافی ناچیز است. + +با این حال، `lax.scan` یک مزیت مهم دارد: از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست. -```{note} -یک مزیت مهم `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست. اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است. -``` در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.