In [None]:
'''
mean = sqrt(a_hat_t) * x_0
variance = sart(1-a_hat_t) * random_noise
'''

In [None]:
import torch
import torch.nn as nn
import torchvision
import PIL
import math
import urllib
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

In [None]:
def forward_diffusion(x_0, t, betas = torch.linspace(0.0, 1.0, 5)):
    noise = torch.randn_like(x_0)
    alphas = 1 - betas
    alphas_hat = torch.cumprod(alphas, axis=0)
    alphas_hat_t = alphas_hat.gather(-1, t).reshape(-1, 1, 1, 1)
    
    mean = alphas_hat_t.sqrt() * x_0
    variance = torch.sqrt(1 - alphas_hat_t) * noise
    
    return mean + variance, noise

In [None]:
url = 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTZmJy3aSZ1Ix573d2MlJXQowLCLQyIUsPdniOJ7rBsgG4XJb04g9ZFA9MhxYvckeKkVmo&usqp=CAU'

In [None]:
filename = 'racoon.jpg'

In [None]:
urllib.request.urlretrieve(url, filename)

In [None]:
image = PIL.Image.open(filename)

In [None]:
image

# Transfer Image to Tensor

In [None]:
transform = transforms.Compose([ # PIL -> Torch
    transforms.Resize((32,32)),
    transforms.ToTensor(), # from 0 to 1
    transforms.Lambda(lambda t: (t * 2) - 1)# 0 -> -1 , 1 -> 1
    
])
reverse_transform = transforms.Compose([ # Torch -> PIL
    transforms.Lambda(lambda t: (t + 1) / 2),
    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
    transforms.Lambda(lambda t: t * 255.),
    transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
    transforms.ToPILImage(),
])

In [None]:
torch_image = transform(image)

In [None]:
torch_image

In [None]:
plt.imshow(reverse_transform(torch_image))

In [None]:
t = torch.tensor([0,1,2,3,4])
batch_images = torch.stack([torch_image] * 5)
noisy_images, _ = forward_diffusion(batch_images, t)

In [None]:
for img in noisy_images:
    plt.imshow(reverse_transform(img))
    plt.show()

In [None]:
torch.cuda.is_available()

In [None]:
torch.cuda.device_count()

In [None]:
torch.cuda.current_device()

In [None]:
torch.cuda.device(0)

In [None]:
torch.cuda.get_device_name(0)

In [None]:
device = 'cuda:0'

In [None]:
def get_sample_image() -> PIL.Image.Image:
#     url = url = 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTZmJy3aSZ1Ix573d2MlJXQowLCLQyIUsPdniOJ7rBsgG4XJb04g9ZFA9MhxYvckeKkVmo&usqp=CAU'
#     filename = 'racoon.jpg'
#     urllib.request.urlretrieve(url, filename)
    return PIL.Image.open(filename)

In [None]:
def plot_noise_distribution(noise, predicted_noise):
    plt.hist(noise.cpu().numpy().flatten(), density=True, alpha=0.8, label= "ground truth noise")
    plt.hist(predicted_noise.cpu().numpy().flatten(), density=True, alpha=0.8, label= "predicted noise")
    plt.legend()
    plt.show()

In [None]:
def plot_noise_prediction(noise, predicted_noise):
    plt.figure(figsize=(15,15))
    f, ax = plt.subplots(1, 2, figsize=(5,5))
    ax[0].imshow(reverse_transform(noise))
    ax[0].set_title(f"ground truth noise", fontsize=10)
    ax[1].imshow(reverse_transform(predicted_noise))
    ax[1].set_title(f"predicted noise", fontsize=10)
    plt.show()

In [None]:
class DiffusionModel:
    def __init__(self, start_schedule=0.0001, end_schedule=0.02, timesteps=300):
        super().__init__()
        self.start_schedule = start_schedule
        self.end_schedule = end_schedule
        self.timesteps = timesteps
        self.betas = torch.linspace(start_schedule, end_schedule, timesteps)
        self.alphas = 1 - self.betas
        self.alphas_hat = torch.cumprod(self.alphas, axis=0)
        
    def forward(self, x_0, t, device):
        '''
        x_0 : (B, C, H, W)
        t: (B,)
        '''
        noise =- torch.randn_like(x_0)
        sqrt_alphas_hat_t = self.get_index_from_list(self.alphas_hat.sqrt(), t, x_0.shape)
        sqrt_one_minus_alpha_hat_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_hat), t, x_0.shape)
        
        mean = sqrt_alphas_hat_t.to(device) * x_0.to(device)
        variance = sqrt_one_minus_alpha_hat_t.to(device) * noise.to(device)
        
        return mean + variance, noise.to(device)
    
    @torch.no_grad()
    def backward(self, x, t, model, **kwargs):
        """
        Calls the model to predict the noise in the image and returns the denoised image.
        Applies noise to this image, if we are not the last step yet.
        """

        betas_t = self.get_index_from_list(self.betas, t, x.shape)
        sqrt_one_minus_alphas_hat_t = self.get_index_from_list(torch.sqrt(1. - self.alphas_hat), t, x.shape)
        sqrt_recip_alphas_t = self.get_index_from_list(torch.sqrt(1.0 / self.alphas), t, x.shape)
        mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t, **kwargs) / sqrt_one_minus_alphas_hat_t)

        posterior_variance_t = betas_t
        
        if t == 0:
            return mean
        else:
            noise = torch.randn_like(x)
            variance = torch.sqrt(posterior_variance_t) * noise
            return mean + variance
        
    @staticmethod
    def get_index_from_list(values, t, x_shape):
        batch_size = t.shape[0]
        '''
        pick the values from vals according to the indices stored in 't'
        '''
        result = values.gather(-1, t.cpu())
        
        '''
        if x_shape = (5,3,64,64)
            -> len(x_shape) = 4
            -> len(x_shape) - 1 = 3
        and thus we reshape `out` to dim (batch_size, 1, 1, 1)
        '''
        return result.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

In [None]:
IMAGE_SHAPE= (32, 32)

In [None]:
transform = transforms.Compose([ # PIL -> Torch
    transforms.Resize((32,32)),
    transforms.ToTensor(), # from 0 to 1
    transforms.Lambda(lambda t: (t * 2) - 1)# 0 -> -1 , 1 -> 1
    
])
reverse_transform = transforms.Compose([ # Torch -> PIL
    transforms.Lambda(lambda t: (t + 1) / 2),
    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
    transforms.Lambda(lambda t: t * 255.),
    transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
    transforms.ToPILImage(),
])

In [None]:
pil_image = get_sample_image()
torch_image = transform(pil_image)

In [None]:
diffusion_model = DiffusionModel()

In [None]:
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self,dim):
        super().__init__()
        self.dim = dim
        
    def forward(self,time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(),embeddings.cos()),dim=-1)
        return embeddings

In [None]:
class Block(nn.Module):
    def __init__(self, channels_in, channels_out, time_embedding_dims, labels, num_filters=3, downsample=True):
        super().__init__()
        
        self.time_embedding_dims = time_embedding_dims
        self.time_embedding = SinusoidalPositionalEmbedding(time_embedding_dims)
        self.labels = labels
        
        if labels:
            self.label_mlp = nn.Linear(1, channels_out)
            
        self.downsample = downsample
        
        if downsample:
            self.conv1 = nn.Conv2d(channels_in, channels_out, num_filters, padding=1)
            self.final = nn.Conv2d(channels_out, channels_out, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(2*channels_in, channels_out, num_filters, padding=1)
            self.final = nn.ConvTranspose2d(channels_out, channels_out, 4, 2, 1)
        
        self.bnorm1 = nn.BatchNorm2d(channels_out)
        self.bnorm2 = nn.BatchNorm2d(channels_out)
        
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1)
        self.time_mlp = nn.Linear(time_embedding_dims, channels_out)
        self.relu = nn.ReLU()
        
    def forward(self, x, t, **kwargs):
        o = self.bnorm1(self.relu(self.conv1(x)))
        o_time = self.relu(self.time_mlp(self.time_embedding(t)))
        o = o + o_time[(...,) + (None,) * 2]
        if self.labels:
            label = kwargs.get('labels')
            o_label = self.relu(self.label_mlp(label))
            o = o + o_label[(...,) + (None,) * 2]
        
        o = self.bnorm2(self.relu(self.conv2(o)))
        return self.final(o)
        


In [None]:
class UNet(nn.Module):
    '''
    A Simplified variant of the U-Net architecture.
    '''
    
    def __init__(self, img_channels=3, time_embedding_dims=128, labels=False, sequence_channels=(64, 128, 256, 512, 1024)):
        super().__init__()
        self.time_embedding_dims = time_embedding_dims
        sequence_channels_rev = reversed(sequence_channels)
        
        self.downsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels) for channels_in, channels_out in zip(sequence_channels,sequence_channels[1:])])
        self.upsampling = nn.ModuleList([Block(channels_in, channels_out, time_embedding_dims, labels, downsample=False) for channels_in, channels_out in zip(sequence_channels[::-1],sequence_channels[::-1][1:])])
        self.conv1 = nn.Conv2d(img_channels, sequence_channels[0], 3, padding=1)
        self.conv2 = nn.Conv2d(sequence_channels[0], img_channels, 1)
        
    def forward(self, x, t, **kwargs):
        residuals = []
        o = self.conv1(x)
        
        for ds in self.downsampling:
            o = ds(o, t, **kwargs)
            residuals.append(o)
            
        for us, res in zip(self.upsampling, reversed(residuals)):
            o = us(torch.cat((o, res), dim=1), t, **kwargs)
            
        return self.conv2(o)

In [None]:
NO_EPOCHS = 2000
PRINT_FREQUENCY = 400
LR = 0.001
BATCH_SIZE = 128
VERBOSE = True

unet = UNet(labels=False)
unet.to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

In [None]:
for epoch in range(NO_EPOCHS):
    mse_epoch_loss = []
    
    batch = torch.stack([torch_image] * BATCH_SIZE)
    t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
    
    batch_noisy, noise = diffusion_model.forward(batch, t, device)
    predicted_noise = unet(batch_noisy, t)
    
    optimizer.zero_grad()
    
    loss = torch.nn.functional.mse_loss(noise, predicted_noise)
    mse_epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()
    if epoch % PRINT_FREQUENCY == 0:
        print('---')
        print(f"Epoch: {epoch} | Train Loss {np.mean(mse_epoch_loss)}")
        if VERBOSE:
            with torch.no_grad():        
                plot_noise_prediction(noise[0], predicted_noise[0])
                plot_noise_distribution(noise, predicted_noise)

In [None]:
# Reverse diffusion
with torch.no_grad():
    img = torch.randn((1, 3) + IMAGE_SHAPE).to(device) # x_300
    print(img)
    for i in reversed(range(diffusion_model.timesteps)):
        t = torch.full((1,), i, dtype=torch.long, device=device)
        img = diffusion_model.backward(img, t, unet.eval())

        if i % 50 == 0:
            plt.figure(figsize=(2, 2))
            plt.imshow(reverse_transform(img[0]))
            plt.show()

# With Label - Training on CIFAR-10

In [None]:
BATCH_SIZE = 256
NO_EPOCHS = 100
PRINT_FREQUENCY = 10
LR = 0.001
VERBOSE = True

unet = UNet(labels=True)
unet.to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=LR)

In [None]:
# transforms
def minus_one_to_one(x):
    return (x * 2) - 1
data_transform = [ # PIL -> Torch
    transforms.Resize((32,32)),
    transforms.ToTensor(), # from 0 to 1
    transforms.Lambda(minus_one_to_one)# 0 -> -1 , 1 -> 1    
]

reverse_transform_dataset = transforms.Compose([ # Torch -> PIL
    transforms.Lambda(lambda t: (t + 1) / 2),
    transforms.Lambda(lambda t: t.permute(1, 2, 0)),
    transforms.Lambda(lambda t: t * 255.),
    transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
    transforms.ToPILImage(),
])

In [None]:
interested_classes = ['bird', 'cat', 'dog', 'horse']
# 将类别名转换为对应的索引
cifar10_classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
interested_class_indices = [cifar10_classes.index(cls) for cls in interested_classes]

cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transform)
cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=data_transform)

train_indices = [i for i in range(len(cifar10_train)) if cifar10_train.targets[i] in interested_class_indices]
test_indices = [i for i in range(len(cifar10_test)) if cifar10_test.targets[i] in interested_class_indices]

trainset = torch.utils.data.Subset(cifar10_train, train_indices)
testset = torch.utils.data.Subset(cifar10_test, test_indices)

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)

# testset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=data_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, drop_last=True)

In [None]:
image = next(iter(trainloader))[0]

In [None]:
for epoch in range(NO_EPOCHS):
    mean_epoch_loss = []
    mean_epoch_loss_val = []
    
    print(f"EPOCH {epoch} Train")
    for batch, label in trainloader:
        t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
        batch = batch.to(device)
        
        batch_noisy, noise = diffusion_model.forward(batch, t, device)
        predicted_noise = unet(batch_noisy, t, labels=label.reshape(-1, 1).float().to(device))
        
        optimizer.zero_grad()
        loss = torch.nn.functional.mse_loss(noise, predicted_noise)
        mean_epoch_loss.append(loss)
        optimizer.step()
    print(f"EPOCH {epoch} Test")
    for batch, label in testloader:
        t = torch.randint(0, diffusion_model.timesteps, (BATCH_SIZE,)).long().to(device)
        batch = batch.to(device)
        batch_noisy, noise = diffusion_model(batch_noisy, t, labels=label.reshape(-1, 1).float().to(device))
        
        loss = torch.nn.functional.mse_loss(noise, predicted_noise)
        mean_epoch_loss_val.append(loss.item())
        
    if epoch % PRINT_FREQUENCY == 0:
        print('---')
        print(f"Epoch: {epoch} | Train Loss: {np.mean(mean_epoch_loss)} | Val Loss: {np.mean(mean_epoch_loss_val)}")
        if VERBOSE:
            with torch.no_grad():
                plot_noise_prediction(noise[0], predicted_noise[0])
                plot_noise_distribution(noise, predicted_noise)
        
        torch.save(unet.state_dict(), f"epoch:{epoch}")