## Install libraries

In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import v2
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np

## Dataset loading and preparation

In [None]:
original_dir = '/kaggle/input/original-and-retouched-faces-images-dataset/original'
retouched_dir = '/kaggle/input/original-and-retouched-faces-images-dataset/retouched'

num_images_to_load = 1000  # Set number of images to use in dataset
batch_size = 4

In [None]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 256  # the generated image resolution
    train_batch_size = batch_size
    eval_batch_size = batch_size  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "diffuserRet"  # the model name locally and on the HF Hub

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    overwrite_output_dir = False  # overwrite the old model when re-running the notebook
    seed = 42


config = TrainingConfig()
config

In [None]:
class PairedImageDataset(Dataset):
    def __init__(self, original_dir, retouched_dir, transform=None, num_images=None, split='train', val_ratio=0.1, test_ratio=0.1):
        self.original_dir = original_dir
        self.retouched_dir = retouched_dir
        self.transform = transform
        
        # Get list of image names and ensure the same images are in both folders
        self.original_images = sorted(os.listdir(original_dir))
        self.retouched_images = sorted(os.listdir(retouched_dir))
        self.original_images = [img for img in self.original_images if img in self.retouched_images]
        
        # Dynamically choose number of images to load
        if num_images is not None:
            self.original_images = self.original_images[:num_images]
        
        # Split the dataset into train, validation, and test sets
        train_images, test_images = train_test_split(self.original_images, test_size=test_ratio, random_state=42)
        train_images, val_images = train_test_split(train_images, test_size=val_ratio / (1 - test_ratio), random_state=42)

        if split == 'train':
            self.image_list = train_images
        elif split == 'val':
            self.image_list = val_images
        elif split == 'test':
            self.image_list = test_images
        else:
            raise ValueError("Split must be 'train', 'val', or 'test'")
    
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        original_image_path = os.path.join(self.original_dir, self.image_list[idx])
        retouched_image_path = os.path.join(self.retouched_dir, self.image_list[idx])  # Same name

        original_image = Image.open(original_image_path).convert('RGB')
        retouched_image = Image.open(retouched_image_path).convert('RGB')
        
        if self.transform:
            original_image = self.transform(original_image)
            retouched_image = self.transform(retouched_image)
        
        return original_image, retouched_image

In [None]:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

transform = v2.Compose([
    v2.Resize((config.image_size, config.image_size)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean, std),
])

In [None]:
train_dataset = PairedImageDataset(original_dir,
                                   retouched_dir,
                                   transform=transform,
                                   num_images=num_images_to_load,
                                   split='train')
val_dataset = PairedImageDataset(original_dir,
                                 retouched_dir,
                                 transform=transform,
                                 num_images=num_images_to_load,
                                 split='val')
test_dataset = PairedImageDataset(original_dir,
                                  retouched_dir,
                                  transform=transform,
                                  num_images=num_images_to_load,
                                  split='test')

In [None]:
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False)
test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=False)

In [None]:
len(train_loader)

In [None]:
# Check that images loaded correctly

def imshow_batch(orig, retouch):
    orig = orig / 2 + 0.5
    orig = orig.numpy()

    batch_size = len(orig)
    _, axes = plt.subplots(1, batch_size, figsize=(batch_size * 3, 2.5))
    if batch_size == 1: axes = [axes]

    for idx in range(batch_size):
        ax = axes[idx]
        img = np.transpose(orig[idx], (1, 2, 0))
        ax.imshow(img)
        ax.set_title("original")
        ax.axis('off')
    plt.show()
    
    retouch = retouch / 2 + 0.5
    retouch = retouch.numpy()
    
    batch_size = len(retouch)
    _, axes = plt.subplots(1, batch_size, figsize=(batch_size * 3, 2.5))
    if batch_size == 1: axes = [axes]

    for idx in range(batch_size):
        ax = axes[idx]
        img = np.transpose(retouch[idx], (1, 2, 0))
        ax.imshow(img)
        ax.set_title("retouched")
        ax.axis('off')
    plt.show()


for images, labels in train_loader:
    imshow_batch(images, labels)
    break

In [None]:
! pip install diffusers

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
! pip install invisible_watermark transformers accelerate safetensors

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm

In [None]:
from diffusers import UNet2DConditionModel, DDPMScheduler
import torch
from torch.optim import Adam

image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
image_pipe.to(device)

In [None]:
! export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [None]:
image_pipe.scheduler = scheduler
images = image_pipe(num_inference_steps=60).images
images[0]

In [None]:
! wandb login

In [None]:
import wandb
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

model_traind_name = 'retDiff'
num_epochs = 20  # @param
lr = 1e-5  # 2param
grad_accumulation_steps = 2 # @param
log_samples_every = 10
save_model_every = 20

wandb.init(project='retDiff', config={
    "learning_rate": lr,
    "architecture": "diffuser",
    "epochs": num_epochs,
    "grad_accumulation_steps": 2
    })
optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr)

losses = []

# Calculate adjusted learning rate
adjusted_lr = lr * grad_accumulation_steps
optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=adjusted_lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [None]:
for epoch in range(num_epochs):
    for step, batch in tqdm(enumerate(train_dataset), total=len(train_dataset)):
        clean_images = batch[1].to(device)
        clean_images = clean_images.unsqueeze(0) if clean_images.ndim == 3 else clean_images

        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0,
            image_pipe.scheduler.num_train_timesteps,
            (bs,),
            device=clean_images.device,
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        noisy_images = image_pipe.scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction for the noise
        noise_pred = image_pipe.unet(noisy_images, timesteps, return_dict=False)[0]

        # Compare the prediction with the actual noise:
        loss = F.mse_loss(noise_pred, noise)

        # Store for later plotting
        losses.append(loss.item())
        wandb.log({'loss':loss.item()})
        
        # Update the model parameters with the optimizer based on this loss
        loss.backward()
        
        # Gradient accumulation:
        if (step + 1) % grad_accumulation_steps == 0:
            optimizer.step()  # Update model parameters
            optimizer.zero_grad()  # Reset gradients
    
        if (step+1)%log_samples_every == 0:
            x = torch.randn(bs, 3, 256, 256).to(device) # Batch of 8
            for i, t in enumerate(image_pipe.scheduler.timesteps):
                model_input = image_pipe.scheduler.scale_model_input(x, t)
                with torch.no_grad():
                    noise_pred = image_pipe.unet(model_input, t)["sample"]
                x = image_pipe.scheduler.step(noise_pred, t, x).prev_sample
            grid = torchvision.utils.make_grid(x, nrow=4)
            im = grid.permute(1, 2, 0).cpu().clip(-1, 1)*0.5 + 0.5
            im = Image.fromarray(np.array(im*255).astype(np.uint8))
            wandb.log({'Sample generations': wandb.Image(im)})
        if (step+1)%save_model_every == 0:
            image_pipe.save_pretrained(model_save_name+'_latest')
    # Calculate average loss for the epoch
    avg_loss = sum(losses[-len(train_dataset):]) / len(train_dataset)
    print(f"Epoch {epoch} average loss: {avg_loss}")
    
image_pipe.save_pretrained(model_save_name)
wandb.finish()

In [None]:
image_pipe

In [None]:
images = image_pipe(num_inference_steps=60).images
images[0]

In [None]:
%env CUDA_LAUNCH_BLOCKING=1

In [None]:
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel

# Load the pre-trained UNet model and DDPM scheduler (diffusion process)
model =  image_pipe.unet # Example model
scheduler = image_pipe.scheduler
scheduler.set_timesteps(num_inference_steps=60)

# Load and preprocess the image
def preprocess_image(image_path, image_size=256):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0).to(device)

# Display image
def show_image(tensor_image, title="Image"):
    image = tensor_image.squeeze(0).permute(1, 2, 0).cpu().numpy()
    plt.imshow(image)
    plt.title(title)
    plt.axis("off")
    plt.show()

# Add noise to an image
def add_noise(image, scheduler, noise_level=0.5):
    noise = torch.randn_like(image)
    timesteps = torch.full((image.shape[0],), int(scheduler.num_train_timesteps * noise_level), device=image.device)
    noisy_image = scheduler.add_noise(image, noise, timesteps)
    return noisy_image, timesteps

def denoise_image(noisy_image, scheduler, model):
    model.eval()
    with torch.no_grad():
        for t in tqdm(scheduler.timesteps):
            # Predict noise from the noisy image
            noise_pred = model(noisy_image, t).sample
            # Denoise step
            noisy_image = scheduler.step(noise_pred, t, noisy_image).prev_sample
    return noisy_image


# Load and preprocess an example image
original_image = test_dataset[20][0].to(device)
target_imgage = test_dataset[20][1]
# Display original image
show_image(original_image, title="Original Image")
show_image(target_imgage, title="Target Image")
# Add noise to the image
original_image = original_image.unsqueeze(0) if original_image.ndim == 3 else original_image

In [None]:
# Add noise to the original image
for i in np.linspace(0.01,0.5,num=10):
    print(i)
    scheduler.set_timesteps(num_inference_steps=6)
    noisy_image, timesteps = add_noise(original_image, scheduler,noise_level=i)
    show_image(noisy_image, title="Noisy Image")

    # Denoise the noisy image
    denoised_image = denoise_image(noisy_image, scheduler, model)

    # Display the denoised image
    show_image(denoised_image, title="Denoised Image")

In [None]:
import os
outputs = 'outputs'
if not os.path.exists(outputs):
    os.mkdir(outputs)
def save_image(tensor_image, file_path):
    """
    Save a tensor image as a PNG file.

    Parameters:
        tensor_image (torch.Tensor): The image tensor to save (C, H, W).
        file_path (str): The path where the image will be saved.
    """
    # Convert tensor image to PIL Image
    image = tensor_image.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy()  # Convert to numpy array
    image = (image * 255).astype(np.uint8)  # Convert to uint8 format
    pil_image = Image.fromarray(image)  # Create a PIL Image from the numpy array
    
    # Save the image
    pil_image.save(file_path)
    #print(f"Image saved to {file_path}")

for step, (image,_) in tqdm(enumerate(test_dataset), total=len(test_dataset)):
    scheduler.set_timesteps(num_inference_steps=6)
    original_image = image.to(device)
    original_image = original_image.unsqueeze(0) if original_image.ndim == 3 else original_image
    noisy_image, timesteps = add_noise(original_image, scheduler,noise_level=0.2)
    denoised_image = denoise_image(noisy_image, scheduler, model)
    save_image(denoised_image,os.path.join(outputs,f'img{step}.png'))
    show_image(denoised_image, title=step)

In [None]:
!zip -r file.zip /kaggle/working/outputs