In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

# =====================================================
# CONFIG
# =====================================================
IMG_SIZE = 32
CHANNELS = 3
BATCH = 128
EPOCHS = 100
TIMESTEPS = 1000
DATA_DIR = r"D:\pk\dataset\real"
MODEL_PATH = "diffusion_unet.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# =====================================================
# DATASET
# =====================================================
class ImageDataset(Dataset):
    def __init__(self, folder):
        self.files = [
            os.path.join(folder, f)
            for f in os.listdir(folder)
            if f.lower().endswith(("jpg", "png", "jpeg"))
        ]
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.transform(img)

dataset = ImageDataset(DATA_DIR)
loader = DataLoader(dataset, batch_size=BATCH, shuffle=True)

# =====================================================
# DIFFUSION SCHEDULE
# =====================================================
beta = torch.linspace(1e-4, 0.02, TIMESTEPS)
alpha = 1.0 - beta
alpha_cum = torch.cumprod(alpha, dim=0).to(DEVICE)

# =====================================================
# TIME EMBEDDING
# =====================================================
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half = self.dim // 2
        emb = torch.exp(
            torch.arange(half, device=t.device) *
            -(np.log(10000.0) / (half - 1))
        )
        emb = t[:, None].float() * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)

# =====================================================
# UNET MODEL
# =====================================================
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.SiLU()
        )

    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.time_emb = TimeEmbedding(512)
        self.time_fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.SiLU(),
            nn.Linear(512, 512)
        )

        self.c1 = nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 64))
        self.c2 = nn.Sequential(ConvBlock(64, 128), ConvBlock(128, 128))
        self.c3 = nn.Sequential(ConvBlock(128, 256), ConvBlock(256, 256))

        self.bottleneck = nn.Sequential(
            ConvBlock(256, 512),
            ConvBlock(512, 512)
        )

        self.u1 = ConvBlock(512 + 256, 256)
        self.u2 = ConvBlock(256 + 128, 128)
        self.u3 = ConvBlock(128 + 64, 64)

        self.out = nn.Conv2d(64, 3, 1)

        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode="nearest")

    def forward(self, x, t):
        t = self.time_fc(self.time_emb(t))
        t = t[:, :, None, None]

        c1 = self.c1(x)
        c2 = self.c2(self.pool(c1))
        c3 = self.c3(self.pool(c2))

        b = self.bottleneck(self.pool(c3))
        b = b + t

        u1 = self.up(b)
        u1 = self.u1(torch.cat([u1, c3], dim=1))

        u2 = self.up(u1)
        u2 = self.u2(torch.cat([u2, c2], dim=1))

        u3 = self.up(u2)
        u3 = self.u3(torch.cat([u3, c1], dim=1))

        return self.out(u3)

# =====================================================
# MODEL, OPTIMIZER
# =====================================================
model = UNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# =====================================================
# MODEL SUMMARY
# =====================================================
print("\nMODEL SUMMARY\n")
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal Parameters: {total_params:,}\n")

# =====================================================
# FORWARD DIFFUSION
# =====================================================
def q_sample(x0, t, noise):
    a = alpha_cum[t][:, None, None, None]
    return torch.sqrt(a) * x0 + torch.sqrt(1 - a) * noise

# =====================================================
# TRAINING
# =====================================================
for epoch in range(EPOCHS):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for x in pbar:
        x = x.to(DEVICE)
        b = x.size(0)

        t = torch.randint(0, TIMESTEPS, (b,), device=DEVICE)
        noise = torch.randn_like(x)
        xt = q_sample(x, t, noise)

        pred = model(xt, t)
        loss = F.mse_loss(pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())

print("✔ Training Complete")

# =====================================================
# SAVE MODEL
# =====================================================
torch.save(model.state_dict(), MODEL_PATH)
print(f"✔ Model saved to {MODEL_PATH}")

# =====================================================
# LOAD MODEL (FOR INFERENCE)
# =====================================================
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =====================================================
# DDIM SAMPLER (INFERENCE)
# =====================================================
DDIM_STEPS = 50
ddim_t = torch.linspace(0, TIMESTEPS - 1, DDIM_STEPS).long().to(DEVICE)

@torch.no_grad()
def sample_ddim(n=16):
    x = torch.randn(n, 3, IMG_SIZE, IMG_SIZE, device=DEVICE)

    for i in range(DDIM_STEPS - 1, 0, -1):
        t = ddim_t[i].repeat(n)
        t_prev = ddim_t[i - 1].repeat(n)

        a_t = alpha_cum[t][:, None, None, None]
        a_prev = alpha_cum[t_prev][:, None, None, None]

        eps = model(x, t)
        x0 = (x - torch.sqrt(1 - a_t) * eps) / torch.sqrt(a_t)
        x = torch.sqrt(a_prev) * x0 + torch.sqrt(1 - a_prev) * eps

    return torch.clamp((x + 1) / 2, 0, 1)

# =====================================================
# GENERATE & SHOW IMAGES
# =====================================================
samples = sample_ddim(16).cpu()

plt.figure(figsize=(5,5))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(samples[i].permute(1,2,0))
    plt.axis("off")
plt.show()



MODEL SUMMARY

UNet(
  (time_emb): TimeEmbedding()
  (time_fc): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (c1): Sequential(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
      )
    )
    (1): ConvBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU()
      )
    )
  )
  (c2): Sequential(
    (0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (

Epoch 1/100:  18%|█▊        | 14/79 [01:00<04:38,  4.28s/it, loss=0.658]