In [1]:
from diffusers import DiffusionPipeline
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pipeline = DiffusionPipeline.from_pretrained(
    'runwayml/stable-diffusion-v1-5', safety_checker=None)

scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer

del pipeline

device, scheduler, tokenizer

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


('cuda',
 PNDMScheduler {
   "_class_name": "PNDMScheduler",
   "_diffusers_version": "0.27.0.dev0",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "timestep_spacing": "leading",
   "trained_betas": null
 },
 CLIPTokenizer(name_or_path='C:\Users\37026\.cache\huggingface\hub\models--runwayml--stable-diffusion-v1-5\snapshots\1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9\tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=T

In [2]:
def read_text(file):
    with open(file) as f:
        text = []
        lines = f.readlines()
        for line in lines:
            #remove \n
            line = line.strip()
            text.append(line)
    return text

def read_images(dir):
    import os
    from PIL import Image
    return [Image.open(os.path.join(dir, i)) for i in os.listdir(dir) if i.endswith('.png')]

In [3]:
from datasets import Dataset
def create_dataset(ori_dir,tune_dir,ori_txt,tune_txt):

    data_ori_images = read_images(ori_dir)
    data_tune_images = read_images(tune_dir)
    text_ori = read_text(ori_txt)
    text_tune = read_text(tune_txt)

    # Create a Hugging Face dataset
    dataset = Dataset.from_dict({'image': data_ori_images, 'text': text_ori,'image_tune':data_tune_images,'text_tune':text_tune  })

    return dataset
pika= create_dataset('./Snorlax_original_promt_1','./Snorlax','./prompt/snorlax.txt','./prompt/snorlax_tune.txt')


In [4]:
from datasets import load_dataset
import torchvision

#图像增强模块
compose = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda img: img.convert('RGB')), 
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    #torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])


    
def f(data):
    #image enhance
    pixel_values = [compose(i) for i in data['image']]

    #text encode
    input_ids = tokenizer.batch_encode_plus(data['text'],
                                            padding='max_length',
                                            truncation=True,
                                            max_length=77).input_ids
    

    pixel_values_tune = [compose(i) for i in data['image_tune']]

    input_ids_tune = tokenizer.batch_encode_plus(data['text_tune'],
                                            padding='max_length',
                                            truncation=True,
                                            max_length=77).input_ids

    return {'pixel_values': pixel_values, 'input_ids': input_ids,'pixel_values_tune': pixel_values_tune, 'input_ids_tune': input_ids_tune}


dataset = pika.map(f,
                      batched=True,
                      batch_size=100,
                      num_proc=1,
                      remove_columns=['image', 'text' , 'image_tune', 'text_tune'])

dataset.set_format(type='torch')




Map:   0%|          | 0/7 [00:00<?, ? examples/s]

In [5]:
#定义loader
def collate_fn(data):
    pixel_values = [i['pixel_values'] for i in data]
    input_ids = [i['input_ids'] for i in data]
    pixel_values_tune = [i['pixel_values_tune'] for i in data]
    input_ids_tune = [i['input_ids_tune'] for i in data]

    pixel_values = torch.stack(pixel_values).to(device)
    input_ids = torch.stack(input_ids).to(device)
    pixel_values_tune = torch.stack(pixel_values_tune).to(device)
    input_ids_tune = torch.stack(input_ids_tune).to(device)

    return {'pixel_values': pixel_values, 'input_ids': input_ids,'pixel_values_tune': pixel_values_tune, 'input_ids_tune': input_ids_tune}


loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     batch_size=1)

len(loader), next(iter(loader))

(7,
 {'pixel_values': tensor([[[[ 0.0667,  0.0902,  0.0824,  ...,  0.1686,  0.1451,  0.1137],
            [ 0.1059,  0.1216,  0.1137,  ...,  0.1686,  0.1373,  0.1137],
            [ 0.1294,  0.1451,  0.1451,  ...,  0.1608,  0.1373,  0.1137],
            ...,
            [-0.0745, -0.0588,  0.0118,  ...,  0.1608,  0.1137,  0.0980],
            [-0.1059, -0.0510, -0.0275,  ...,  0.1216,  0.1294,  0.1373],
            [-0.1373, -0.1451, -0.1373,  ..., -0.0196,  0.0118,  0.0980]],
  
           [[-0.0745, -0.0667, -0.0824,  ..., -0.0353, -0.0510, -0.0588],
            [-0.0588, -0.0431, -0.0510,  ..., -0.0353, -0.0667, -0.0824],
            [-0.0431, -0.0275, -0.0196,  ..., -0.0431, -0.0588, -0.0824],
            ...,
            [-0.2706, -0.2706, -0.1608,  ..., -0.0353, -0.0902, -0.0980],
            [-0.3020, -0.2471, -0.2235,  ..., -0.0745, -0.0588, -0.0431],
            [-0.3333, -0.3255, -0.3098,  ..., -0.1843, -0.1451, -0.0510]],
  
           [[-0.1765, -0.1608, -0.1686,  ..., -0.1

In [6]:
#加载模型
%run encoder.ipynb
%run vae.ipynb
%run unet.ipynb

#准备训练
encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(True)

encoder.eval()
vae.eval()
unet.train()

encoder.to(device)
vae.to(device)
unet.to(device)

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

optimizer, criterion

(AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     differentiable: False
     eps: 1e-08
     foreach: None
     fused: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [7]:
out_encoder_tune= encoder(dataset[0]['input_ids_tune'].to(device))
out_encoder_ori= encoder(dataset[0]['input_ids'].to(device))

In [8]:
def get_loss(data,tune=False):
    with torch.no_grad():

        #[1, 77] -> [1, 77, 768]
        if tune:
            out_encoder = out_encoder_tune
        else:
            out_encoder = out_encoder_ori

        #[1, 3, 512, 512] -> [1, 4, 64, 64]
        if tune:
            out_vae = vae.encoder(data['pixel_values_tune'])
            out_vae = vae.sample(out_vae)
        else:
            out_vae = vae.encoder(data['pixel_values'])
            out_vae = vae.sample(out_vae)

        #0.18215 = vae.config.scaling_factor
        out_vae = out_vae * 0.18215

    #noise generator
    noise = torch.randn_like(out_vae)

    #add noise
    #1000 = scheduler.num_train_timesteps
    #1 = batch size
    noise_step = torch.randint(0, 1000, (1, )).long().to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #calc noise
    out_unet = unet(out_vae=out_vae_noise,
                    out_encoder=out_encoder,
                    time=noise_step)

    #mse loss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)



In [11]:
from tqdm import tqdm
def train():
    loss_sum = 0
    for epoch in tqdm(range(200)):
        for i, data in enumerate(loader):
            lambda_ = 1.0
            loss = get_loss(data)  + lambda_ * get_loss(data,tune=True) 
            #print(epoch, i, loss)
            loss.backward()
            loss_sum += loss.item()

            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #torch.save(unet.to('cpu'), 'saves/unet.model')


train()
torch.save(unet.to('cpu'), 'saves/unet3.pth')

  0%|          | 1/200 [00:18<1:00:01, 18.10s/it]

0 2.3122791722416878


  0%|          | 1/200 [00:24<1:21:26, 24.56s/it]


KeyboardInterrupt: 