In [1]:
import os
import zipfile
import numpy as np
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler
import lpips
from pytorch_msssim import ssim
from torchvision.transforms.functional import to_tensor
import gc
from matplotlib import pyplot as plt

In [2]:
data_folder = './test_split'

In [3]:
grayscale_transform = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor()
])

transform = transforms.Compose([
    transforms.ToTensor()
])

class GrayscaleDataset(Dataset):
    def __init__(self, data_folder, transform=None, grayscale_transform=None, extra_prompt=False):
        self.data_folder = data_folder
        self.transform = transform
        self.grayscale_transform = grayscale_transform
        self.extra_prompt = extra_prompt
        all_files = os.listdir(data_folder)
        self.ids = list(set(file.split('_')[0] for file in all_files if file.split('_')[0].isdigit()))
        print("Collected IDs:", self.ids)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        image_id = self.ids[idx]

        image_path = os.path.join(self.data_folder, f"{image_id}.jpg")
        image = Image.open(image_path).convert('RGB')

        gray_image_path = os.path.join(self.data_folder, f"{image_id}_gray.jpg")
        gray_image = Image.open(gray_image_path)

        mask_path = os.path.join(self.data_folder, f"{image_id}_mask.jpg")
        mask = Image.open(mask_path).convert('L')

        text_path = os.path.join(self.data_folder, f"{image_id}.txt")
        with open(text_path, 'r') as file:
            text_data = file.read().strip()

        if self.extra_prompt:
            text_data = "Colorize the whole image. " + text_data

        if self.transform:
            image = self.transform(image)
            gray_image = self.grayscale_transform(gray_image)
            mask = self.transform(mask)


        return {'image': image, 'input_image': gray_image, 'gray_image': gray_image, 'mask': mask, 'text': text_data}


class CannyDataset(Dataset):
    def __init__(self, data_folder, transform=None, grayscale_transform=None, extra_prompt=False):
        self.data_folder = data_folder
        self.transform = transform
        self.grayscale_transform = grayscale_transform
        all_files = os.listdir(data_folder)
        self.ids = list(set(file.split('_')[0] for file in all_files if file.split('_')[0].isdigit()))
        print("Collected IDs:", self.ids)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        image_id = self.ids[idx]

        image_path = os.path.join(self.data_folder, f"{image_id}.jpg")
        image = Image.open(image_path).convert('RGB')

        gray_image_path = os.path.join(self.data_folder, f"{image_id}_gray.jpg")
        gray_image = Image.open(gray_image_path)
        canny_image = np.array(gray_image)

        low_threshold = 100
        high_threshold = 200
        canny_image = cv2.Canny(canny_image, low_threshold, high_threshold)
        canny_image = canny_image[:, :, None]
        canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
        canny_image = Image.fromarray(canny_image)

        mask_path = os.path.join(self.data_folder, f"{image_id}_mask.jpg")
        mask = Image.open(mask_path).convert('L')

        text_path = os.path.join(self.data_folder, f"{image_id}.txt")
        with open(text_path, 'r') as file:
            text_data = file.read().strip()

        if self.transform:
            image = self.transform(image)
            canny_image = self.transform(canny_image)
            gray_image = self.grayscale_transform(gray_image)
            mask = self.transform(mask)


        return {'image': image, 'input_image': canny_image, 'gray_image': gray_image, 'mask': mask, 'text': text_data}



def get_loader(grayscale, extra_prompt=False):
    if grayscale:
        dataset = GrayscaleDataset(data_folder, transform=transform, grayscale_transform=grayscale_transform, extra_prompt=extra_prompt)
    else:
        dataset = CannyDataset(data_folder, transform=transform, grayscale_transform=grayscale_transform)
    loader = DataLoader(dataset, batch_size=8, shuffle=False)
    return loader

In [4]:
def get_pipe(grayscale=False):
    if grayscale:
        controlnet = ControlNetModel.from_pretrained("latentcat/control_v1p_sd15_brightness", use_safetensors=True)
        pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True).to("cuda")
    else:
        controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", use_safetensors=True)
        pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet, use_safetensors=True).to("cuda")
    pipe.safety_checker = lambda images, clip_input: (images, [False] * len(images))
    return pipe

In [5]:
def adjust_gamma(image, gamma=1.0):
    # build a lookup table mapping the pixel values [0, 255] to
    # their adjusted gamma values
    invGamma = 1.0 / gamma
    table = np.array([((i / 255.0) ** invGamma) * 255
        for i in np.arange(0, 256)]).astype("uint8")

    # apply gamma correction using the lookup table
    return cv2.LUT(image, table)

def brighten_image(image, brightness=100):
    # Make sure the brightness value is appropriate to prevent overflow
    brightness = np.clip(brightness, 0, 255)

    # Create an array of the same shape as the image, filled with the brightness value
    brightness_matrix = np.ones(image.shape, dtype=np.uint8) * brightness

    # Add the brightness matrix to the image
    brightened_image = cv2.add(image, brightness_matrix)

    return brightened_image


def postprocess(grayscale, colored):
    grayscale = grayscale[0]

    # Convert the grayscale image to Lab color space
    grayscale = cv2.cvtColor(grayscale, cv2.COLOR_GRAY2BGR)
    grayscale_lab = cv2.cvtColor(grayscale, cv2.COLOR_BGR2Lab)

    # Convert the colored image to Lab color space
    colored = cv2.cvtColor(colored, cv2.COLOR_RGB2BGR)
    colored_lab = cv2.cvtColor(colored, cv2.COLOR_BGR2Lab)

    # Replace the 'a' and 'b' channels of the grayscale image with the ones from the colored image
    combined_lab = np.concatenate((grayscale_lab[:, :, 0:1], colored_lab[:, :, 1:]), axis=2)

    # Convert the result back to RGB color space
    colorized = cv2.cvtColor(combined_lab.astype('uint8'), cv2.COLOR_Lab2RGB)

    return colorized

def adjust_intensity(gray, colored):
    target = np.transpose(gray, (1, 2, 0))
    source = colored

    # Calculate the grayscale intensity of the source and target images
    source_intensity = 0.299 * source[..., 0] + 0.587 * source[..., 1] + 0.114 * source[..., 2]
    target_intensity = 0.299 * target[..., 0] + 0.587 * target[..., 1] + 0.114 * target[..., 2]

    # Calculate the adjustment factor
    k = target_intensity / source_intensity

    # Adjust the RGB values
    adjusted = source * k[..., np.newaxis]

    # Clip the values to the range [0, 255]
    adjusted = np.clip(adjusted, 0, 255).astype('uint8')

    return adjusted

def color_transfer(gray, colored):
    target = np.transpose(gray, (1, 2, 0))
    source = colored

    # Convert the source and target images to float32
    source = source.astype('float32')
    target = target.astype('float32')

    # Split the source and target images into their respective color channels
    source_b, source_g, source_r = cv2.split(source)
    target_b, target_g, target_r = cv2.split(target)

    # Calculate the mean and standard deviation of each color channel in the source and target images
    source_mean, source_std = np.mean(source_b), np.std(source_b)
    target_mean, target_std = np.mean(target_b), np.std(target_b)

    # Normalize each color channel of the source image by subtracting the mean and dividing by the standard deviation
    source_b = (source_b - source_mean) / source_std

    # Scale and shift each color channel of the target image using the mean and standard deviation of the source image
    target_b = target_b * source_std + source_mean

    # Repeat the above steps for the green and red color channels
    source_mean, source_std = np.mean(source_g), np.std(source_g)
    target_mean, target_std = np.mean(target_g), np.std(target_g)
    source_g = (source_g - source_mean) / source_std
    target_g = target_g * source_std + source_mean

    source_mean, source_std = np.mean(source_r), np.std(source_r)
    target_mean, target_std = np.mean(target_r), np.std(target_r)
    source_r = (source_r - source_mean) / source_std
    target_r = target_r * source_std + source_mean

    # Merge the color channels back together
    transfer = cv2.merge([target_b, target_g, target_r])

    # Clip the values in the transfer image to the range [0, 255] and convert it back to uint8
    transfer = np.clip(transfer, 0, 255).astype('uint8')

    return transfer

In [6]:
def calculate_metrics(generated, original, mask):

    if not isinstance(generated, torch.Tensor):
        generated = to_tensor(generated).unsqueeze(0).to('cuda')
    if not isinstance(original, torch.Tensor):
        original = to_tensor(original).unsqueeze(0).to('cuda')
    if not isinstance(mask, torch.Tensor):
        mask = to_tensor(mask).unsqueeze(0).to('cuda')

    if generated.dim() == 3:
        generated = generated.unsqueeze(0)
    if original.dim() == 3:
        original = original.unsqueeze(0)
    if mask.dim() == 3:
        mask = mask.unsqueeze(0)

    if generated.shape[1] == 1:
        generated = generated.repeat(1, 3, 1, 1)
    if original.shape[1] == 1:
        original = original.repeat(1, 3, 1, 1)

    mse_loss = torch.nn.MSELoss()
    full_mse = mse_loss(generated, original)

    if mask.shape[1] != 1:
        mask = mask[:, 0:1, :, :]

    masked_generated = generated * mask
    masked_original = original * mask
    masked_mse = mse_loss(masked_generated, masked_original)

    ssim_val = ssim(generated, original, data_range=1, size_average=True)
    lpips_vgg = lpips.LPIPS(net='vgg').to('cuda')
    lpips_val = lpips_vgg(generated, original)
    return full_mse.item(), masked_mse.item(), ssim_val.item(), lpips_val.item()

def run_inference(data_loader, model, guidance_scale=4.0, post=False):
    results = []
    i = 0
    with torch.no_grad():
        for batch in data_loader:
            texts = batch['text']
            input_images = batch['input_image'].to('cuda')
            gray_images = batch['gray_image']
            gt_images = batch['image'].to('cuda')
            masks = batch['mask'].to('cuda')
            generated_images = model(prompt=texts, image=input_images, guess_mode=False, guidance_scale=guidance_scale).images

            for gen_img, gt_img, input_img, gray_img, mask, text in zip(generated_images, gt_images, input_images, gray_images, masks, texts):
                if post:
                    gen_img = postprocess(gray_img.numpy(), np.array(gen_img))
                metrics = calculate_metrics(gen_img, gt_img, mask)

                if i < 10:
                    print(metrics)
                    i += 1

                results.append((text, gen_img, gt_img, input_img, metrics))

    return results


In [7]:
ids = [str(i) for i in range(1, 50)]

# Pix2Pix

In [None]:
# plotting
torch.cuda.empty_cache()
gc.collect()

model_id = "timbrooks/instruct-pix2pix"
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
pipe.to("cuda")
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

os.makedirs("plotting", exist_ok=True)

for id in ids:
    target = Image.open(f"test_split/{id}.jpg")
    gray = Image.open(f"test_split/{id}_gray.jpg")
    with open(f"test_split/{id}.txt", "r") as f:
        prompt = f.readline()
    prompt = "Colorize the whole image. " + prompt

    generated = pipe(prompt, gray, guess_mode=False, guidance_scale=7.0).images[0]

    generated.save(f"plotting/pix2pix_{id}.jpg")

# Brightness

In [None]:
# plotting
torch.cuda.empty_cache()
gc.collect()

pipe = get_pipe(True)

os.makedirs("plotting", exist_ok=True)

for id in ids:
    target = Image.open(f"test_split/{id}.jpg")
    gray = Image.open(f"test_split/{id}_gray.jpg")
    with open(f"test_split/{id}.txt", "r") as f:
        prompt = f.readline()

    generated = pipe(prompt, gray, guess_mode=False, guidance_scale=4.0).images[0]

    generated.save(f"plotting/brightness_{id}.jpg")

# Ledits

In [None]:
# plotting

from diffusers import LEditsPPPipelineStableDiffusion
pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

os.makedirs("plotting", exist_ok=True)

for id in ids:
    target = Image.open(f"test_split/{id}.jpg")
    gray = Image.open(f"test_split/{id}_gray.jpg")
    with open(f"test_split/{id}.txt", "r") as f:
        prompt = f.readline()

    _  = pipe.invert(gray, num_inversion_steps=50, skip=0.1)

    generated = pipe(editing_prompt=[prompt], edit_guidance_scale=10.0, edit_threshold=0.75).images[0]

    generated.save(f"plotting/ledits_{id}.jpg")

In [None]:
# plotting
torch.cuda.empty_cache()
gc.collect()

controlnet = ControlNetModel.from_pretrained("../../trained_model", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
).to("cuda")


os.makedirs("plotting", exist_ok=True)

for id in ids:
    target = Image.open(f"test_split/{id}.jpg")
    gray = Image.open(f"test_split/{id}_gray.jpg")
    with open(f"test_split/{id}.txt", "r") as f:
        prompt = f.readline()

    generated = pipe(prompt, gray, guess_mode=False, guidance_scale=4.0).images[0]

    generated.save(f"plotting/v1_{id}.jpg")