In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
device = (
    'mps'
    if torch.backends.mps.is_available()
    else 'cuda'
    if torch.cuda.is_available()
    else 'cpu'
)

In [2]:
'''finetuned'''
dataset_name = 'huggan/smithsonian_butterflies_subset'
dataset = load_dataset(dataset_name, split='train')
image_size = 256
batch_size = 4
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5],[0.5]),
    ]
)

def transform(examples):
    images = [preprocess(image.convert('RGB')) for image in examples['image']]
    return {'images':images}

dataset.set_transform(transform)

train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)



In [3]:
print('previewing batch')
batch = next(iter(train_dataloader))
grid = torchvision.utils.make_grid(batch['images'], nrow=4)
plt.imshow(grid.permute(1,2,0).cpu().clip(-1,1)*0.5+0.5)

previewing batch


<matplotlib.image.AxesImage at 0x2047fe94070>

: 

待续，详见第五章

In [None]:
'''CLIP GUIDANCE'''

#skip some basic lines
import open_clip

#clip_model = open_clip.create_... 
#tfms = torchvision.transforms.Compose(...) #图像变换，归一化数据适配CLIP
def clip_loss(image, text_features):
    image_features = clip_model.encode_image(tfms(image))
    input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
    embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
    dists = (input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2))
    return dists.mean()

prompt = 'Red car, ...'
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
    text_features = clip_model.encode_text(text)

x = torch.randn(4,3,256,256).to(device)

for i,t in tqdm(enumerate(scheduler.timesteps)):
    #model_input = scheduler...(x,t)
    #noise_pred = image_pipe.unet(model_inut, t)...

    cond_grad = 0

    for cut in range(n_cuts):
        x = x.detach().requires_grad()
        x0 = scheduler.step(noise_pred, t, x).pred_originial_sample

        loss = clip_loss(x0, text_features)*guidance_scale

        cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts

    #update x
    alpha_bar = sceduler.alphas_cumprod[i]    #缩放因子
    x = (x.detach()+cond_grad*alpha_bar.sqrt())
    x = scheduler.step(noise_pred, t, x).prev_sample

In [11]:
'''类别条件扩散模型基于MNIST'''
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device}')
dataset = torchvision.datasets.MNIST(root='mnist/',train=True,download=True,transform=torchvision.transforms.ToTensor())

train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

x, y = next(iter(train_dataloader))
print('Input shape:',x.shape)
print('Labels:',y)
plt.imshow(torchvision.utils.make_grid(x)[0],cmap='Greys')


Using cpu
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST\raw\train-images-idx3-ubyte.gz


 68%|██████▊   | 6720343/9912422 [00:01<00:00, 6638773.05it/s]





RuntimeError: File not found or corrupted.

In [7]:
'''创建类别条件的UNET模型'''
class ClassConditionedUnet(nn.Module):
    def __init__(self, num_classes=10, class_emb_size=4):
        super().__init__()

        self.class_emb = nn.Embedding(num_classes, class_emb_size)

        self.model = UNet2DModel(
            sample_size=28, #picture size
            in_channels=1+class_emb_size,
            out_channels=1,
            layers_per_block=2, #残差连接层
            block_out_channels=(32,64,64),
            down_block_types=(
                'DownBlock2D',
                'AttnDownBlock2D',
                'AttnDownBlock2D',
            ),
            up_block_types=(
                'AttnUpBlock2D',
                'AttnUpBlock2D',
                'UpBlock2D',
            ),
        )

    def forward(self, x, t, class_labels):
        bs, ch, w, h = x.shape

        class_cond = self.class_emb(class_labels)
        class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1],w,h)
        net_input = torch.cat((x, class_cond), 1)

        return self.model(net_input, t).sample

In [10]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
train_dataloader = DataLoader(dataset, batch_size=128,shuffle=True)
n_epochs = 10
net = ClassConditionedUnet().to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(),lr=1e-3)
losses = []

for epoch in range(n_epochs):
    for x,y in tqdm(train_dataloader):
        x=x.to(device)*2-1
        y=y.to(device)
        noise = torch.randn_like(x)
        timesteps = torch.randint(0,999,(x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x,noise,timesteps)

        pred = net(noisy_x,timesteps,y)
        loss = loss_fn(pred, noise)
        opt.zero_grad()
        loss.backward()
        opt.step()

        losses.append(loss.item())

    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch{epoch}, ave loss:{avg_loss:05f}')

    plt.plot(losses)
        

NameError: name 'dataset' is not defined

In [None]:
x = torch.randn(80,1,28,28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)

for i,t in tqdm(enumerate(noise_scheduler.timesteps)):
    with torch.no_grad():
        residual = net(x,t,y)
    x = noise_scheduler.step(residual, t, x).prev_sample

fig, ax = plt.subplot(1,1,figsize=(12,12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1,1),nrow=8)[0],cmap='Greys')