In [None]:
def bind_forward_double_loss(
        self,
        prompt: Union[str, List[str]] = None,
        audio_length_in_s: Optional[float] = None,
        num_inference_steps: int = 10,
        guidance_scale: float = 2.5,
        learning_rate: float = 0.1,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_waveforms_per_prompt: Optional[int] = 1,
        clip_duration: float = 2.0,
        clips_per_video: int = 5,
        num_optimization_steps: int = 1,
        optimization_starting_point: float = 0.2,
        eta: float = 0.0,
        video_paths: Union[str, List[str]] = None,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        output_type: Optional[str] = "np",
    ):
        # --- 0. Audio 길이 설정 (LDM 기본 세팅, 3.1.1 Latent diffusion과 동일한 베이스 파이프라인) ---
        vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate

        if audio_length_in_s is None:
            # unet의 latent 크기와 vocoder upsample factor로부터 height(멜 스펙트럼 시간축) 계산
            audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor

        height = int(audio_length_in_s / vocoder_upsample_factor)

        original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
        if height % self.vae_scale_factor != 0:
            # VAE의 stride에 맞게 height를 올림해서 맞춰주는 부분
            height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
            logger.info(
                f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
                f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
                f"denoising process."
            )

        # --- 1. 입력 체크 (diffusion pipeline 공통 전처리) ---
        self.check_inputs(
            prompt,
            audio_length_in_s,
            vocoder_upsample_factor,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
        )

        # --- 2. 배치 크기 및 classifier-free guidance 여부 설정 ---
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # classifier-free guidance (Imagen 논문 w, Seeing-and-Hearing 3.1.2의 classifier guidance 언급과 연결)
        do_classifier_free_guidance = guidance_scale > 1.0

        # --- 3. 텍스트 프롬프트 인코딩 (y = EMB(p), Algorithm 1 line 1) ---
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_waveforms_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        # --- 4. 타임스텝 준비 (T → 0까지 denoising, 3.1.1의 DDPM 스케줄링 부분) ---
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # --- 5. latent 초기화 (z_T 샘플링, 3.1.1의 z_t 정의에 해당) ---
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_waveforms_per_prompt,
            num_channels_latents,
            height,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        latents_dtype = latents.dtype

        # --- 6. scheduler 추가 인자 (DDIM 등에서 사용, LDM 베이스 부분) ---
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # --- (중요) ImageBind용 비디오 입력 준비 (3.1.3 Linking multiple modalities, M2 = VISION) ---
        image_bind_video_input = load_and_transform_video_data(
            video_paths, device,
            clip_duration=clip_duration,
            clips_per_video=clips_per_video,
            n_samples_per_clip=2
        )

        # --- (중요) Pretrained ImageBind 로드 및 freeze (3.1.3에서 말하는 "pre-trained ImageBind") ---
        bind_model = imagebind_model.imagebind_huge(pretrained=True)
        bind_model.eval()
        bind_model.to(device)

        for p in bind_model.parameters():
            p.requires_grad = False   # 논문에서 주장하는 training-free 설정 (ImageBind는 업데이트하지 않음)

        # --- 7. Denoising loop (Algorithm 1의 for t = T to 0) ---
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        # optimization_starting_point 비율 이후부터 ImageBind guidance 적용 (Algorithm 1의 warmup K에 해당)
        num_warmup_steps_bind = int(len(timesteps) * optimization_starting_point)

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # classifier-free guidance를 위한 latent 복제 (uncond + cond)
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # --- 7-1. UNet로 노이즈 예측 (ϵ_θ( z_t, t, p ), 3.1.1 식 (5)) ---
                with torch.no_grad():
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=None,
                        class_labels=prompt_embeds,  # AudioLDM에서 text-conditioning
                        cross_attention_kwargs=cross_attention_kwargs,
                    ).sample.to(dtype=latents_dtype)

                # --- 7-2. classifier-free guidance 적용 (논문 3.1.2 classifier guidance와 동일한 형태) ---
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # --- 7-3. scheduler step: z_t -> z_{t-1} 업데이트 (p_θ( z_{t-1} | z_t ), DDPM 역과정) ---
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                # 여기부터가 논문 3.2.2 "Multimodal guidance" & 3.2.3 "Dual loss" 구현부 --------
                latents_temp = latents.detach()
                latents_temp.requires_grad = True

                # z_t를 파라미터로 두고 최적화 (식 (13) z_t ← z_t - λ∇_z L 에 해당)
                optimizer = torch.optim.Adam([latents_temp], lr=learning_rate)

                # warmup 이후부터 ImageBind 기반 guidance 수행 (Algorithm 1 의 if t < K then 부분과 대응)
                if i > num_warmup_steps_bind:
                    for optim_step in range(num_optimization_steps):
                        with torch.autograd.set_detect_anomaly(True):
                            # --- (1) 현재 타임스텝에서 z_t로부터 z̃_0 = G(z_t) 계산 (식 (11)) ---
                            # z̃_0 = 1/√ᾱ_t ( z_t - √(1-ᾱ_t) * ϵ̂ )
                            x0 = 1/(self.scheduler.alphas_cumprod[t] ** 0.5) * \
                                 (latents_temp - (1-self.scheduler.alphas_cumprod[t])**0.5 * noise_pred)

                            # --- (2) VAE decoder D 를 통해 z̃_0 → 멜 스펙트럼 (x_0) (3.2.2에서 D(z̃_0)) ---
                            x0_mel_spectrogram = self.decode_latents(x0)

                            if x0_mel_spectrogram.dim() == 4:
                                x0_mel_spectrogram = x0_mel_spectrogram.squeeze(1)

                            # --- (3) Vocoder로 멜 스펙트럼 → waveform (AudioLDM 기본 파이프라인) ---
                            x0_waveform = self.vocoder(x0_mel_spectrogram)

                            # --- (4) waveform을 ImageBind용 audio 입력 형태로 변환 (E_A(·), 3.1.3) ---
                            x0_imagebind_audio_input = load_and_transform_audio_data_from_waveform(
                                x0_waveform,
                                org_sample_rate=self.vocoder.config.sampling_rate,
                                device=device,
                                target_length=204,
                                clip_duration=clip_duration,
                                clips_per_video=clips_per_video
                            )

                            # --- (5) 텍스트, 비디오, 오디오를 모두 ImageBind 공간으로 매핑 (e_v, e_a, e_p) ---
                            if isinstance(prompt, str):
                                prompt_bind = [prompt]
                            inputs = {
                                ModalityType.VISION: image_bind_video_input,          # e_v
                                ModalityType.AUDIO: x0_imagebind_audio_input,        # e_a
                                ModalityType.TEXT: load_and_transform_text(prompt_bind, device)  # e_p
                            }

                            # ImageBind encoder로부터 각 modality embedding 추출 (3.1.3, 식 (8))
                            embeddings = bind_model(inputs)
                            # --- (6) F(e_p, e_a) = 1 - cos(e_p, e_a) (dual loss의 text–audio 항) ---
                            bind_loss_text_audio = 1 - F.cosine_similarity(
                                embeddings[ModalityType.TEXT],
                                embeddings[ModalityType.AUDIO]
                            )

                            # --- (7) F(e_v, e_a) = 1 - cos(e_v, e_a) (dual loss의 vision–audio 항) ---
                            bind_loss_vision_audio = 1 - F.cosine_similarity(
                                embeddings[ModalityType.VISION],
                                embeddings[ModalityType.AUDIO]
                            )

                            # --- (8) L_v2a = F(e_a,e_v) + F(e_a,e_p)를 구현 (식 (15), 순서만 바뀐 대칭 형태) ---
                            #   논문: L_v2a = F(ea, ev) + F(ea, ep)
                            #   여기:  F(ev, ea) + F(ep, ea)  [cosine distance라서 대칭이므로 동일]
                            bind_loss = bind_loss_text_audio + bind_loss_vision_audio

                            # --- (9) ∇_{z_t} L 를 이용해 z_t 업데이트 (식 (13)의 gradient step) ---
                            bind_loss.backward()
                            optimizer.step()
                            optimizer.zero_grad()

                # 최적화된 latents_temp를 다시 z_t로 사용
                latents = latents_temp.detach()

                # --- 7-4. 진행바 및 콜백 (시각화/디버깅용) ---
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        # --- 8. 후처리: 최종 z_0 → 멜 스펙트럼 → waveform (AudioLDM 기본 파이프라인) ---
        mel_spectrogram = self.decode_latents(latents)

        audio = self.mel_spectrogram_to_waveform(mel_spectrogram)  # [1, 128032]
        audio = audio[:, :original_waveform_length]                 # [1, 128000]

        if output_type == "np":
            audio = audio.detach().numpy()

        if not return_dict:
            return (audio,)

        return AudioPipelineOutput(audios=audio)