In [31]:
import argparse
import logging
import math
import os
import random
from pathlib import Path
from typing import Optional

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset, Dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)

from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

# Test Dataset

In [18]:
logger = get_logger(__name__)

def make_train_dataset():

    dataset = Dataset.from_json("/home/v-yuancwang/DiffAudioImg/metadata/vgg_train_2.json")
    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset.column_names

    # 6. Get the column names for input/target.
    # if args.image_column is None:
    image_column = column_names[0]
    #     logger.info(f"image column defaulting to {image_column}")
    # else:
    #     image_column = args.image_column
    #     if image_column not in column_names:
    #         raise ValueError(
    #             f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
    #         )

    # if args.caption_column is None:
    caption_column = column_names[2]
    #     logger.info(f"caption column defaulting to {caption_column}")
    # else:
    #     caption_column = args.caption_column
    #     if caption_column not in column_names:
    #         raise ValueError(
    #             f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
    #         )

    # if args.conditioning_image_column is None:
    conditioning_image_column = column_names[1]
    #     logger.info(f"conditioning image column defaulting to {caption_column}")
    # else:
    #     conditioning_image_column = args.conditioning_image_column
    #     if conditioning_image_column not in column_names:
    #         raise ValueError(
    #             f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
    #         )

    # def tokenize_captions(examples, is_train=True):
    #     captions = []
    #     for caption in examples[caption_column]:
    #         if random.random() < 0.5:
    #             captions.append("")
    #         elif isinstance(caption, str):
    #             captions.append(caption)
    #         elif isinstance(caption, (list, np.ndarray)):
    #             # take a random caption if there are multiple
    #             captions.append(random.choice(caption) if is_train else caption[0])
    #         else:
    #             raise ValueError(
    #                 f"Caption column `{caption_column}` should contain either strings or lists of strings."
    #             )
    #     inputs = tokenizer(
    #         captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    #     )
    #     return inputs.input_ids

    image_transforms = transforms.Compose(
        [
            transforms.Resize((256,256), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    conditioning_image_transforms = transforms.Compose(
        [
        transforms.ToTensor(),
            transforms.Resize((256,256), interpolation=transforms.InterpolationMode.BILINEAR),
        ]
    )

    def preprocess_train(examples):
        images = [Image.open(os.path.join("/blob/v-yuancwang/DiffAudioImg/VGGSound/data/vggsound/img_spilt", image)).convert("RGB") 
                  for image in examples[image_column]]
        images = [image_transforms(image) for image in images]

        conditioning_images = [np.load(os.path.join("/blob/v-yuancwang/DiffAudioImg/VGGSound/data/vggsound/mel", image))
                                for image in examples[conditioning_image_column]]
        conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]

        examples["pixel_values"] = images
        examples["conditioning_pixel_values"] = conditioning_images
        # examples["input_ids"] = tokenize_captions(examples)

        return examples

    train_dataset = dataset.with_transform(preprocess_train)

    return train_dataset

dataset = make_train_dataset()
print(dataset[0]["pixel_values"].shape)
print(dataset[0]["conditioning_pixel_values"].shape)

Found cached dataset json (/home/v-yuancwang/.cache/huggingface/datasets/json/default-410212d31ddd5a3a/0.0.0)


torch.Size([3, 256, 256])
torch.Size([1, 256, 256])


# Test ControlNet Format

In [19]:
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"

In [None]:
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet"
)

In [None]:
controlnet = ControlNetModel.from_unet(unet)

# Try ControlNet