@@ -170,31 +170,31 @@ <h3>命令行</h3>
170170</ code > </ pre >
171171< h3 > 训练</ h3 >
172172< h4 > 条件控制怎么加的</ h4 >
173- < div class ="highlight highlight-source-python "> < pre class ="notranslate "> < span class ="pl-k "> if</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-s1 "> obs_as_global_cond</ span > : < span class ="pl-c "> # true</ span >
173+ < div class ="highlight highlight-source-python "> < pre class ="notranslate "> < span class ="pl-k "> if</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-c1 "> obs_as_global_cond</ span > : < span class ="pl-c "> # true</ span >
174174 < span class ="pl-c "> # reshape B, T, ... to B*T</ span >
175175 < span class ="pl-s1 "> this_nobs</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-en "> dict_apply</ span > (< span class ="pl-s1 "> nobs</ span > ,
176- < span class ="pl-k "> lambda</ span > < span class ="pl-s1 "> x</ span > : < span class ="pl-s1 "> x</ span > [:,:< span class ="pl-s1 "> self</ span > .< span class ="pl-s1 "> n_obs_steps</ span > ,...].< span class ="pl-en "> reshape</ span > (< span class ="pl-c1 "> -</ span > < span class ="pl-c1 "> 1</ span > ,< span class ="pl-c1 "> *</ span > < span class ="pl-s1 "> x</ span > .< span class ="pl-s1 "> shape</ span > [< span class ="pl-c1 "> 2</ span > :]))
177- < span class ="pl-s1 "> nobs_features</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-en "> obs_encoder</ span > (< span class ="pl-s1 "> this_nobs</ span > )
176+ < span class ="pl-k "> lambda</ span > < span class ="pl-s1 "> x</ span > : < span class ="pl-s1 "> x</ span > [:,:< span class ="pl-s1 "> self</ span > .< span class ="pl-c1 "> n_obs_steps</ span > ,...].< span class ="pl-c1 "> reshape</ span > (< span class ="pl-c1 "> -</ span > < span class ="pl-c1 "> 1</ span > ,< span class ="pl-c1 "> *</ span > < span class ="pl-s1 "> x</ span > .< span class ="pl-c1 "> shape</ span > [< span class ="pl-c1 "> 2</ span > :]))
177+ < span class ="pl-s1 "> nobs_features</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-c1 "> obs_encoder</ span > (< span class ="pl-s1 "> this_nobs</ span > )
178178 < span class ="pl-c "> # reshape back to B, Do</ span >
179- < span class ="pl-s1 "> global_cond</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> nobs_features</ span > .< span class ="pl-en "> reshape</ span > (< span class ="pl-s1 "> batch_size</ span > , < span class ="pl-c1 "> -</ span > < span class ="pl-c1 "> 1</ span > )</ pre > </ div >
179+ < span class ="pl-s1 "> global_cond</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> nobs_features</ span > .< span class ="pl-c1 "> reshape</ span > (< span class ="pl-s1 "> batch_size</ span > , < span class ="pl-c1 "> -</ span > < span class ="pl-c1 "> 1</ span > )</ pre > </ div >
180180< p > 这里面nobs就是输入图像,经过一个encoder以后reshape,传入model</ p >
181181< p > diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py,< code class ="notranslate "> compute_loss</ code > 函数。< br >
182182这里< code class ="notranslate "> pred_type == 'epsilon'</ code > ,所以模型输出的是noise。</ p >
183183< div class ="highlight highlight-source-python "> < pre class ="notranslate "> < span class ="pl-c "> # Predict the noise residual</ span >
184- < span class ="pl-s1 "> pred</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-en "> model</ span > (< span class ="pl-s1 "> noisy_trajectory</ span > , < span class ="pl-s1 "> timesteps</ span > ,
184+ < span class ="pl-s1 "> pred</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-c1 "> model</ span > (< span class ="pl-s1 "> noisy_trajectory</ span > , < span class ="pl-s1 "> timesteps</ span > ,
185185 < span class ="pl-s1 "> local_cond</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> local_cond</ span > , < span class ="pl-s1 "> global_cond</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> global_cond</ span > )
186- < span class ="pl-s1 "> pred_type</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-s1 "> noise_scheduler</ span > .< span class ="pl-s1 "> config</ span > .< span class ="pl-s1 "> prediction_type</ span >
186+ < span class ="pl-s1 "> pred_type</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> self</ span > .< span class ="pl-c1 "> noise_scheduler</ span > .< span class ="pl-c1 "> config</ span > .< span class ="pl-c1 "> prediction_type</ span >
187187< span class ="pl-k "> if</ span > < span class ="pl-s1 "> pred_type</ span > < span class ="pl-c1 "> ==</ span > < span class ="pl-s "> 'epsilon'</ span > :
188188 < span class ="pl-s1 "> target</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> noise</ span >
189189< span class ="pl-k "> elif</ span > < span class ="pl-s1 "> pred_type</ span > < span class ="pl-c1 "> ==</ span > < span class ="pl-s "> 'sample'</ span > :
190190 < span class ="pl-s1 "> target</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> trajectory</ span >
191191< span class ="pl-k "> else</ span > :
192- < span class ="pl-k "> raise</ span > < span class ="pl-v "> ValueError</ span > (< span class ="pl-s "> f"Unsupported prediction type < span class ="pl-s1 "> < span class ="pl-kos "> {</ span > < span class ="pl-s1 "> pred_type</ span > < span class ="pl-kos "> }</ span > </ span > "</ span > )
192+ < span class ="pl-k "> raise</ span > < span class ="pl-en "> ValueError</ span > (< span class ="pl-s "> f"Unsupported prediction type < span class ="pl-s1 "> < span class ="pl-kos "> {</ span > < span class ="pl-s1 "> pred_type</ span > < span class ="pl-kos "> }</ span > </ span > "</ span > )
193193
194- < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-v "> F</ span > .< span class ="pl-en "> mse_loss</ span > (< span class ="pl-s1 "> pred</ span > , < span class ="pl-s1 "> target</ span > , < span class ="pl-s1 "> reduction</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s "> 'none'</ span > )
195- < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> *</ span > < span class ="pl-s1 "> loss_mask</ span > .< span class ="pl-en "> type</ span > (< span class ="pl-s1 "> loss</ span > .< span class ="pl-s1 "> dtype</ span > )
194+ < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-c1 "> F</ span > .< span class ="pl-c1 "> mse_loss</ span > (< span class ="pl-s1 "> pred</ span > , < span class ="pl-s1 "> target</ span > , < span class ="pl-s1 "> reduction</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s "> 'none'</ span > )
195+ < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> *</ span > < span class ="pl-s1 "> loss_mask</ span > .< span class ="pl-c1 "> type</ span > (< span class ="pl-s1 "> loss</ span > .< span class ="pl-c1 "> dtype</ span > )
196196< span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-en "> reduce</ span > (< span class ="pl-s1 "> loss</ span > , < span class ="pl-s "> 'b ... -> b (...)'</ span > , < span class ="pl-s "> 'mean'</ span > )
197- < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> loss</ span > .< span class ="pl-en "> mean</ span > ()
197+ < span class ="pl-s1 "> loss</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> loss</ span > .< span class ="pl-c1 "> mean</ span > ()
198198< span class ="pl-k "> return</ span > < span class ="pl-s1 "> loss</ span > </ pre > </ div >
199199< h3 > 推理</ h3 >
200200< p > diffusion_policy/policy/diffusion_unet_hybrid_image_policy.py,< code class ="notranslate "> predict_action</ code > 函数调用< code class ="notranslate "> conditional_sample</ code > 函数。< br >
@@ -203,19 +203,19 @@ <h3>推理</h3>
203203< a target ="_blank " rel ="noopener noreferrer nofollow " href ="https://raw.githubusercontent.com/algo-scope/imgBed/main/202410/202411111540205.png "> < img src ="https://raw.githubusercontent.com/algo-scope/imgBed/main/202410/202411111540205.png " alt ="image.png " style ="max-width: 100%; "> </ a > < br >
204204所以数据集中的action或者trajectory是二维的,代表xy坐标。</ p >
205205< div class ="highlight highlight-source-python "> < pre class ="notranslate "> < span class ="pl-c "> # set step values</ span >
206- < span class ="pl-s1 "> scheduler</ span > .< span class ="pl-en "> set_timesteps</ span > (< span class ="pl-s1 "> self</ span > .< span class ="pl-s1 "> num_inference_steps</ span > )
207- < span class ="pl-k "> for</ span > < span class ="pl-s1 "> t</ span > < span class ="pl-c1 "> in</ span > < span class ="pl-s1 "> scheduler</ span > .< span class ="pl-s1 "> timesteps</ span > :
206+ < span class ="pl-s1 "> scheduler</ span > .< span class ="pl-c1 "> set_timesteps</ span > (< span class ="pl-s1 "> self</ span > .< span class ="pl-c1 "> num_inference_steps</ span > )
207+ < span class ="pl-k "> for</ span > < span class ="pl-s1 "> t</ span > < span class ="pl-c1 "> in</ span > < span class ="pl-s1 "> scheduler</ span > .< span class ="pl-c1 "> timesteps</ span > :
208208 < span class ="pl-c "> # 1. apply conditioning</ span >
209209 < span class ="pl-s1 "> trajectory</ span > [< span class ="pl-s1 "> condition_mask</ span > ] < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> condition_data</ span > [< span class ="pl-s1 "> condition_mask</ span > ]
210210 < span class ="pl-c "> # 2. predict model output</ span >
211211 < span class ="pl-s1 "> model_output</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-en "> model</ span > (< span class ="pl-s1 "> trajectory</ span > , < span class ="pl-s1 "> t</ span > ,
212212 < span class ="pl-s1 "> local_cond</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> local_cond</ span > , < span class ="pl-s1 "> global_cond</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> global_cond</ span > )
213213 < span class ="pl-c "> # 3. compute previous image: x_t -> x_t-1</ span >
214- < span class ="pl-s1 "> trajectory</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> scheduler</ span > .< span class ="pl-en "> step</ span > (
214+ < span class ="pl-s1 "> trajectory</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> scheduler</ span > .< span class ="pl-c1 "> step</ span > (
215215 < span class ="pl-s1 "> model_output</ span > , < span class ="pl-s1 "> t</ span > , < span class ="pl-s1 "> trajectory</ span > ,
216216 < span class ="pl-s1 "> generator</ span > < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> generator</ span > ,
217217 < span class ="pl-c1 "> **</ span > < span class ="pl-s1 "> kwargs</ span >
218- ).< span class ="pl-s1 "> prev_sample</ span >
218+ ).< span class ="pl-c1 "> prev_sample</ span >
219219< span class ="pl-c "> # finally make sure conditioning is enforced</ span >
220220< span class ="pl-s1 "> trajectory</ span > [< span class ="pl-s1 "> condition_mask</ span > ] < span class ="pl-c1 "> =</ span > < span class ="pl-s1 "> condition_data</ span > [< span class ="pl-s1 "> condition_mask</ span > ]
221221< span class ="pl-k "> return</ span > < span class ="pl-s1 "> trajectory</ span > </ pre > </ div >
0 commit comments