Skip to content

Commit a699bc4

Browse files
author
algo-scope
committed
🎉auto update by Gmeek action
1 parent 834073f commit a699bc4

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

docs/post/fu-xian-Diffusion Policy,-sheng-cheng-shi-de-jue-ce-kuang-jia.html

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,31 +170,31 @@ <h3>命令行</h3>
170170
</code></pre>
171171
<h3>训练</h3>
172172
<h4>条件控制怎么加的</h4>
173-
<div class="highlight highlight-source-python"><pre class="notranslate"><span class="pl-k">if</span> <span class="pl-s1">self</span>.<span class="pl-s1">obs_as_global_cond</span>:     <span class="pl-c"># true</span>
173+
<div class="highlight highlight-source-python"><pre class="notranslate"><span class="pl-k">if</span> <span class="pl-s1">self</span>.<span class="pl-c1">obs_as_global_cond</span>:     <span class="pl-c"># true</span>
174174
<span class="pl-c"># reshape B, T, ... to B*T</span>
175175
<span class="pl-s1">this_nobs</span> <span class="pl-c1">=</span> <span class="pl-en">dict_apply</span>(<span class="pl-s1">nobs</span>,
176-
<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">x</span>[:,:<span class="pl-s1">self</span>.<span class="pl-s1">n_obs_steps</span>,...].<span class="pl-en">reshape</span>(<span class="pl-c1">-</span><span class="pl-c1">1</span>,<span class="pl-c1">*</span><span class="pl-s1">x</span>.<span class="pl-s1">shape</span>[<span class="pl-c1">2</span>:]))
177-
<span class="pl-s1">nobs_features</span> <span class="pl-c1">=</span> <span class="pl-s1">self</span>.<span class="pl-en">obs_encoder</span>(<span class="pl-s1">this_nobs</span>)
176+
<span class="pl-k">lambda</span> <span class="pl-s1">x</span>: <span class="pl-s1">x</span>[:,:<span class="pl-s1">self</span>.<span class="pl-c1">n_obs_steps</span>,...].<span class="pl-c1">reshape</span>(<span class="pl-c1">-</span><span class="pl-c1">1</span>,<span class="pl-c1">*</span><span class="pl-s1">x</span>.<span class="pl-c1">shape</span>[<span class="pl-c1">2</span>:]))
177+
<span class="pl-s1">nobs_features</span> <span class="pl-c1">=</span> <span class="pl-s1">self</span>.<span class="pl-c1">obs_encoder</span>(<span class="pl-s1">this_nobs</span>)
178178
<span class="pl-c"># reshape back to B, Do</span>
179-
<span class="pl-s1">global_cond</span> <span class="pl-c1">=</span> <span class="pl-s1">nobs_features</span>.<span class="pl-en">reshape</span>(<span class="pl-s1">batch_size</span>, <span class="pl-c1">-</span><span class="pl-c1">1</span>)</pre></div>
179+
<span class="pl-s1">global_cond</span> <span class="pl-c1">=</span> <span class="pl-s1">nobs_features</span>.<span class="pl-c1">reshape</span>(<span class="pl-s1">batch_size</span>, <span class="pl-c1">-</span><span class="pl-c1">1</span>)</pre></div>
180180
<p>这里面nobs就是输入图像,经过一个encoder以后reshape,传入model</p>
181181
<p>diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py,<code class="notranslate">compute_loss</code>函数。<br>
182182
这里<code class="notranslate">pred_type == 'epsilon'</code>,所以模型输出的是noise。</p>
183183
<div class="highlight highlight-source-python"><pre class="notranslate"><span class="pl-c"># Predict the noise residual</span>
184-
<span class="pl-s1">pred</span> <span class="pl-c1">=</span> <span class="pl-s1">self</span>.<span class="pl-en">model</span>(<span class="pl-s1">noisy_trajectory</span>, <span class="pl-s1">timesteps</span>,
184+
<span class="pl-s1">pred</span> <span class="pl-c1">=</span> <span class="pl-s1">self</span>.<span class="pl-c1">model</span>(<span class="pl-s1">noisy_trajectory</span>, <span class="pl-s1">timesteps</span>,
185185
<span class="pl-s1">local_cond</span><span class="pl-c1">=</span><span class="pl-s1">local_cond</span>, <span class="pl-s1">global_cond</span><span class="pl-c1">=</span><span class="pl-s1">global_cond</span>)
186-
<span class="pl-s1">pred_type</span> <span class="pl-c1">=</span> <span class="pl-s1">self</span>.<span class="pl-s1">noise_scheduler</span>.<span class="pl-s1">config</span>.<span class="pl-s1">prediction_type</span>
186+
<span class="pl-s1">pred_type</span> <span class="pl-c1">=</span> <span class="pl-s1">self</span>.<span class="pl-c1">noise_scheduler</span>.<span class="pl-c1">config</span>.<span class="pl-c1">prediction_type</span>
187187
<span class="pl-k">if</span> <span class="pl-s1">pred_type</span> <span class="pl-c1">==</span> <span class="pl-s">'epsilon'</span>:
188188
<span class="pl-s1">target</span> <span class="pl-c1">=</span> <span class="pl-s1">noise</span>
189189
<span class="pl-k">elif</span> <span class="pl-s1">pred_type</span> <span class="pl-c1">==</span> <span class="pl-s">'sample'</span>:
190190
<span class="pl-s1">target</span> <span class="pl-c1">=</span> <span class="pl-s1">trajectory</span>
191191
<span class="pl-k">else</span>:
192-
<span class="pl-k">raise</span> <span class="pl-v">ValueError</span>(<span class="pl-s">f"Unsupported prediction type <span class="pl-s1"><span class="pl-kos">{</span><span class="pl-s1">pred_type</span><span class="pl-kos">}</span></span>"</span>)
192+
<span class="pl-k">raise</span> <span class="pl-en">ValueError</span>(<span class="pl-s">f"Unsupported prediction type <span class="pl-s1"><span class="pl-kos">{</span><span class="pl-s1">pred_type</span><span class="pl-kos">}</span></span>"</span>)
193193

194-
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-v">F</span>.<span class="pl-en">mse_loss</span>(<span class="pl-s1">pred</span>, <span class="pl-s1">target</span>, <span class="pl-s1">reduction</span><span class="pl-c1">=</span><span class="pl-s">'none'</span>)
195-
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-s1">loss</span> <span class="pl-c1">*</span> <span class="pl-s1">loss_mask</span>.<span class="pl-en">type</span>(<span class="pl-s1">loss</span>.<span class="pl-s1">dtype</span>)
194+
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-c1">F</span>.<span class="pl-c1">mse_loss</span>(<span class="pl-s1">pred</span>, <span class="pl-s1">target</span>, <span class="pl-s1">reduction</span><span class="pl-c1">=</span><span class="pl-s">'none'</span>)
195+
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-s1">loss</span> <span class="pl-c1">*</span> <span class="pl-s1">loss_mask</span>.<span class="pl-c1">type</span>(<span class="pl-s1">loss</span>.<span class="pl-c1">dtype</span>)
196196
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-en">reduce</span>(<span class="pl-s1">loss</span>, <span class="pl-s">'b ... -&gt; b (...)'</span>, <span class="pl-s">'mean'</span>)
197-
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-s1">loss</span>.<span class="pl-en">mean</span>()
197+
<span class="pl-s1">loss</span> <span class="pl-c1">=</span> <span class="pl-s1">loss</span>.<span class="pl-c1">mean</span>()
198198
<span class="pl-k">return</span> <span class="pl-s1">loss</span></pre></div>
199199
<h3>推理</h3>
200200
<p>diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py,<code class="notranslate">predict_action</code>函数调用<code class="notranslate">conditional_sample</code>函数。<br>
@@ -203,19 +203,19 @@ <h3>推理</h3>
203203
<a target="_blank" rel="noopener noreferrer nofollow" href="https://raw.githubusercontent.com/algo-scope/imgBed/main/202410/202411111540205.png"><img src="https://raw.githubusercontent.com/algo-scope/imgBed/main/202410/202411111540205.png" alt="image.png" style="max-width: 100%;"></a><br>
204204
所以数据集中的action或者trajectory是二维的,代表xy坐标。</p>
205205
<div class="highlight highlight-source-python"><pre class="notranslate"><span class="pl-c"># set step values</span>
206-
<span class="pl-s1">scheduler</span>.<span class="pl-en">set_timesteps</span>(<span class="pl-s1">self</span>.<span class="pl-s1">num_inference_steps</span>)
207-
<span class="pl-k">for</span> <span class="pl-s1">t</span> <span class="pl-c1">in</span> <span class="pl-s1">scheduler</span>.<span class="pl-s1">timesteps</span>:
206+
<span class="pl-s1">scheduler</span>.<span class="pl-c1">set_timesteps</span>(<span class="pl-s1">self</span>.<span class="pl-c1">num_inference_steps</span>)
207+
<span class="pl-k">for</span> <span class="pl-s1">t</span> <span class="pl-c1">in</span> <span class="pl-s1">scheduler</span>.<span class="pl-c1">timesteps</span>:
208208
<span class="pl-c"># 1. apply conditioning</span>
209209
<span class="pl-s1">trajectory</span>[<span class="pl-s1">condition_mask</span>] <span class="pl-c1">=</span> <span class="pl-s1">condition_data</span>[<span class="pl-s1">condition_mask</span>]
210210
<span class="pl-c"># 2. predict model output</span>
211211
<span class="pl-s1">model_output</span> <span class="pl-c1">=</span> <span class="pl-en">model</span>(<span class="pl-s1">trajectory</span>, <span class="pl-s1">t</span>,
212212
<span class="pl-s1">local_cond</span><span class="pl-c1">=</span><span class="pl-s1">local_cond</span>, <span class="pl-s1">global_cond</span><span class="pl-c1">=</span><span class="pl-s1">global_cond</span>)
213213
<span class="pl-c"># 3. compute previous image: x_t -&gt; x_t-1</span>
214-
<span class="pl-s1">trajectory</span> <span class="pl-c1">=</span> <span class="pl-s1">scheduler</span>.<span class="pl-en">step</span>(
214+
<span class="pl-s1">trajectory</span> <span class="pl-c1">=</span> <span class="pl-s1">scheduler</span>.<span class="pl-c1">step</span>(
215215
<span class="pl-s1">model_output</span>, <span class="pl-s1">t</span>, <span class="pl-s1">trajectory</span>,
216216
<span class="pl-s1">generator</span><span class="pl-c1">=</span><span class="pl-s1">generator</span>,
217217
<span class="pl-c1">**</span><span class="pl-s1">kwargs</span>
218-
).<span class="pl-s1">prev_sample</span>
218+
).<span class="pl-c1">prev_sample</span>
219219
<span class="pl-c"># finally make sure conditioning is enforced</span>
220220
<span class="pl-s1">trajectory</span>[<span class="pl-s1">condition_mask</span>] <span class="pl-c1">=</span> <span class="pl-s1">condition_data</span>[<span class="pl-s1">condition_mask</span>]        
221221
<span class="pl-k">return</span> <span class="pl-s1">trajectory</span></pre></div>

0 commit comments

Comments
 (0)