In [1]:
import torch

class Diffusion:
    def __init__(self, noise_steps=500, beta_start=1e-3, beta_end=0.02, img_size=32, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        noise = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        model.eval()
        with torch.inference_mode():
            x = torch.randn((n, 1, self.img_size, self.img_size)).to(self.device)
            for i in reversed(range(1, self.noise_steps)):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]  # Index alpha_hat!
                beta = self.beta[t][:, None, None, None]
                noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)

                x = (1 / torch.sqrt(alpha)) * (x - (((1 - alpha) / torch.sqrt(1 - alpha_hat)) * predicted_noise)) + torch.sqrt(beta) * noise  # Corrected formula
        model.train()
        x = (x+1)/2
        x = x.clamp(0,1)
        x = (x*255).type(torch.uint8)
        return x


In [2]:
import torchvision
import os
import torch
from PIL import Image

def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).cuda().numpy()
    im = Image.fromarray(ndarr)
    im.save(path)

def setup_logging(run_name):
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs(os.path.join("models", run_name), exist_ok=True)
    os.makedirs(os.path.join("results", run_name), exist_ok=True)


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 2, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb

class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):  # Change c_in and c_out to 3 for RGB
            super().__init__()
            self.device = device
            self.time_dim = time_dim
            self.inc = DoubleConv(c_in, 32)
            self.down1 = Down(32, 64)
            self.sa1 = SelfAttention(64, 64)
            self.down2 = Down(64, 128)
            self.sa2 = SelfAttention(128, 32)
            self.down3 = Down(128, 128)
            self.sa3 = SelfAttention(128, 16)

            self.bot1 = DoubleConv(128, 256)
            self.bot2 = DoubleConv(256, 256)
            self.bot3 = DoubleConv(256, 128)

            self.up1 = Up(256, 64)
            self.sa4 = SelfAttention(64, 32)
            self.up2 = Up(128, 32)
            self.sa5 = SelfAttention(32, 64)
            self.up3 = Up(64, 32)
            self.sa6 = SelfAttention(32, 128)
            self.outc = nn.Conv2d(32, c_out, kernel_size=1)  # Change c_out to 3 for RGB output



    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc


    def forward(self, x, t, y=None):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        #print(x2.shape)
        x2 = self.sa1(x2)
        #print(x2.shape)
        x3 = self.down2(x2, t)
        #print(x3.shape)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        #print(x4.shape)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        #print(x3.shape,x4.shape)

        x = self.up1(x4, x3, t)
        #print(x.shape)

        x = self.sa4(x)
        #print(x.shape)

        x = self.up2(x, x2, t)

        x = self.sa5(x)
        x = self.up3(x, x1, t)

        x = self.sa6(x)
        output = self.outc(x)

        #print(x.shape)
        return output



In [4]:
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

# def train(args, dataset):
#     setup_logging(args.run_name)
#     device = args.device
#     model = UNet(c_in=3, c_out=3).to(device)  # Adjust input/output channels for RGB
#     optimizer = optim.AdamW(model.parameters(), lr=args.lr)
#     mse = torch.nn.MSELoss()
#     diffusion = Diffusion(noise_steps=args.noise_steps, img_size=args.image_size, device=device)
#     dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

#     loss_history = []

#     for epoch in range(args.epochs):
#         pbar = tqdm(dataloader)
#         epoch_loss = 0
#         for i, (images,) in enumerate(pbar):
#             images = images.to(device)
#             t = diffusion.sample_timesteps(images.shape[0]).to(device)
#             x_t, noise = diffusion.noise_images(images, t)
#             predicted_noise = model(x_t, t)
#             loss = mse(noise, predicted_noise)

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#             epoch_loss += loss.item()
#             pbar.set_postfix(loss=loss.item())

#         avg_loss = epoch_loss / len(dataloader)
#         loss_history.append(avg_loss)
#         sampled_images = diffusion.sample(model, n=images.shape[0])
#         save_images(sampled_images, f"results/{args.run_name}/{epoch}.jpg")

#     torch.save(model.state_dict(), f"models/{args.run_name}/model.pth")
#     return loss_history

def train(args, train_loader):
    setup_logging(args.run_name)
    device = args.device
    model = UNet(c_in=3, c_out=3).to(device)  # Adjust input/output channels for RGB
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = torch.nn.MSELoss()
    diffusion = Diffusion(noise_steps=args.noise_steps, img_size=args.image_size, device=device)

    loss_history = []

    for epoch in range(args.epochs):
        pbar = tqdm(train_loader)  # Use train_loader directly
        epoch_loss = 0
        for i, (images, _) in enumerate(pbar):  # Use the correct unpacking
            images = images.to(device)
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix(loss=loss.item())

        avg_loss = epoch_loss / len(train_loader)
        loss_history.append(avg_loss)
        sampled_images = diffusion.sample(model, n=images.shape[0])
        save_images(sampled_images, f"results/{args.run_name}/{epoch}.jpg")

    torch.save(model.state_dict(), f"models/{args.run_name}/model.pth")
    return loss_history


def test_model(model, diffusion, device, n):
    model.eval()
    with torch.no_grad():
        sampled_images = diffusion.sample(model, n)
        save_images(sampled_images, f"results/test_output.jpg")


In [5]:
import itertools
import torch
import plotly.graph_objects as go
import os

def grid_search(args, train_loader):
    learning_rates = [1e-4]
    batch_sizes = [32]
    noise_steps = [500]
    beta_starts = [5e-4]
    beta_ends = [0.02]

    grid = list(itertools.product(learning_rates, batch_sizes, noise_steps, beta_starts, beta_ends))
    results = {}
    best_run = None
    best_loss = float('inf')

    for i, (lr, bs, ns, beta_start, beta_end) in enumerate(grid):
        args.lr = lr
        args.batch_size = bs
        args.noise_steps = ns
        args.beta_start = beta_start
        args.beta_end = beta_end
        args.run_name = f"run_{i}_lr_{lr}_bs_{bs}_ns_{ns}_beta_{beta_start}_to_{beta_end}"

        # Pass the DataLoader (train_loader) directly to the train function
        loss_history = train(args, train_loader)
        results[args.run_name] = loss_history

        # Check if this is the best run
        if min(loss_history) < best_loss:
            best_loss = min(loss_history)
            best_run = args.run_name

    return results, best_run


def save_results_as_html(results, best_run, html_filename="grid_search_results.html"):
    fig = go.Figure()
    for run_name, losses in results.items():
        fig.add_trace(go.Scatter(y=losses, mode='lines+markers', name=run_name))

    fig.update_layout(title="Grid Search Results", xaxis_title="Epoch", yaxis_title="MSE Loss")
    fig.write_html(html_filename)

    print(f"Best run: {best_run} with loss: {min(results[best_run])}")


In [7]:
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

def load_image_dataset(image_folder, batch_size=32, image_size=128):
    """Load and preprocess the dataset from an image folder."""
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize to 128x128
        transforms.ToTensor(),  # Convert images to tensors
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize RGB images
    ])

    dataset = datasets.ImageFolder(image_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--image_size", type=int, default=128)
    parser.add_argument("--epochs", type=int, default=25)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--run_name", type=str, default="default")
    parser.add_argument("--data_dir", type=str, required=False, default=r"C:\Users\vimle\Desktop\Diffusion_repo\128x128Dataset\128x128", help="Path to the folder containing images")

    args, unknown = parser.parse_known_args()

    # Load the image dataset from the folder
    train_loader = load_image_dataset(args.data_dir, batch_size=args.batch_size, image_size=args.image_size)

    # Run grid search for hyperparameter optimization
    results, best_run = grid_search(args, train_loader)

    # Save results as HTML
    save_results_as_html(results, best_run, "grid_search_results.html")

    print(f"Best run: {best_run}")
    print("Results saved to grid_search_results.html")
    print(f"Running with the following arguments: {args}")

if __name__ == "__main__":
    main()


  0%|          | 0/255 [01:12<?, ?it/s]


KeyboardInterrupt: 