# Diffusion Latent 를 ImageBind 멀티모달 공간과 정렬시키는 Cross-Modal Latent Optimization Loop

In [None]:
 def bind_forward_double_loss(
        self,
        prompt: Union[str, List[str]] = None,
        audio_length_in_s: Optional[float] = None,
        num_inference_steps: int = 10,
        guidance_scale: float = 2.5,
        learning_rate: float = 0.1,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_waveforms_per_prompt: Optional[int] = 1,
        clip_duration: float = 2.0,
        clips_per_video: int = 5,
        num_optimization_steps: int = 1,
        optimization_starting_point: float = 0.2,
        eta: float = 0.0,
        video_paths: Union[str, List[str]] = None,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        output_type: Optional[str] = "np",
    ):
        # 입력된 오디오 길이를 mel-spectrogram의 height로 변환
        # AudioLDM 구조상 mel-spectrogram height는 upsample factor와 VAE scale factor에 의해 결정된다.
        vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate

        if audio_length_in_s is None:
            audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor

        height = int(audio_length_in_s / vocoder_upsample_factor)

        original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)

        #height이 VAE scale factor 배수로 정리되도록 조정
        if height % self.vae_scale_factor != 0:
            height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
            logger.info(
                f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
                f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
                f"denoising process."
            )

        self.check_inputs(
            prompt,
            audio_length_in_s,
            vocoder_upsample_factor,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
        )

        # 배치 계산 및 guidance 여부 판단
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        do_classifier_free_guidance = guidance_scale > 1.0

        # 텍스트 프롬프트를 CLAP text encoder로 임베딩 → diffusion UNet에 class conditioning으로 제공됨
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_waveforms_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        # Diffusion scheduler에서 타임스텝 생성
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 초기 latent 샘플 생성 (z_T ~ N(0, I))
        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_waveforms_per_prompt,
            num_channels_latents,
            height,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )
        latents_dtype = latents.dtype

        # Scheduler 추가 kwargs 준비
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # ImageBind 기반 영상 특징 추출
        # “video embedding + audio embedding consistency loss”부분
        image_bind_video_input = load_and_transform_video_data(video_paths, device, clip_duration=clip_duration, clips_per_video=clips_per_video, n_samples_per_clip=2)

        bind_model = imagebind_model.imagebind_huge(pretrained=True)

        bind_model.eval()
        bind_model.to(device)

         # ImageBind는 frozen → gradient update 금지
        for p in bind_model.parameters():
            p.requires_grad = False

        # Diffusion + Latent Optimization Loop
        #   논문 구조 그대로:
        #   (1) UNet으로 noise_pred 예측 → xₜ→xₜ₋₁
        #   (2) 일정 단계 이후, latent를 ImageBind 멀티모달 공간에 맞게 보정
        #       (vision–audio, text–audio alignment loss)

        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        num_warmup_steps_bind = int(len(timesteps) * optimization_starting_point)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):

                # (1) Classifier-Free Guidance input 구성
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # (2) U-Net으로 noise residual 예측
                with torch.no_grad():
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=None,
                        class_labels=prompt_embeds,
                        cross_attention_kwargs=cross_attention_kwargs,
                    ).sample.to(dtype=latents_dtype)

                # Classifier-Free Guidance 적용
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # (3) Diffusion update: x_t → x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                 # (4) Latent optimization 시작 조건 -> 초기 coarse structure 생성 이후 fine-grained alignment 수행
                latents_temp = latents.detach()
                latents_temp.requires_grad = True

                optimizer = torch.optim.Adam([latents_temp], lr=learning_rate)

                if i > num_warmup_steps_bind:

                    #   Latent Optimization
                    #   UNet의 diffusion 업데이트와 별개로,ImageBind의 multimodal embedding space에 맞도록 latent를 직접 gradient descent로 수정
                    for optim_step in range(num_optimization_steps):
                        with torch.autograd.set_detect_anomaly(True):
                            # 1. x₀ 추정: xₜ → x₀ (DDIM 역전 과정)
                            x0 = 1/(self.scheduler.alphas_cumprod[t] ** 0.5) * (latents_temp - (1-self.scheduler.alphas_cumprod[t])**0.5 * noise_pred)

                            # 2. VAE 디코더로 mel-spectrogram 복원
                            x0_mel_spectrogram = self.decode_latents(x0)

                            if x0_mel_spectrogram.dim() == 4:
                                x0_mel_spectrogram = x0_mel_spectrogram.squeeze(1)

                            # 3. mel → waveform 복원
                            x0_waveform = self.vocoder(x0_mel_spectrogram)

                            # 4.waveform → ImageBind 오디오 입력 형식으로 변환
                            x0_imagebind_audio_input = load_and_transform_audio_data_from_waveform(x0_waveform, org_sample_rate=self.vocoder.config.sampling_rate,
                                                device=device, target_length=204, clip_duration=clip_duration, clips_per_video=clips_per_video)

                            # multimodal input 구성
                            if isinstance(prompt, str):
                                prompt_bind = [prompt]
                            inputs = {
                                ModalityType.VISION: image_bind_video_input,
                                ModalityType.AUDIO: x0_imagebind_audio_input,
                                ModalityType.TEXT: load_and_transform_text(prompt_bind, device)
                            }

                            # with torch.no_grad(): ImageBind 멀티모달 임베딩 계산
                            embeddings = bind_model(inputs)

                            #Cross-modal alignment loss 계산
                            bind_loss_text_audio = 1 - F.cosine_similarity(embeddings[ModalityType.TEXT], embeddings[ModalityType.AUDIO])

                            bind_loss_vision_audio = 1 - F.cosine_similarity(embeddings[ModalityType.VISION], embeddings[ModalityType.AUDIO])

                            bind_loss = bind_loss_text_audio + bind_loss_vision_audio

                            #Latent 업데이트
                            bind_loss.backward()
                            optimizer.step()
                            optimizer.zero_grad()

                # 업데이트된 latent로 교체
                latents = latents_temp.detach()

                # call the callback
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        #  최종 latents를 mel → waveform으로 변환 (AudioLDM standard)
        mel_spectrogram = self.decode_latents(latents)

        audio = self.mel_spectrogram_to_waveform(mel_spectrogram) # [1, 128032]
        audio = audio[:, :original_waveform_length] # [1, 128000]

        if output_type == "np":
            audio = audio.detach().numpy()

        if not return_dict:
            return (audio,)

        return AudioPipelineOutput(audios=audio)