In [1]:
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    else:
        raise ValueError(f"{model_class} is not supported.")
    

tokenizer = AutoTokenizer.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base"
    subfolder="tokenizer",
    revision=None,
    use_fast=False,
)

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path( "stabilityai/stable-diffusion-2-1-base", None)

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained( "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
     "stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder", revision=None, variant=None
)
vae = AutoencoderKL.from_pretrained(
     "stabilityai/stable-diffusion-2-1-base", subfolder="vae", revision=None, variant=None
)
unet = UNet2DConditionModel.from_pretrained(
   "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision=None, variant=None
)


   
controlnet = ControlNetModel.from_unet(unet)