## Foward

In [None]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import os
import shutil

# Linear schedule for diffusion process
def choose_schedule(time_steps):
    return torch.linspace(0.0001, 0.02, time_steps)

# Sample noisy image based on current timestep
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_t = torch.sqrt(alphas_cumprod[t])
    sqrt_one_minus_alphas_t = torch.sqrt(1. - alphas_cumprod[t])
    return sqrt_alphas_t * x_start + sqrt_one_minus_alphas_t * noise

class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.conv(x)

# Directory setup
source_dir = ''  # Path to your source data
temp_dir = ''  # Path to temporary storage

if not os.path.exists(temp_dir):
    os.makedirs(temp_dir)

for file_name in os.listdir(source_dir):
    full_file_name = os.path.join(source_dir, file_name)
    if os.path.isfile(full_file_name):
        shutil.move(full_file_name, temp_dir)

# Data loading and transformation
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

dataset = ImageFolder(root='', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Model, optimizer, and loss setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = SimpleUNet().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Diffusion schedule setup
time_steps = 10
betas = choose_schedule(time_steps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)

# Training loop
epochs = 1
for epoch in range(epochs):
    for i, (images, _) in enumerate(dataloader):
        images = images.to(device)
        t = torch.randint(0, time_steps, (images.size(0),), device=device).long()
        noise = torch.randn_like(images)
        x_noisy = q_sample(images, t, noise)

        optimizer.zero_grad()
        output = unet(x_noisy)
        loss = criterion(output, noise)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Step {i}/{len(dataloader)}, Loss: {loss.item()}')

print('Training complete')

# Save the model
torch.save(unet.state_dict(), f"./simple_diffusion.pth")


## backward

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import math
# 定义设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Linear schedule for diffusion process
def choose_schedule(time_steps):
    return torch.linspace(0.0001, 0.02, time_steps)

# Sample noisy image based on current timestep
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_t = torch.sqrt(alphas_cumprod[t])
    sqrt_one_minus_alphas_t = torch.sqrt(1. - alphas_cumprod[t])
    return sqrt_alphas_t * x_start + sqrt_one_minus_alphas_t * noise

class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        return self.conv(x)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)

    predicted_noise = denoise_model(x_noisy, t)
    loss = (noise, predicted_noise)

    return loss


@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)

    
    model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


def p_sample_loop(model, shape, n_steps):
    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in range(n_steps - 1, -1, -1):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs


def show_denoising_images(images):
    num_images = len(images)
    fig, axes = plt.subplots(num_images // 10 + 1, 10, figsize=(15, 15))
    axes = axes.flatten()
    for img, ax in zip(images, axes):
        ax.axis('off')
        ax.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

def get_index_from_list(vals, t, x_shape):
        batch_size = t.shape[0]
        out = vals.gather(-1, t.cpu())
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def p_sample(model, x, t, t_index):
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)


    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def p_sample_loop(model, shape, n_steps):
    b = shape[0]
    img = torch.randn(shape, device=device)
    imgs = []

    for i in range(n_steps - 1, -1, -1):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs



model = SimpleUNet().to(device)
model.load_state_dict(torch.load(''))
model.eval()

n_steps = 500 
betas = choose_schedule(n_steps) #, schedule='cosine'
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.]), alphas_cumprod[:-1]], axis=0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
batch_size = 1  
shape = (batch_size, 3, 32, 32) 


generated_images = p_sample_loop(model, shape, n_steps)


steps_to_show = list(range(0, n_steps+1, 30))


fig, axs = plt.subplots(1, len(steps_to_show), figsize=(30, 5))


for idx, step in enumerate(steps_to_show):
    img_index = (n_steps - step)
    img = generated_images[step+(n_steps//10)]
    axs[idx].imshow(np.transpose(img[0], (1, 2, 0)))
    axs[idx].set_title(f'Step {img_index}')
    axs[idx].axis('off')

plt.tight_layout()
plt.show()




