In [1]:
import math
import os
import warnings
import random
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# (Optional) Hide Pillow warnings
warnings.filterwarnings("ignore", message="Palette images with Transparency")

##################################################
# 1) Hyperparameters & Config
##################################################

T = 1000          # Diffusion steps
beta_start = 0.0001
beta_end   = 0.02

epochs        = 50
batch_size    = 64
learning_rate = 1e-4
image_size    = 32
num_channels  = 3
device        = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the beta/alpha schedules
betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat(
    [torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0
)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

##################################################
# 2) Define Forward Diffusion Function
##################################################
def forward_diffusion(x0, t):
    """
    x0: [B, 3, H, W]  (clean images)
    t:  [B]           (timesteps)
    Returns: x_t, noise
       x_t:   Noisy version of x0 at step t
       noise: The actual noise added
    """
    sqrt_alpha = sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
    sqrt_one_minus_alpha = sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
    noise = torch.randn_like(x0)
    x_t = sqrt_alpha * x0 + sqrt_one_minus_alpha * noise
    return x_t, noise

##################################################
# 3) Dataset Classes (Recursive Scanning)
##################################################
class PngDataset(Dataset):
    def __init__(self, root_dirs, transform=None, test_dir=None):
        """
        root_dirs: list of directories to scan for images
        transform: torchvision transforms
        test_dir: directory whose .png images go into 'test_files'
        """
        self.transform = transform
        self.files = []
        self.test_files = []

        # If a test_dir is specified, gather its images for the test set
        if test_dir:
            for folder, _, filenames in os.walk(test_dir):
                for fname in filenames:
                    if fname.lower().endswith(".png"):
                        self.test_files.append(os.path.join(folder, fname))

        # Gather all .png files from each root_dir
        for root_dir in root_dirs:
            for folder, _, filenames in os.walk(root_dir):
                for fname in filenames:
                    if fname.lower().endswith(".png"):
                        full_path = os.path.join(folder, fname)
                        # if test_dir is set, skip those paths here
                        if test_dir and test_dir in full_path:
                            continue
                        self.files.append(full_path)

        # Shuffle the big file list so we can do an 80/20 split
        random.shuffle(self.files)

        # 80% -> train_files, 20% -> test_files
        split_idx = int(0.8 * len(self.files))
        self.train_files = self.files[:split_idx]
        # Extend the test_files array with the leftover 20%
        self.test_files.extend(self.files[split_idx:])

    def __len__(self):
        return len(self.train_files)

    def __getitem__(self, idx):
        img_path = self.train_files[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0  # dummy label

class TestPngDataset(Dataset):
    def __init__(self, test_files, transform=None):
        self.test_files = test_files
        self.transform = transform

    def __len__(self):
        return len(self.test_files)

    def __getitem__(self, idx):
        img_path = self.test_files[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, 0  # dummy label

##################################################
# 4) U-Net Model (Predict Noise)
##################################################
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    """
    A simplified U-Net that doesn't take time embeddings.
    It tries to predict noise given x.
    """
    def __init__(self):
        super(UNet, self).__init__()
        self.conv1 = DoubleConv(3, 64)
        self.conv2 = DoubleConv(64, 128)
        self.conv3 = DoubleConv(128, 256)
        self.conv4 = DoubleConv(256, 512)
        self.pool = nn.MaxPool2d(2)

        self.uptrans3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.uconv3   = DoubleConv(512, 256)
        self.uptrans2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.uconv2   = DoubleConv(256, 128)
        self.uptrans1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.uconv1   = DoubleConv(128, 64)
        self.out_conv = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        # No time embeddings: purely sees x
        h1 = self.conv1(x)
        h2 = self.conv2(self.pool(h1))
        h3 = self.conv3(self.pool(h2))
        h4 = self.conv4(self.pool(h3))

        u3 = self.uptrans3(h4)
        u3 = torch.cat([u3, h3], dim=1)
        u3 = self.uconv3(u3)

        u2 = self.uptrans2(u3)
        u2 = torch.cat([u2, h2], dim=1)
        u2 = self.uconv2(u2)

        u1 = self.uptrans1(u2)
        u1 = torch.cat([u1, h1], dim=1)
        u1 = self.uconv1(u1)

        return self.out_conv(u1)

##################################################
# 5) Instantiate Model & Optimizer
##################################################
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

##################################################
# 6) Sample Function (Reverse Diffusion)
##################################################
@torch.no_grad()
def sample_model(n=8):
    """
    Generate n random samples using the learned noise predictor
    in a DDPM-like reverse diffusion loop.
    Saves the images as 'samples/diffusion_samples.png'
    """
    model.eval()
    x = torch.randn(n, 3, image_size, image_size, device=device)

    # We do the standard DDPM sampling approach but
    # the model doesn't take t, so we call model(x).
    for i in reversed(range(T)):
        beta_t = betas[i]
        alpha_t = 1.0 - beta_t
        alpha_t_bar = alphas_cumprod[i]
        alpha_t_bar_prev = alphas_cumprod_prev[i] if i > 0 else torch.tensor(1.0, device=device)

        # The model predicts noise
        pred_noise = model(x)
        # Equation 12 in DDPM
        x = (1.0 / torch.sqrt(alpha_t)) * (
            x - (beta_t / torch.sqrt(1.0 - alpha_t_bar)) * pred_noise
        )

        if i > 0:
            sigma_t = math.sqrt(
                (1.0 - alpha_t_bar_prev) / (1.0 - alpha_t_bar) * beta_t
            )
            noise = torch.randn_like(x)
            x += sigma_t * noise

    samples = x.cpu()
    os.makedirs('samples', exist_ok=True)
    out_path = 'samples/diffusion_samples.png'
    img_grid = utils.make_grid(samples, nrow=4)
    utils.save_image(img_grid, out_path)
    print(f"Samples saved to {out_path}")

##################################################
# 7) Train Function (Saves Weights)
##################################################
def train_model():
    print("===== Training Model from Scratch =====")
    model.train()

    for epoch in range(epochs):
        for step, (x0, _) in enumerate(train_loader):
            x0 = x0.to(device)
            t = torch.randint(0, T, (x0.shape[0],), device=device).long()

            # forward diffusion
            x_t, noise = forward_diffusion(x0, t)

            # model predicts noise
            noise_pred = model(x_t)
            loss = F.mse_loss(noise_pred, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 100 == 0:
                print(f"Epoch {epoch+1}/{epochs} | Step {step} | Loss: {loss.item():.4f}")

    # Final save
    torch.save(model.state_dict(), "/kaggle/working/latest_weights.pth", _use_new_zipfile_serialization=False)
    print("Final weights saved to /kaggle/working/latest_weights.pth")

##################################################
# 8) Dataset / DataLoader Setup
##################################################
if __name__ == "__main__":
    # Example usage:
    # root_dirs -> where your training PNG files are
    # test_dir -> a dedicated folder that must be in the test set
    train_dirs = ["/kaggle/input"]
    test_dir   = "/kaggle/input/pokemon-images-and-types/images"

    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ])

    dataset = PngDataset(root_dirs=train_dirs, transform=transform, test_dir=test_dir)
    test_dataset = TestPngDataset(dataset.test_files, transform=transform)

    # Quick checks
    if len(dataset) == 0:
        raise ValueError("Training dataset is empty. Check /kaggle/input for PNG images!")
    if len(test_dataset) == 0:
        raise ValueError("Test dataset is empty. Check your test_dir for PNG images!")

    # Build loaders
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True,
                              num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                              num_workers=2, pin_memory=True)

    print(f"Loaded {len(dataset)} training images and {len(test_dataset)} test images.")

    # Now ask user to load or retrain
    weight_path = "/kaggle/working/latest_weights.pth"
    #answer = input("Use existing weights (Y) or retrain (N)? ").strip().upper()

    #if answer == "Y" and os.path.exists(weight_path):
        #print(f"Loading existing weights from {weight_path}...")
        # Because of PyTorch's future changes, specify weights_only=True
        # plus _use_new_zipfile_serialization=False to be safe
        #model.load_state_dict(torch.load(weight_path, map_location=device, weights_only=True))
        # Sample
        #sample_model(n=8)
    #else:
    print("Training from scratch...")
    train_model()
    # Sample
    sample_model(n=8)


Loaded 35822 training images and 9765 test images.
Training from scratch...
===== Training Model from Scratch =====
Epoch 1/50 | Step 0 | Loss: 1.0041
Epoch 1/50 | Step 100 | Loss: 0.2530
Epoch 1/50 | Step 200 | Loss: 0.1604
Epoch 1/50 | Step 300 | Loss: 0.0952
Epoch 1/50 | Step 400 | Loss: 0.1022
Epoch 1/50 | Step 500 | Loss: 0.1106
Epoch 2/50 | Step 0 | Loss: 0.0590
Epoch 2/50 | Step 100 | Loss: 0.0733
Epoch 2/50 | Step 200 | Loss: 0.0897
Epoch 2/50 | Step 300 | Loss: 0.0706
Epoch 2/50 | Step 400 | Loss: 0.0666
Epoch 2/50 | Step 500 | Loss: 0.0500
Epoch 3/50 | Step 0 | Loss: 0.0401
Epoch 3/50 | Step 100 | Loss: 0.0425
Epoch 3/50 | Step 200 | Loss: 0.0469
Epoch 3/50 | Step 300 | Loss: 0.0416
Epoch 3/50 | Step 400 | Loss: 0.0674
Epoch 3/50 | Step 500 | Loss: 0.0443
Epoch 4/50 | Step 0 | Loss: 0.0594
Epoch 4/50 | Step 100 | Loss: 0.0440
Epoch 4/50 | Step 200 | Loss: 0.0291
Epoch 4/50 | Step 300 | Loss: 0.0715
Epoch 4/50 | Step 400 | Loss: 0.0507
Epoch 4/50 | Step 500 | Loss: 0.0372
Epoc