In [37]:
from diffusers import DiffusionPipeline
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

pipeline = DiffusionPipeline.from_pretrained(
    'runwayml/stable-diffusion-v1-5', safety_checker=None)

scheduler = pipeline.scheduler
tokenizer = pipeline.tokenizer

del pipeline

device, scheduler, tokenizer

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


('cuda',
 PNDMScheduler {
   "_class_name": "PNDMScheduler",
   "_diffusers_version": "0.27.0.dev0",
   "beta_end": 0.012,
   "beta_schedule": "scaled_linear",
   "beta_start": 0.00085,
   "clip_sample": false,
   "num_train_timesteps": 1000,
   "prediction_type": "epsilon",
   "set_alpha_to_one": false,
   "skip_prk_steps": true,
   "steps_offset": 1,
   "timestep_spacing": "leading",
   "trained_betas": null
 },
 CLIPTokenizer(name_or_path='C:\Users\37026\.cache\huggingface\hub\models--runwayml--stable-diffusion-v1-5\snapshots\1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9\tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=T

In [38]:
def read_text(file):
    with open(file) as f:
        text = []
        lines = f.readlines()
        for line in lines:
            #remove \n
            line = line.strip()
            text.append(line)
    return text

def read_images(dir):
    import os
    from PIL import Image
    return [Image.open(os.path.join(dir, i)) for i in os.listdir(dir) if i.endswith('.png')]

In [39]:
texts=[
    'a photo of a [snorlax_bear]', 'a rendering of a [snorlax_bear]',
    'a dark photo of the [snorlax_bear]', 'a photo of my [snorlax_bear]',
    'a good photo of a [snorlax_bear]', 'a photo of the nice [snorlax_bear]',
    'a photo of the small [snorlax_bear]', 'a photo of the weird [snorlax_bear]',
    'a photo of the cool [snorlax_bear]', 'a close-up photo of a [snorlax_bear]',
    'a bright photo of the [snorlax_bear]', 'a cropped photo of a [snorlax_bear]',
    'a cropped photo of the [snorlax_bear]', 'the photo of a [snorlax_bear]',
    'a photo of a clean [snorlax_bear]', 'a photo of a dirty [snorlax_bear]',
    'a photo of the large [snorlax_bear]', 'a photo of a cool [snorlax_bear]',
    'a photo of a small [snorlax_bear]'
]

In [40]:
from datasets import Dataset
import random
def create_dataset(dir_1):

    images = read_images(dir_1)
    text = random.choices(texts, k=len(images))
 
    # Create a Hugging Face dataset
    dataset = Dataset.from_dict({'image': images, 'text': text})

    return dataset
pika= create_dataset('./Snorlax')


In [41]:
from datasets import load_dataset
import torchvision

#图像增强模块
compose = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda img: img.convert('RGB')), 
    torchvision.transforms.Resize(
        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    torchvision.transforms.CenterCrop(512),
    #torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])


    
def f(data):
    #image enhance
    pixel_values = [compose(i) for i in data['image']]

    #text encode
    input_ids = tokenizer.batch_encode_plus(data['text'],
                                            padding='max_length',
                                            truncation=True,
                                            max_length=77).input_ids
    


    return {'pixel_values': pixel_values, 'input_ids': input_ids}


dataset = pika.map(f,
                      batched=True,
                      batch_size=100,
                      num_proc=1,
                      remove_columns=['image', 'text' ])

dataset.set_format(type='torch')




Map:   0%|          | 0/7 [00:00<?, ? examples/s]

In [42]:
#定义loader
def collate_fn(data):
    pixel_values = [i['pixel_values'] for i in data]
    input_ids = [i['input_ids'] for i in data]

    pixel_values = torch.stack(pixel_values).to(device)
    input_ids = torch.stack(input_ids).to(device)

    return {'pixel_values': pixel_values, 'input_ids': input_ids}


loader = torch.utils.data.DataLoader(dataset,
                                     shuffle=True,
                                     collate_fn=collate_fn,
                                     batch_size=1)

len(loader), next(iter(loader))

(7,
 {'pixel_values': tensor([[[[0.9216, 0.9216, 0.9216,  ..., 0.9216, 0.9216, 0.9216],
            [0.9216, 0.9216, 0.9216,  ..., 0.9216, 0.9216, 0.9216],
            [0.9216, 0.9216, 0.9216,  ..., 0.9216, 0.9216, 0.9216],
            ...,
            [0.9216, 0.9216, 0.9216,  ..., 0.9216, 0.9216, 0.9216],
            [0.9216, 0.9216, 0.9216,  ..., 0.9216, 0.9216, 0.9216],
            [0.9216, 0.9216, 0.9216,  ..., 0.9216, 0.9216, 0.9216]],
  
           [[0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
            [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
            [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
            ...,
            [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
            [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765],
            [0.9765, 0.9765, 0.9765,  ..., 0.9765, 0.9765, 0.9765]],
  
           [[0.9137, 0.9137, 0.9137,  ..., 0.9137, 0.9137, 0.9137],
            [0.9137, 0.9137, 0.9137,  ..., 0.9137, 0.9

In [44]:
#加载模型
from transformers import CLIPTextModel
encoder=CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5",subfolder="text_encoder")
%run vae.ipynb
%run unet.ipynb

#准备训练
encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
#only train the embedding layer
encoder.text_model.embeddings.token_embedding.requires_grad_(True)

encoder.eval()
vae.eval()
unet.train()

encoder.to(device)
vae.to(device)
unet.to(device)

optimizer = torch.optim.AdamW(unet.parameters(),
                              lr=1e-5,
                              betas=(0.9, 0.999),
                              weight_decay=0.01,
                              eps=1e-8)

criterion = torch.nn.MSELoss()

optimizer, criterion

(AdamW (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     differentiable: False
     eps: 1e-08
     foreach: None
     fused: None
     lr: 1e-05
     maximize: False
     weight_decay: 0.01
 ),
 MSELoss())

In [45]:
def init_new_word():
    #字典里添加新词
    tokenizer.add_tokens('[snorlax_bear]')

    #扩展encoder的embed层,添加一个新空间用于容纳新词
    encoder.resize_token_embeddings(len(tokenizer))

    #取新旧两个词的id
    old_id = tokenizer.convert_tokens_to_ids('Bear')
    new_id = tokenizer.convert_tokens_to_ids('[snorlax_bear]')

    embed = encoder.get_input_embeddings().weight.data

    #以旧词来初始化新词
    embed[new_id] = embed[old_id]


init_new_word()

In [46]:
def get_loss(data):
    #[1, 77] -> [1, 77, 768]
    out_encoder = encoder(data['input_ids'])[0]

    #[1, 3, 512, 512] -> [1, 4, 64, 64]
 
    out_vae = vae.encoder(data['pixel_values'])
    out_vae = vae.sample(out_vae)

    #0.18215 = vae.config.scaling_factor
    out_vae = out_vae * 0.18215

    #noise generator
    noise = torch.randn_like(out_vae)

    #add noise
    #1000 = scheduler.num_train_timesteps
    #1 = batch size
    noise_step = torch.randint(0, 1000, (1, )).long().to(device)
    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)

    #calc noise
    out_unet = unet(out_vae=out_vae_noise,
                    out_encoder=out_encoder,
                    time=noise_step)

    #mse loss
    #[1, 4, 64, 64],[1, 4, 64, 64]
    return criterion(out_unet, noise)



In [47]:
from tqdm import tqdm
def train():
    loss_sum = 0
    for epoch in tqdm(range(200)):
        for i, data in enumerate(loader):
            loss = get_loss(data)  
            #print(epoch, i, loss)
            loss.backward()
            loss_sum += loss.item()
            optimizer.step()
            optimizer.zero_grad()

        if epoch % 10 == 0:
            print(epoch, loss_sum)
            loss_sum = 0

    #torch.save(unet.to('cpu'), 'saves/unet.model')


train()
torch.save( encoder.to('cpu'), 'saves/encoder.pth')

  0%|          | 1/200 [00:02<08:13,  2.48s/it]

0 0.6553820911794901


  6%|▌         | 11/200 [00:14<03:52,  1.23s/it]

10 6.480514218565077


 10%|█         | 21/200 [00:26<03:38,  1.22s/it]

20 5.024804863380268


 16%|█▌        | 31/200 [00:38<03:26,  1.22s/it]

30 5.818547120783478


 20%|██        | 41/200 [00:51<03:13,  1.22s/it]

40 5.353631235891953


 26%|██▌       | 51/200 [01:03<03:00,  1.21s/it]

50 6.0187197513878345


 30%|███       | 61/200 [01:15<02:49,  1.22s/it]

60 6.251606137491763


 36%|███▌      | 71/200 [01:27<02:37,  1.22s/it]

70 4.837037515128031


 40%|████      | 81/200 [01:39<02:25,  1.22s/it]

80 5.059074798366055


 46%|████▌     | 91/200 [01:51<02:11,  1.21s/it]

90 5.157424732111394


 50%|█████     | 101/200 [02:04<01:59,  1.21s/it]

100 4.843293134355918


 56%|█████▌    | 111/200 [02:16<01:48,  1.21s/it]

110 5.771075590047985


 60%|██████    | 121/200 [02:28<01:35,  1.21s/it]

120 6.247637171763927


 66%|██████▌   | 131/200 [02:40<01:23,  1.21s/it]

130 6.22042535059154


 70%|███████   | 141/200 [02:52<01:11,  1.21s/it]

140 5.106386477360502


 76%|███████▌  | 151/200 [03:04<00:59,  1.21s/it]

150 5.224767256062478


 80%|████████  | 161/200 [03:16<00:47,  1.21s/it]

160 5.168128897901624


 86%|████████▌ | 171/200 [03:28<00:35,  1.21s/it]

170 4.84971322491765


 90%|█████████ | 181/200 [03:40<00:22,  1.21s/it]

180 4.157948060659692


 96%|█████████▌| 191/200 [03:53<00:10,  1.21s/it]

190 4.7174648102372885


100%|██████████| 200/200 [04:03<00:00,  1.22s/it]
