In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from pathlib import Path
import shutil
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
class AbstractArtDataset(Dataset):
    def __init__(self, root_dir, image_size=256, train=True):
        self.root_dir = Path(root_dir)
        self.train = train

        self.transform = transforms.Compose(
            [
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                (
                    transforms.RandomHorizontalFlip()
                    if train
                    else transforms.Lambda(lambda x: x)
                ),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        self.image_files = []
        valid_extensions = {".jpg", ".jpeg", ".png"}

        for ext in valid_extensions:
            self.image_files.extend(list(self.root_dir.glob(f"*{ext}")))

        print(f"Found {len(self.image_files)} images in {root_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]

        try:
            image = Image.open(img_path).convert("RGB")
            image = self.transform(image)

            return image

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return self[np.random.randint(len(self))]

In [None]:

def prepare_dataset(download_path, output_path, train_split=0.9, image_size=256):
    """
    Prepare the dataset by organizing and splitting the data
    """
    output_path = Path(output_path)
    train_dir = output_path / "train"
    val_dir = output_path / "val"

    for dir in [train_dir, val_dir]:
        dir.mkdir(parents=True, exist_ok=True)

    image_files = []
    for ext in [".jpg", ".jpeg", ".png"]:
        image_files.extend(Path(download_path).glob(f"*{ext}"))

    np.random.shuffle(image_files)

    split_idx = int(len(image_files) * train_split)
    train_files = image_files[:split_idx]
    val_files = image_files[split_idx:]

    def copy_files(files, dest_dir):
        for file in tqdm(files, desc=f"Copying to {dest_dir.name}"):
            shutil.copy2(file, dest_dir / file.name)

    copy_files(train_files, train_dir)
    copy_files(val_files, val_dir)

    return len(train_files), len(val_files)


In [None]:
def get_dataloaders(data_dir, batch_size=32, image_size=256, num_workers=4):
    """
    Create training and validation dataloaders
    """
    train_dataset = AbstractArtDataset(
        os.path.join(data_dir, "train"), image_size=image_size, train=True
    )

    val_dataset = AbstractArtDataset(
        os.path.join(data_dir, "val"), image_size=image_size, train=False
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader



In [None]:
def verify_dataset(dataloader):
    """
    Verify the dataset by checking a batch of images
    """
    batch = next(iter(dataloader))
    print(f"Batch shape: {batch.shape}")
    print(f"Value range: [{batch.min():.2f}, {batch.max():.2f}]")
    print(f"Mean: {batch.mean():.2f}")
    print(f"Std: {batch.std():.2f}")

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj = nn.Sequential(
            nn.Linear(dim, dim * 4), nn.SiLU(), nn.Linear(dim * 4, dim)
        )

    def forward(self, t):

        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim) * -embeddings)
        embeddings = t[:, None] * embeddings[None, :].to(t.device)
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        embeddings = self.proj(embeddings)
        return embeddings


In [None]:

class AttentionBlock(nn.Module):
    def __init__(self, dim, heads=4):
        super().__init__()
        self.heads = heads
        self.scale = dim**-0.5
        self.norm = nn.LayerNorm(dim)
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1)  # B, H, W, C
        x = self.norm(x)

        qkv = self.qkv(x).reshape(B, H * W, 3, self.heads, C // self.heads)
        q, k, v = qkv.unbind(2)  # B, H*W, heads, C//heads

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).reshape(B, H, W, C)

        x = self.proj(x)
        return x.permute(0, 3, 1, 2)  # Back to B, C, H, W


In [None]:

class UNet(nn.Module):
    def __init__(self, in_channels=3, dim=64, dim_mults=(1, 2, 4, 8)):
        super().__init__()

        self.time_mlp = TimeEmbedding(dim)
        self.init_conv = nn.Conv2d(in_channels, dim, 3, padding=1)

        # Downsampling
        dims = [dim * m for m in dim_mults]
        in_out = list(zip(dims[:-1], dims[1:]))

        self.downs = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(in_out):
            self.downs.append(
                nn.ModuleList(
                    [
                        AttentionBlock(dim_in),
                        nn.Conv2d(dim_in, dim_out, 3, padding=1),
                        nn.Conv2d(dim_out, dim_out, 3, stride=2, padding=1),
                    ]
                )
            )

        # Middle
        mid_dim = dims[-1]
        self.mid_block1 = AttentionBlock(mid_dim)
        self.mid_block2 = nn.Conv2d(mid_dim, mid_dim, 3, padding=1)

        # Upsampling
        self.ups = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            self.ups.append(
                nn.ModuleList(
                    [
                        nn.ConvTranspose2d(dim_in, dim_out, 4, stride=2, padding=1),
                        nn.Conv2d(dim_out * 2, dim_out, 3, padding=1),
                        AttentionBlock(dim_out),
                    ]
                )
            )

        # Final conv
        self.final_conv = nn.Conv2d(dim, in_channels, 3, padding=1)

    def forward(self, x, time):
        t = self.time_mlp(time)
        x = self.init_conv(x)

        # Downsampling
        h = []
        for attn, conv1, conv2 in self.downs:
            x = attn(x)
            x = conv1(x)
            x = F.silu(x)
            h.append(x)
            x = conv2(x)
            x = F.silu(x)

        # Middle
        x = self.mid_block1(x)
        x = self.mid_block2(x)
        x = F.silu(x)

        # Upsampling
        for up, conv, attn in self.ups:
            x = up(x)
            x = torch.cat((x, h.pop()), dim=1)
            x = conv(x)
            x = F.silu(x)
            x = attn(x)

        return self.final_conv(x)



In [None]:

class DiffusionModel:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        self.beta = torch.linspace(1e-4, 0.02, timesteps)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)

    def add_noise(self, x, t):
        """Add noise to the input image according to the noise schedule"""
        sqrt_alpha_bar = torch.sqrt(self.alpha_bar[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar[t])[
            :, None, None, None
        ]
        ε = torch.randn_like(x)
        return sqrt_alpha_bar * x + sqrt_one_minus_alpha_bar * ε, ε

    @torch.no_grad()
    def sample(self, model, n_samples, size, device):
        """Generate samples using the trained model"""
        model.eval()
        x = torch.randn(n_samples, 3, *size).to(device)

        for t in reversed(range(self.timesteps)):
            t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
            predicted_noise = model(x, t_batch)
            alpha = self.alpha[t]
            alpha_bar = self.alpha_bar[t]
            beta = self.beta[t]

            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = 0

            x = (
                1
                / torch.sqrt(alpha)
                * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_bar))) * predicted_noise)
                + torch.sqrt(beta) * noise
            )

        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        return x



In [None]:

def train_step(model, diffusion, optimizer, images, device):
    """Single training step"""
    batch_size = images.shape[0]
    t = torch.randint(0, diffusion.timesteps, (batch_size,), device=device).long()

    noisy_images, noise = diffusion.add_noise(images, t)
    predicted_noise = model(noisy_images, t)
    loss = F.mse_loss(predicted_noise, noise)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


In [None]:
def train_diffusion(
    num_epochs=100, batch_size=16, image_size=256, device="cuda", save_dir="models"
):
    model = UNet().to(device)
    diffusion = DiffusionModel()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    train_loader, val_loader = get_dataloaders(
        "data/processed_abstract_art", batch_size=batch_size, image_size=image_size
    )

    # Setup tensorboard
    writer = SummaryWriter("runs/diffusion_training")

    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            images = batch.to(device)

            optimizer.zero_grad()
            t = torch.randint(0, diffusion.timesteps, (images.shape[0],), device=device)
            noisy_images, noise = diffusion.add_noise(images, t)
            predicted_noise = model(noisy_images, t)
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")

        # Log to tensorboard
        writer.add_scalar("Loss/train", avg_loss, epoch)

        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                samples = diffusion.sample(
                    model, n_samples=4, size=(image_size, image_size), device=device
                )
                writer.add_images("Generated", samples, epoch)

            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": avg_loss,
                },
                save_dir / f"checkpoint_epoch_{epoch}.pt",
            )


In [None]:

def generate_images(checkpoint_path, num_images=4, image_size=256, device="cuda"):
    model = UNet().to(device)
    diffusion = DiffusionModel()

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["model_state_dict"])

    model.eval()
    with torch.no_grad():
        samples = diffusion.sample(
            model, n_samples=num_images, size=(image_size, image_size), device=device
        )

    for i, sample in enumerate(samples):
        save_image(sample, f"generated_image_{i}.png")

