In [None]:
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
import torchvision.transforms as T

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available

In [None]:
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="text_encoder"
).cuda()
vae = AutoencoderKL.from_pretrained(
    "models/12to14", subfolder="vae"
).cuda()
unet = UNet2DConditionModel.from_pretrained(
    "models/16to18", subfolder="unet"
).cuda()
scheduler = DDIMScheduler.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler"
)
low_res_scheduler = DDPMScheduler.from_pretrained(
    "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
)

def get_tokenized_caption(caption):
    captions = [caption]
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

train_transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

In [None]:
def get_image(lrim, unet, timesteps=30, negative=False, guidance_scale_n=0, positive=False, guidance_scale_p=0):
    with torch.no_grad():
        weight_dtype = torch.float32
        latents = torch.randn((1, 4, 128, 128)).to('cuda')

        prompt_embeds = text_encoder(torch.unsqueeze(get_tokenized_caption("satellite photo")[0], dim=0).to('cuda'))[0]
        if negative:
            neg_embeds = text_encoder(torch.unsqueeze(get_tokenized_caption("blurry, lowres, low quality, deformed")[0], dim=0).to('cuda'))[0]
            prompt_embeds = torch.cat([prompt_embeds, neg_embeds])
        
        if positive:
            pos_embeds = text_encoder(torch.unsqueeze(get_tokenized_caption("sharp, high quality, detailed, realistic")[0], dim=0).to('cuda'))[0]
            prompt_embeds = torch.cat([prompt_embeds, pos_embeds])

        scheduler.set_timesteps(timesteps, device='cuda')
        timesteps = scheduler.timesteps

        for i, t in enumerate(timesteps):
            latent_inp = torch.cat([latents]*2) if negative else latents
            latent_inp = scheduler.scale_model_input(latent_inp, t)
            latent_inp = torch.cat([latent_inp, torch.unsqueeze(lrim, dim=0).repeat((2 if negative else 1, 1, 1, 1))], dim=1)
            noise_pred = unet(latent_inp, t, prompt_embeds, class_labels=noise_level, return_dict=False)[0]

            if negative and not positive:
                noise_pred_text, noise_pred_neg = noise_pred.chunk(2)
                noise_pred = noise_pred_text + guidance_scale_n * (noise_pred_text - noise_pred_neg)
            
            if positive and not negative:
                noise_pred_text, noise_pred_pos = noise_pred.chunk(2)
                noise_pred = noise_pred_text + guidance_scale_p * (noise_pred_pos - noise_pred_text)
            
            if positive and negative:
                noise_pred_text, noise_pred_neg = noise_pred.chunk(3)
                noise_pred = noise_pred_text + guidance_scale_n * (noise_pred_text - noise_pred_neg) + guidance_scale_p * (noise_pred_pos - noise_pred_text)

            latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

        result = vae.decode(latents/vae.config.scaling_factor, return_dict=False)[0]

    return result[0]#, mins, maxs, nns, ops, means, stds

def get_lrim(row, col, low_res):
    lr_center_crop = low_res.crop(((low_res.width*3)//8, (low_res.width*3)//8, (low_res.width*5)//8, (low_res.width*5)//8))
    lrim =  lr_center_crop.crop((tile_dim//8*col, tile_dim//8*row, tile_dim//8*(col+2), tile_dim//8*(row+2)))
    return train_transforms(lrim)

In [None]:
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, StableDiffusionUpscalePipeline
import torch
import torchvision
from PIL import Image
import numpy as np
import os

import matplotlib.pyplot as plt

image_dirs = os.listdir('/u/ansh/cascaded-generative/bing_20/')

hr_zoom = 20

transform = torchvision.transforms.ToPILImage()
noise_level = torch.tensor([2], dtype=torch.long, device='cuda')

def generate_sample(lrim, timesteps, uuid, unet, negative=False, guidance_scale_n=0, positive=False, guidance_scale_p=0):
    if save:
        if not os.path.exists(save_dir + f'{hr_zoom}/{lrw}_{thr}'):
            os.makedirs(save_dir + f'{hr_zoom}/{lrw}_{thr}')    
        
    image = get_image(lrim, unet, timesteps = timesteps, negative=negative, guidance_scale_n=guidance_scale_n, positive=positive, guidance_scale_p=guidance_scale_p)
    img = transform(torch.clip((image+1)/2, 0, 1))
    if show:
        plt.imshow(img)
        plt.show()
    if save:
        img.save(save_dir + f'{hr_zoom}/{timesteps}/{i}.png')
    return img

def get_random_image():
    root_dir = '/scratch/bbut/bing_datasets/'
    dataset = random.choice(['bing_20', 'bing_urban'])
    uuid = random.choice(os.listdir(root_dir + dataset))
    path_to_image = root_dir + dataset + '/' + uuid + '/'
    return path_to_image, uuid

def load_random_lr():
    image_path, uuid = get_random_image()
    lr_image = Image.open(image_path+f'/{hr_zoom-6}.jpg').convert("RGB")
    lr_image = lr_image.crop(((lr_image.width*3)//8, (lr_image.width*3)//8, (lr_image.width*5)//8, (lr_image.width*5)//8))
    lr_image = lr_image.crop((0, 0, 512//4, 512//4))
    lrim = train_transforms(lr_image).cuda()
    return lrim

unet10 = UNet2DConditionModel.from_pretrained(
    "models/10to12", subfolder="unet"
).cuda()
unet12 = UNet2DConditionModel.from_pretrained(
    "models/12to14", subfolder="unet"
).cuda()
unet14 = UNet2DConditionModel.from_pretrained(
    "models/14to16", subfolder="unet"
).cuda()
unet16 = UNet2DConditionModel.from_pretrained(
    "models/16to18", subfolder="unet"
).cuda()
unet18 = UNet2DConditionModel.from_pretrained(
    "models/18to20", subfolder="unet"
).cuda()

In [None]:
images_to_be_used = []
for i in range(20):
    images_to_be_used.append(get_random_image())

In [None]:
import time
import matplotlib.pyplot as plt
from diffusers import DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDIMScheduler, DPMSolverSDEScheduler, PNDMScheduler, UniPCMultistepScheduler
save_dir = '/scratch/bbut/ansh/output_comparisons/'
save = False
show = True

config = {'num_train_timesteps': 1000, 'beta_start': 0.0001, 'beta_end': 0.02, 
        'beta_schedule': 'scaled_linear', 'trained_betas': None, 'clip_sample': False, 
        'set_alpha_to_one': False, 'steps_offset': 1, 'prediction_type': 'v_prediction', 'thresholding': False, 
        'dynamic_thresholding_ratio': 0.995, 'clip_sample_range': 1.0, 
        'sample_max_value': 1.0, 'timestep_spacing': 'leading', 
        'rescale_betas_zero_snr': False, 
        '_use_default_values': ['clip_sample_range', 'rescale_betas_zero_snr', 'thresholding', 'sample_max_value', 'timestep_spacing', 'dynamic_thresholding_ratio']}

scheduler = DPMSolverSinglestepScheduler().from_config(config)
timesteps = 50
images = []
images_2 = []
images_5 = []
lr_images = []
hr_images = []
start = time.time()
for i in range(len(images_to_be_used)):
    # if(i%10== 0):
    #     print(i)

    # Low res image
    image_path, uuid = images_to_be_used[i]
    lr_image = Image.open(image_path+f'/{hr_zoom-10}.jpg').convert("RGB")
    lr_image = lr_image.crop(((lr_image.width*3)//8, (lr_image.width*3)//8, (lr_image.width*5)//8, (lr_image.width*5)//8))
    lr_image = lr_image.crop((192, 192, 320, 320))
    lrim = train_transforms(lr_image).cuda()

    print("No negative prompt")
    # Super res 3 times
    img12 = train_transforms(generate_sample(lrim, timesteps, uuid, unet10).crop((192, 192, 320, 320))).cuda()
    img14 = train_transforms(generate_sample(img12, timesteps, uuid, unet12).crop((192, 192, 320, 320))).cuda()
    img16 = train_transforms(generate_sample(img14, timesteps, uuid, unet14).crop((192, 192, 320, 320))).cuda()
    img18 = train_transforms(generate_sample(img16, timesteps, uuid, unet16).crop((192, 192, 320, 320))).cuda()
    img20 = generate_sample(img18, timesteps, uuid, unet18)

    print("Negative Conditioning - 52334")
    # Super res 3 times
    img12 = train_transforms(generate_sample(lrim, timesteps, uuid, unet10, True, 5).crop((192, 192, 320, 320))).cuda()
    img14 = train_transforms(generate_sample(img12, timesteps, uuid, unet12, True, 2).crop((192, 192, 320, 320))).cuda()
    img16 = train_transforms(generate_sample(img14, timesteps, uuid, unet14, True, 3).crop((192, 192, 320, 320))).cuda()
    img18 = train_transforms(generate_sample(img16, timesteps, uuid, unet16, True, 3).crop((192, 192, 320, 320))).cuda()
    img20 = generate_sample(img18, timesteps, uuid, unet18, True, 4)

    if show:
        plt.imshow(lr_image)
        plt.show()



    # Corresponding HR Ground Truth 
    gt = Image.open(image_path+f'/{hr_zoom}.jpg').convert("RGB")
    gt = gt.crop((768, 768, 1280, 1280))
    if save:
        if not os.path.exists(save_dir + f'{hr_zoom}/gt'):
            os.makedirs(save_dir + f'{hr_zoom}/gt')
        gt.save(save_dir+f'{hr_zoom}/gt/{i}.png')
    if show:
        plt.imshow(gt)
        plt.show()

    images.append(img20)
    lr_images.append(lr_image)
    hr_images.append(gt)

    if(i%100 == 0):
        print((time.time() - start)/(i+1))
print(f"timesteps {timesteps}")
np_images = np.stack([np.asarray(img) for img in images])
grid_img = torchvision.utils.make_grid(torch.Tensor(np_images).permute(0, 3, 1, 2), nrow=5)
plt.imshow(grid_img.permute(1, 2, 0)/255)
plt.show()
print("high res")
np_images = np.stack([np.asarray(img) for img in hr_images])
grid_img = torchvision.utils.make_grid(torch.Tensor(np_images).permute(0, 3, 1, 2), nrow=5)
plt.imshow(grid_img.permute(1, 2, 0)/255)
plt.show()