# Text-guided image-to-image generation

The [StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline) lets you pass a text prompt and an initial image to condition the generation of new images.

Before you begin, make sure you have all the necessary libraries installed:

In [None]:
#@title Imports
!pip install diffusers huggingface-hub accelerate
!pip install transformers ftfy torch torchaudio ipython typing
!pip install torch fadtk



Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing
  Downloading typing-3.7.4.3.tar.gz (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jedi>=0.16 (from ipython)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: typing
  Building wheel for typing (setup.py) ... [?25l[?25hdone
  Created wheel for typing: filename=typing-3.7.4.3-py3-none-any.whl size=26304 sha256=79db6d24afe9a6a2a38a6cd5f9dfbd8f1f12aac

^C


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# Importing necessary classes and functions
import inspect
from typing import Callable, List, Optional, Union

import numpy as np
import PIL
import torch
import torchvision.transforms
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import subprocess
import os
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    PIL_INTERPOLATION,
    deprecate,
    is_accelerate_available,
    is_accelerate_version,
    logging,
    replace_example_docstring,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

# Importing StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionImg2ImgPipeline

In [None]:
import locale
import numpy as np
from PIL import Image, ImageEnhance, ImageChops
import torch
import requests
from io import BytesIO
import ipywidgets as widgets
from IPython.display import display
from typing import *
import torchaudio
from IPython.display import Audio
from torchvision import transforms
import locale
locale.getpreferredencoding = lambda: "UTF-8"




In [None]:
def randn_tensor(
    shape: Union[Tuple, List],
    generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
    device: Optional["torch.device"] = None,
    dtype: Optional["torch.dtype"] = None,
    layout: Optional["torch.layout"] = None,
):
    """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
    passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
    is always created on the CPU.
    """
    # device on which tensor is created defaults to device
    rand_device = device
    batch_size = shape[0]

    layout = layout or torch.strided
    device = device or torch.device("cpu")

    if generator is not None:
        gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
        if gen_device_type != device.type and gen_device_type == "cpu":
            rand_device = "cpu"
            if device != "mps":
                logger.info(
                    f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
                    f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
                    f" slighly speed up this function by passing a generator that was created on the {device} device."
                )
        elif gen_device_type != device.type and gen_device_type == "cuda":
            raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")

    # make sure generator list of length 1 is treated like a non-list
    if isinstance(generator, list) and len(generator) == 1:
        generator = generator[0]

    if isinstance(generator, list):
        shape = (1,) + shape[1:]
        latents = [
            torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
            for i in range(batch_size)
        ]
        latents = torch.cat(latents, dim=0).to(device)
    else:
        latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)

    return latents

In [None]:
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import requests
        >>> import torch
        >>> from PIL import Image
        >>> from io import BytesIO

        >>> from diffusers import StableDiffusionImg2ImgPipeline

        >>> device = "cuda"
        >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
        >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
        >>> pipe = pipe.to(device)

        >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

        >>> response = requests.get(url)
        >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
        >>> init_image = init_image.resize((768, 512))

        >>> prompt = "A fantasy landscape, trending on artstation"

        >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
        >>> images[0].save("fantasy_landscape.png")
        ```
"""



In [None]:
def preprocess(image):
    if isinstance(image, torch.Tensor):
        return image
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    if isinstance(image[0], PIL.Image.Image):
        w, h = image[0].size
        w, h = (dim - dim % 8 for dim in (w, h))  # resize to integer multiple of 8

        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
        image = np.concatenate(image, axis=0)
        image = np.array(image).astype(np.float32) / 255.0
        image = image.transpose(0, 3, 1, 2)
        image = 2.0 * image - 1.0
        image = torch.from_numpy(image)
    elif isinstance(image[0], torch.Tensor):
        image = torch.cat(image, dim=0)
    return image

def preprocess_map(map):
    map = map.convert("L")
    map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
    # convert to tensor
    map = transforms.ToTensor()(map)
    map = map.to(device)
    return map

def preprocess_image(image):
    image = image.convert("RGB")
    image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
    image = transforms.ToTensor()(image)
    image = image * 2 - 1
    image = image.unsqueeze(0).to(device)
    return image


In [None]:
class MaskStableDiffusionImg2ImgPipeline(DiffusionPipeline):
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPFeatureExtractor,
        requires_safety_checker: bool = True,
    ):
        super().__init__()

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)



        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)

        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.register_to_config(requires_safety_checker=requires_safety_checker)

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
    def enable_sequential_cpu_offload(self, gpu_id=0):
        r"""
        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
        Note that offloading happens on a submodule basis. Memory savings are higher than with
        `enable_model_cpu_offload`, but performance is lower.
        """
        if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
            from accelerate import cpu_offload
        else:
            raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")

        device = torch.device(f"cuda:{gpu_id}")

        if self.device.type != "cpu":
            self.to("cpu", silence_dtype_warnings=True)
            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)

        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
            cpu_offload(cpu_offloaded_model, device)

        if self.safety_checker is not None:
            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
    def enable_model_cpu_offload(self, gpu_id=0):
        r"""
        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
        """
        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
            from accelerate import cpu_offload_with_hook
        else:
            raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")

        device = torch.device(f"cuda:{gpu_id}")

        if self.device.type != "cpu":
            self.to("cpu", silence_dtype_warnings=True)
            torch.cuda.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)

        hook = None
        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

        if self.safety_checker is not None:
            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)

        # We'll offload the last model manually.
        self.final_offload_hook = hook

    @property
    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
    def _execution_device(self):
        r"""
        Returns the device on which the pipeline's models will be executed. After calling
        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
        hooks.
        """
        if not hasattr(self.unet, "_hf_hook"):
            return self.device
        for module in self.unet.modules():
            if (
                hasattr(module, "_hf_hook")
                and hasattr(module._hf_hook, "execution_device")
                and module._hf_hook.execution_device is not None
            ):
                return torch.device(module._hf_hook.execution_device)
        return self.device

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    ):
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
             prompt (`str` or `List[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
        """
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None

            prompt_embeds = self.text_encoder(
                text_input_ids.to(device),
                attention_mask=attention_mask,
            )
            prompt_embeds = prompt_embeds[0]

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            negative_prompt_embeds = negative_prompt_embeds[0]

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        return prompt_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is not None:
            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        else:
            has_nsfw_concept = None
        return image, has_nsfw_concept

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

    def check_inputs(
        self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
    ):
        if strength < 0 or strength > 1:
            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

        if (callback_steps is None) or (
            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
        ):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start:]

        return timesteps, num_inference_steps - t_start

    def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
            raise ValueError(
                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
            )

        image = image.to(device=device, dtype=dtype)

        batch_size = batch_size * num_images_per_prompt
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if isinstance(generator, list):
            init_latents = [
                self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
            ]
            init_latents = torch.cat(init_latents, dim=0)
        else:
            init_latents = self.vae.encode(image).latent_dist.sample(generator)

        init_latents = self.vae.config.scaling_factor * init_latents

        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
            # expand init_latents for batch_size
            deprecation_message = (
                f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
                " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
                " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
                " your script to pass as many initial images as text prompts to suppress this warning."
            )
            deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
            additional_image_per_prompt = batch_size // init_latents.shape[0]
            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
            raise ValueError(
                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
            )
        else:
            init_latents = torch.cat([init_latents], dim=0)

        shape = init_latents.shape
        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        # get latents
        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
        latents = init_latents

        return latents

    @torch.no_grad()

    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        image: Union[torch.FloatTensor, PIL.Image.Image] = None,
        strength: float = 1,
        num_inference_steps: Optional[int] = 50,
        guidance_scale: Optional[float] = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: Optional[float] = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: int = 1,
        map:torch.FloatTensor = None,
    ):
        r"""

        """
        # 1. Check inputs. Raise error if not correct
        self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]
        device = self._execution_device

        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        prompt_embeds = self._encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        # 4. Preprocess image
        image = preprocess(image)

        # 5. set timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)


        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
        map = torchvision.transforms.Resize(tuple(s // self.vae_scale_factor for s in image.shape[2:]),antialias=None)(map)

        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

        # prepartions
        original_with_noise = self.prepare_latents(
            image, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
        )
        thresholds = torch.arange(len(timesteps), dtype=map.dtype) / len(timesteps)
        thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
        masks = map > thresholds
        # end diff diff preparations

        with self.progress_bar(total=num_inference_steps) as progress_bar:

            for i, t in enumerate(timesteps):
                # diff diff
                if i == 0:
                    latents = original_with_noise[:1]
                else:
                    mask = masks[i].unsqueeze(0)
                    # cast mask to the same type as latents etc
                    mask = mask.to(latents.dtype)
                    mask = mask.unsqueeze(1)  # fit shape
                    latents = original_with_noise[i] * mask + latents * (1 - mask)
                    # end diff diff
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        # 9. Post-processing
        image = self.decode_latents(latents)

        # 10. Run safety checker
        image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        #has_nsfw_concept = False

        # 11. Convert to PIL
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        # Offload last model to CPU
        if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
            self.final_offload_hook.offload()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

In [None]:
device = "cuda"
pipe = MaskStableDiffusionImg2ImgPipeline.from_pretrained("cappiepappie/stable-diffusion-music-25", torch_dtype=torch.float16, safety_checker=None).to(
    device
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/246M [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/322 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/603 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/737 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/641 [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

An error occurred while trying to fetch /root/.cache/huggingface/hub/models--cappiepappie--stable-diffusion-music-25/snapshots/4f3317ca7f5fe21251f059246ec24c80ea8e15a3/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--cappiepappie--stable-diffusion-music-25/snapshots/4f3317ca7f5fe21251f059246ec24c80ea8e15a3/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /root/.cache/huggingface/hub/models--cappiepappie--stable-diffusion-music-25/snapshots/4f3317ca7f5fe21251f059246ec24c80ea8e15a3/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--cappiepappie--stable-diffusion-music-25/snapshots/4f3317ca7f5fe21251f059246ec24c80ea8e15a3/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Expected types for feature_extractor: (<class 'transformer

In [None]:
import numpy as np
from PIL import Image

image_width = 512
sample_rate = 44100  # [Hz]
clip_duration_ms = 5000  # [ms]

bins_per_image = 512
n_mels = 512

# FFT parameters
window_duration_ms = 100  # [ms]
padded_duration_ms = 400  # [ms]
step_size_ms = 10  # [ms]

# Derived parameters
num_samples = int(image_width / float(bins_per_image) * clip_duration_ms) * sample_rate
n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
hop_length = int(step_size_ms / 1000.0 * sample_rate)
win_length = int(window_duration_ms / 1000.0 * sample_rate)

In [None]:
def spectrogram_from_image(
    image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
) -> np.ndarray:

    data = np.array(image).astype(np.float32)
    data = data[::-1, :, 0]
    data = 255 - data
    data = data * max_volume / 255
    data = np.power(data, 1 / power_for_image)

    return data

In [None]:


def waveform_from_spectrogram(
    Sxx: np.ndarray,
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    num_samples=num_samples,
    sample_rate=sample_rate,
    mel_scale: bool = True,
    n_mels: int = 512,
    max_mel_iters: int = 200,
    num_griffin_lim_iters: int = 32,
    device: str = "cuda",
) -> np.ndarray:
    Sxx_torch = torch.from_numpy(Sxx).to(device)

    if mel_scale:
        mel_inv_scaler = torchaudio.transforms.InverseMelScale(
            n_mels=n_mels,
            sample_rate=sample_rate,
            f_min=0,
            f_max=10000,
            n_stft=n_fft // 2 + 1,
            norm=None,
            mel_scale="htk"
        ).to(device)
        Sxx_torch = mel_inv_scaler(Sxx_torch)

    griffin_lim = torchaudio.transforms.GriffinLim(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        power=1.0,
        n_iter=num_griffin_lim_iters,
    ).to(device)

    waveform = griffin_lim(Sxx_torch).cpu().numpy()
    return waveform

In [None]:
def spectrogram_from_waveform(
    waveform: np.ndarray,
    sample_rate=sample_rate,
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    mel_scale: bool = True,
    n_mels: int = 512,
) -> np.ndarray:
    spectrogram_func = torchaudio.transforms.Spectrogram(
        n_fft=n_fft,
        power=None,
        hop_length=hop_length,
        win_length=win_length,
    )

    waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
    Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]
    Sxx_mag = np.abs(Sxx_complex)

    if mel_scale:
        mel_scaler = torchaudio.transforms.MelScale(
            n_mels=n_mels,
            sample_rate=sample_rate,
            f_min=0,
            f_max=10000,
            n_stft=n_fft // 2 + 1,
            norm=None,
            mel_scale="htk",
        )
        Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()

    return Sxx_mag


In [None]:
def image_from_spectrogram(spectrogram: np.ndarray, max_volume: float = 50, power_for_image: float = 0.25) -> Image.Image:
    data = np.power(spectrogram, power_for_image)
    data = data * 255 / max_volume
    data = 255 - data
    image = Image.fromarray(data.astype(np.uint8))
    image = image.transpose(Image.FLIP_TOP_BOTTOM)
    image = image.convert("RGB")
    return image

## Mask creator module
- Enter image parameters and adjust mask as required


In [None]:
# ====================================================================
#
#       COMPLETE & CORRECTED ABLATION STUDY SCRIPT (V3)
#
# This version fixes the critical spectrogram normalization issue by
# adaptively calculating `max_volume` for each audio file,
# ensuring input spectrograms are visually correct.
#
# ====================================================================

import os
import numpy as np
import pandas as pd
from PIL import Image, ImageChops, ImageEnhance
import torch
import librosa
import soundfile as sf
from skimage.metrics import structural_similarity as ssim
import shutil

print("Ablation study script started.")

# ====================================================================
# 1. SETUP AND CONFIGURATION
# ====================================================================

# --- Base Paths (Ensure these are correct) ---
BASE_DRIVE_PATH = "/content/gdrive/MyDrive/MusicGen"
INPUT_NOISE_FOLDER = os.path.join(BASE_DRIVE_PATH, "RemainingNoise")
OUTPUT_ABLATION_FOLDER = os.path.join(BASE_DRIVE_PATH, "RemainingAblation")

# --- Study Hyperparameters ---
BRIGHTNESS_CONTROL = 0.05
CONTRAST_CONTROL = 200.0
DOWNSHIFT_FIXED = 40
INVERT_FACTOR = 1.0 # Full inversion
POWER_FOR_IMAGE = 0.25 # From your notebook

# --- Diffusion Model Parameters ---
PROMPT = "ambient, saxophone"
STRENGTH = 0.8
GUIDANCE_SCALE = 8.0
GENERATOR_SEED = 1040
# Re-initialize generator inside the loop for determinism if needed, but this is fine for now
generator = torch.Generator(device="cuda").manual_seed(GENERATOR_SEED)

# --- Results Collection ---
all_results = []

# --- Create Base Output Directory ---
os.makedirs(OUTPUT_ABLATION_FOLDER, exist_ok=True)
print(f"Base output folder ready at: {OUTPUT_ABLATION_FOLDER}")


# ====================================================================
# 2. HELPER FUNCTIONS (UNCHANGED, BUT ASSUMED FROM YOUR NOTEBOOK)
# All functions from your final_loss.py notebook are assumed to be defined here.
# For clarity, I'm including the mask functions again.
# ====================================================================
def invert_image(image: Image.Image, factor: float = 1.0) -> Image.Image:
    img = image.convert("L")
    inverted_image = ImageChops.invert(img)
    return Image.blend(img, inverted_image, factor)

def bright_image(image: Image.Image, factor: float) -> Image.Image:
    enhancer = ImageEnhance.Brightness(image)
    return enhancer.enhance(factor)

def contrast_image(image: Image.Image, factor: float) -> Image.Image:
    enhancer = ImageEnhance.Contrast(image)
    return enhancer.enhance(factor)

def downshift_image(image: Image.Image, downshift_value: int) -> Image.Image:
    img = image.convert("L")
    image_data = np.array(img)
    downshifted_data = np.where(image_data > downshift_value, image_data - downshift_value, 0).astype(np.uint8)
    return Image.fromarray(downshifted_data)

def generate_mask_and_intermediates(original_image, brightness, contrast, downshift, perform_brightness_scaling=False, target_brightness=0):
    intermediates = {"01_Original_Image": original_image}
    img = invert_image(original_image, INVERT_FACTOR)
    intermediates["02_After_Inversion"] = img
    if brightness is not None:
        img = bright_image(img, brightness)
        intermediates["03_After_Brightness_Adjustment"] = img
    if perform_brightness_scaling:
        current_brightness = np.mean(np.array(img.convert("L")))
        if current_brightness > 1:
            scale_factor = target_brightness / current_brightness
            img = bright_image(img, scale_factor)
        intermediates["04_After_Brightness_Scaling_(Fair_Contrast)"] = img
    if contrast is not None:
        img = contrast_image(img, contrast)
        intermediates["05_After_Contrast_Adjustment"] = img
    final_mask = downshift_image(img, downshift)
    intermediates[f"06_Final_Mask_(Downshift_{downshift})"] = final_mask
    return final_mask, intermediates

def compute_ssim_score(img1, img2):
    arr1 = np.array(img1.convert("L"))
    arr2 = np.array(img2.convert("L"))
    score, _ = ssim(arr1, arr2, full=True, data_range=255.0)
    return score

def compute_cens_score(input_spec_img, output_spec_img, sample_rate=44100):
    try:
        input_waveform = waveform_from_spectrogram(spectrogram_from_image(input_spec_img))
        output_waveform = waveform_from_spectrogram(spectrogram_from_image(output_spec_img))
        input_chroma = librosa.feature.chroma_cens(y=input_waveform, sr=sample_rate)
        output_chroma = librosa.feature.chroma_cens(y=output_waveform, sr=sample_rate)
        norm_input = np.linalg.norm(input_chroma.flatten())
        norm_output = np.linalg.norm(output_chroma.flatten())
        if norm_input == 0 or norm_output == 0: return 0.0
        similarity = np.dot(input_chroma.flatten(), output_chroma.flatten()) / (norm_input * norm_output)
        return similarity
    except Exception as e:
        print(f"  [Warning] Could not compute CENS score: {e}")
        return np.nan



def compute_fad_folders(test_folders, output_dir):
    """
    Computes FAD scores using fixed model (clap-laion-audio) and reference (fma_pop).

    Parameters:
    - test_folders: List[str] of folders with generated audio
    - output_dir: str, path where the FAD score CSVs will be saved

    Returns:
    - List[str]: Paths to the generated FAD score CSVs
    """
    MODEL = "clap-laion-audio"
    REFERENCE = "fma_pop"  # assumes 'fma_pop' is registered as a known dataset in fadtk

    os.makedirs(output_dir, exist_ok=True)
    score_files = []

    for folder in test_folders:
        folder_name = os.path.basename(os.path.normpath(folder))
        output_csv = os.path.join(output_dir, f"{folder_name}_fad_score.csv")

        command = [
            "fadtk", MODEL, REFERENCE, folder,
            output_csv, "--indiv"
        ]

        try:
            print(f"[FAD] Computing for: {folder}")
            subprocess.run(command, check=True)
            print(f"[FAD] Saved: {output_csv}")
            score_files.append(output_csv)
        except subprocess.CalledProcessError as e:
            print(f"[FAD] Failed for {folder}: {e}")

    return score_files


# ====================================================================
# 3. MAIN ABLATION LOOP
# ====================================================================

if not os.path.exists(INPUT_NOISE_FOLDER):
    print(f"[ERROR] Input folder not found: {INPUT_NOISE_FOLDER}")
else:
    song_files = [f for f in os.listdir(INPUT_NOISE_FOLDER) if f.endswith(('.mp3', '.wav', '.flac'))]
    print(f"Found {len(song_files)} audio files to process.")

    for i, song_file in enumerate(song_files):
        song_name = os.path.splitext(song_file)[0]
        print(f"\n--- Processing {i+1}/{len(song_files)}: {song_name} ---")

        song_output_dir = os.path.join(OUTPUT_ABLATION_FOLDER, song_name)
        os.makedirs(song_output_dir, exist_ok=True)

        try:
            audio_path = os.path.join(INPUT_NOISE_FOLDER, song_file)
            audio_cust_whole, sr = librosa.load(audio_path, sr=44100)
            audio_cust_data = audio_cust_whole[:225500]

            # --- FIX: ADAPTIVE NORMALIZATION LOGIC ---
            # Step 1: Generate raw spectrogram data
            spec_cust_data = spectrogram_from_waveform(audio_cust_data)

            # Step 2: Calculate the adaptive max_volume for THIS specific spectrogram
            max_volume = np.ceil(np.power(np.max(spec_cust_data), POWER_FOR_IMAGE))

            # Step 3: Create the properly normalized input image
            input_spec_image = image_from_spectrogram(spec_cust_data, max_volume=max_volume, power_for_image=POWER_FOR_IMAGE)
            # --- END OF FIX ---

            sf.write(os.path.join(song_output_dir, "input_audio.wav"), audio_cust_data, sr)
            input_spec_image.save(os.path.join(song_output_dir, "input_spectrogram.png"))
        except Exception as e:
            print(f"  [ERROR] Failed to load or process {song_file}. Skipping. Error: {e}")
            continue

        target_img_for_brightness = bright_image(invert_image(input_spec_image), BRIGHTNESS_CONTROL)
        TARGET_BRIGHTNESS = np.mean(np.array(target_img_for_brightness.convert("L")))

        tests_to_run = ["control", "no_brightness_ablation", "no_contrast_ablation"]

        for test_type in tests_to_run:
            print(f"  -> Running Test: {test_type}")
            test_output_dir = os.path.join(song_output_dir, test_type)
            os.makedirs(test_output_dir, exist_ok=True)

            final_mask_pil, intermediates = None, {}

            if test_type == "control":
                final_mask_pil, intermediates = generate_mask_and_intermediates(input_spec_image, BRIGHTNESS_CONTROL, CONTRAST_CONTROL, DOWNSHIFT_FIXED)
            elif test_type == "no_brightness_ablation":
                final_mask_pil, intermediates = generate_mask_and_intermediates(input_spec_image, None, CONTRAST_CONTROL, DOWNSHIFT_FIXED)
            elif test_type == "no_contrast_ablation":
                final_mask_pil, intermediates = generate_mask_and_intermediates(input_spec_image, BRIGHTNESS_CONTROL, None, DOWNSHIFT_FIXED, perform_brightness_scaling=True, target_brightness=TARGET_BRIGHTNESS)

            for name, img in intermediates.items():
                img.save(os.path.join(test_output_dir, f"{name}.png"))

            try:
                final_map_tensor = preprocess_map(final_mask_pil)
                output_image = pipe(
                    prompt=PROMPT, image=input_spec_image, strength=STRENGTH, guidance_scale=GUIDANCE_SCALE,
                    map=final_map_tensor, generator=generator
                ).images[0]

                output_audio = waveform_from_spectrogram(spectrogram_from_image(output_image))
                max_val = np.max(np.abs(output_audio))

                if max_val > 0:
                  normalized_audio = output_audio / max_val
                else:
                  normalized_audio = output_audio

                output_image.save(os.path.join(test_output_dir, "output_spectrogram.png"))
                sf.write(os.path.join(test_output_dir, "output_audio.wav"), normalized_audio, sr)

                ssim_score = compute_ssim_score(input_spec_image, output_image)
                cens_score = compute_cens_score(input_spec_image, output_image)

                with open(os.path.join(test_output_dir, "scores.txt"), "w") as f:
                    f.write(f"SSIM_Score: {ssim_score}\n")
                    f.write(f"CENS_Score: {cens_score}\n")

                all_results.append({"song_name": song_name, "test_type": test_type, "ssim": ssim_score, "cens": cens_score})
                print(f"    SSIM: {ssim_score:.4f}, CENS: {cens_score:.4f}")

            except Exception as e:
                print(f"    [ERROR] Diffusion or scoring failed for {test_type}. Error: {e}")
                all_results.append({"song_name": song_name, "test_type": test_type, "ssim": np.nan, "cens": np.nan})

# ====================================================================
# 4. GENERATE FINAL SUMMARY CSV
# ====================================================================

if all_results:
    print("\n--- Generating final summary CSV ---")
    df = pd.DataFrame(all_results)

    pivot_df = df.pivot(index='song_name', columns='test_type', values=['ssim', 'cens'])
    pivot_df.columns = [f'{val}_{col}' for val, col in pivot_df.columns]
    pivot_df.reset_index(inplace=True)

    csv_path = os.path.join(OUTPUT_ABLATION_FOLDER, "ablation_study_summary.csv")
    pivot_df.to_csv(csv_path, index=False)
    print(f"Successfully saved summary to: {csv_path}")
else:
    print("\nNo results were generated. CSV file not created.")

print("\n--- Ablation study script finished! ---")

Ablation study script started.
Base output folder ready at: /content/gdrive/MyDrive/MusicGen/RemainingAblation
Found 10 audio files to process.

--- Processing 1/10: laser ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8398, CENS: 0.7339
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8381, CENS: 0.6724
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.7419, CENS: 0.7603

--- Processing 2/10: iphone-alarm ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6749, CENS: 0.6839
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.7941, CENS: 0.7630
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6025, CENS: 0.7242

--- Processing 3/10: iphone-notif ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8384, CENS: 0.6633
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8477, CENS: 0.6208
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.7992, CENS: 0.4540

--- Processing 4/10: gong ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.5387, CENS: 0.8089
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.5912, CENS: 0.9403
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.3721, CENS: 0.7304

--- Processing 5/10: bicicle-bell ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.7022, CENS: 0.4725
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.7102, CENS: 0.1787
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.7714, CENS: 0.3096

--- Processing 6/10: discord-leave ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8290, CENS: 0.7085
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.9099, CENS: 0.7639
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8876, CENS: 0.5043

--- Processing 7/10: door-knocking ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6201, CENS: 0.7870
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6333, CENS: 0.8131
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.5343, CENS: 0.2669

--- Processing 8/10: doorbell ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.5761, CENS: 0.7921
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.8007, CENS: 0.8745
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6730, CENS: 0.6378

--- Processing 9/10: alarm ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.4909, CENS: 0.5662
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6057, CENS: 0.5632
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.4863, CENS: 0.5103

--- Processing 10/10: water-splash ---
  -> Running Test: control


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6816, CENS: 0.7452
  -> Running Test: no_brightness_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6760, CENS: 0.8295
  -> Running Test: no_contrast_ablation


  0%|          | 0/40 [00:00<?, ?it/s]

    SSIM: 0.6023, CENS: 0.5672

--- Generating final summary CSV ---
Successfully saved summary to: /content/gdrive/MyDrive/MusicGen/RemainingAblation/ablation_study_summary.csv

--- Ablation study script finished! ---


In [None]:
import pandas as pd
import numpy as np
from scipy import stats
import os

# ====================================================================
# 1. LOAD THE ABLATION STUDY RESULTS
# ====================================================================

# This path is determined from the script you provided.
BASE_DRIVE_PATH = "/content/gdrive/MyDrive/MusicGen"
OUTPUT_ABLATION_FOLDER = os.path.join(BASE_DRIVE_PATH, "Ablation_Final")
CSV_PATH = os.path.join(OUTPUT_ABLATION_FOLDER, "ablation_study_summary.csv")

try:
    df = pd.read_csv(CSV_PATH)
    print(f"Successfully loaded the summary from: {CSV_PATH}\n")
except FileNotFoundError:
    print(f"[ERROR] The file was not found at the specified path: {CSV_PATH}")
    exit() # Exit if the file doesn't exist
except Exception as e:
    print(f"An unexpected error occurred while loading the file: {e}")
    exit()

# ====================================================================
# 2. HELPER FUNCTION FOR PAIRED ANALYSIS
# ====================================================================

def perform_paired_analysis(df, col1, col2):
    """
    Performs a paired t-test and calculates Cohen's d for two columns in a DataFrame.
    """
    # Remove rows with missing data in either column to ensure a fair paired comparison
    analysis_df = df[[col1, col2]].dropna()

    if len(analysis_df) < 3: # Need at least 3 pairs for a meaningful test
        print(f"Not enough data points to analyze {col1} vs {col2}.\n")
        return

    group1 = analysis_df[col1]
    group2 = analysis_df[col2]

    # --- Perform the paired t-test ---
    # The null hypothesis is that the true mean difference between the pairs is zero.
    t_statistic, p_value = stats.ttest_rel(group1, group2)

    # --- Calculate Effect Size (Cohen's d for paired samples) ---
    differences = group1 - group2
    mean_diff = np.mean(differences)
    std_diff = np.std(differences, ddof=1) # ddof=1 for sample standard deviation
    cohens_d = mean_diff / std_diff if std_diff != 0 else 0

    # --- Interpretation ---
    print(f"Comparing '{col1}' vs. '{col2}':")
    print(f"  - Mean of '{col1}': {np.mean(group1):.4f}")
    print(f"  - Mean of '{col2}': {np.mean(group2):.4f}")
    print(f"  - Mean Difference: {mean_diff:.4f}\n")
    print(f"  - Paired T-test Results:")
    print(f"    - T-statistic: {t_statistic:.4f}")
    print(f"    - P-value: {p_value:.4f}")

    # Significance conclusion
    if p_value < 0.05:
        print("    - Conclusion: The difference is STATISTICALLY SIGNIFICANT (p < 0.05).")
        print("      This suggests the ablation had a real effect on the score.")
    else:
        print("    - Conclusion: The difference is NOT statistically significant (p >= 0.05).")
        print("      We cannot conclude that the ablation had a real effect; the observed difference could be due to chance.")

    # Effect size conclusion
    print(f"\n  - Effect Size (Cohen's d): {cohens_d:.4f}")
    if abs(cohens_d) >= 0.8:
        print("    - Interpretation: The effect size is LARGE.")
    elif abs(cohens_d) >= 0.5:
        print("    - Interpretation: The effect size is MEDIUM.")
    elif abs(cohens_d) >= 0.2:
        print("    - Interpretation: The effect size is SMALL.")
    else:
        print("    - Interpretation: The effect size is negligible.")
    print("------------------------------------------------------------------\n")

# ====================================================================
# 3. RUN THE SCIENTIFIC ANALYSIS
# ====================================================================

print("========= Scientific Ablation Study Analysis =========\n")

# --- SSIM Comparisons ---
print("### Part 1: Structural Similarity (SSIM) Analysis ###\n")
perform_paired_analysis(df, 'ssim_control', 'ssim_no_brightness_ablation')
perform_paired_analysis(df, 'ssim_control', 'ssim_no_contrast_ablation')

# --- CENS Comparisons ---
print("### Part 2: Chroma Energy Normalized (CENS) Analysis ###\n")
perform_paired_analysis(df, 'cens_control', 'cens_no_brightness_ablation')
perform_paired_analysis(df, 'cens_control', 'cens_no_contrast_ablation')

Successfully loaded the summary from: /content/gdrive/MyDrive/MusicGen/Ablation_Final/ablation_study_summary.csv


### Part 1: Structural Similarity (SSIM) Analysis ###

Comparing 'ssim_control' vs. 'ssim_no_brightness_ablation':
  - Mean of 'ssim_control': 0.4378
  - Mean of 'ssim_no_brightness_ablation': 0.5023
  - Mean Difference: -0.0645

  - Paired T-test Results:
    - T-statistic: -8.1687
    - P-value: 0.0000
    - Conclusion: The difference is STATISTICALLY SIGNIFICANT (p < 0.05).
      This suggests the ablation had a real effect on the score.

  - Effect Size (Cohen's d): -1.1670
    - Interpretation: The effect size is LARGE.
------------------------------------------------------------------

Comparing 'ssim_control' vs. 'ssim_no_contrast_ablation':
  - Mean of 'ssim_control': 0.4378
  - Mean of 'ssim_no_contrast_ablation': 0.3934
  - Mean Difference: 0.0443

  - Paired T-test Results:
    - T-statistic: 4.8233
    - P-value: 0.0000
    - Conclusion: The difference is STATI