In [None]:
import src
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
from datasets import load_dataset
from PIL import Image
import requests
from pathlib import Path
from torch.optim import Adam

In [None]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

In [None]:
image_size = 128
forward_transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(), # turn into Numpy array of shape HWC, divide by 255
    Lambda(lambda t: (t * 2) - 1),
    
])

x_start = forward_transform(image).unsqueeze(0)
x_start.shape

In [None]:
reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

reverse_transform(x_start.squeeze())

In [None]:
beta_start = 0.0001
beta_end = 0.02
timesteps = 1000

scheduler = src.LinearScheduler(beta_start=beta_start, beta_end=beta_end, timesteps=timesteps)

In [None]:
forward_diffusion = src.ForwardDiffusion(sqrt_alphas_cumprod=scheduler.sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=scheduler.sqrt_one_minus_alphas_cumprod, reverse_transform=reverse_transform)

In [None]:
src.plot(image=image, imgs=[forward_diffusion.get_noisy_image(x_start=x_start, t=torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

In [None]:
image_size = 28
num_channels = 1
batch_size = 128
dataset_name = 'fashion_mnist'
num_workers = 0
dataset = load_dataset(dataset_name, num_proc=num_workers)

In [None]:
forward_transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

def transforms(examples):
   examples["pixel_values"] = [forward_transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
results_folder_name = './results'
sample_and_save_freq = 1000
results_folder = Path(results_folder_name)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

In [None]:
model = src.DDPM(n_features=image_size, in_channels=num_channels, channel_scale_factors=(1, 2, 4,))
model.to(device)

In [None]:
learninig_rate = 1e-3
optimizer = Adam(model.parameters(), lr=learninig_rate)
criterion = src.get_loss

In [None]:
timesteps = 200
sampler = src.Sampler(betas=scheduler.betas, sqrt_one_minus_alphas_cumprod=scheduler.sqrt_one_minus_alphas_cumprod, sqrt_one_by_alphas=scheduler.sqrt_one_by_alphas, posterior_variance=scheduler.posterior_variance, timesteps=timesteps)

In [None]:
epochs = 5

src.train(
    image_size=image_size, 
    num_channels=num_channels, 
    epochs=epochs, 
    timesteps=timesteps, 
    sample_and_save_freq=sample_and_save_freq, 
    save_folder=results_folder, 
    forward_diffusion_model=forward_diffusion, 
    denoising_model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    dataloader=dataloader, 
    sampler=sampler, 
    device=device
)