In [None]:
import argparse
import logging
import math
import os
import random
import shutil
from pathlib import Path

import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
import torchvision.transforms as T

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DDIMScheduler, ControlNetModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
import matplotlib.pyplot as plt

import torchvision

In [None]:
# root_folder = '/scratch/bbut/prathi3/unconditional/10_50000'
root_folder = 'models/10gen'

noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    root_folder, subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
    root_folder, subfolder="text_encoder"
).cuda()
vae = AutoencoderKL.from_pretrained(
    root_folder, subfolder="vae"
).cuda()
unet = UNet2DConditionModel.from_pretrained(
    root_folder, subfolder="unet"
).cuda()
scheduler = DDIMScheduler.from_pretrained(
    root_folder, subfolder="scheduler"
)
controlnet = ControlNetModel.from_pretrained(
    "models/10gen", subfolder="controlnet"
).cuda()

# controlnet = ControlNetModel.from_unet(unet).cuda()

def get_tokenized_caption(caption, tokenizer):
    captions = [caption]
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids

train_transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )
    
from diffusers import StableDiffusionUpscalePipeline, DPMSolverSinglestepScheduler

pipeline = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler")
scheduler = DPMSolverSinglestepScheduler.from_config(pipeline.scheduler.config)

In [None]:
transform = T.ToPILImage()
def get_controlnet_image(cond=None, use_lr=False, use_ctrlnet=True, use_neg=True, use_pos=False):
    transform = torchvision.transforms.ToPILImage()
    noise_level = torch.tensor([20], dtype=torch.long, device='cuda')
    weight_dtype = torch.float32
    latents = torch.randn((1, 4, 64, 64)).to('cuda')
    
    repeats = 1
    if use_neg:
        repeats +=1
    if use_pos:
        repeats +=1
    
    if use_lr:
        latents = torch.randn((1, 4, 128, 128)).to('cuda')
        lrim = -torch.ones((1, 3, 128, 128)).cuda()

    lrim = lrim.repeat((repeats, 1, 1, 1))
    cond = cond.repeat((repeats, 1, 1, 1))
    
    prompt_embeds = text_encoder(torch.unsqueeze(get_tokenized_caption("satellite photo", tokenizer)[0], dim=0).to('cuda'))[0]
    if use_neg:
        neg_embeds = text_encoder(torch.unsqueeze(get_tokenized_caption("blurry, lowres, low quality", tokenizer)[0], dim=0).to('cuda'))[0]
        prompt_embeds = torch.cat([prompt_embeds, neg_embeds])
    if use_pos:
        pos_embeds = text_encoder(torch.unsqueeze(get_tokenized_caption("fall colors", tokenizer)[0], dim=0).to('cuda'))[0]
        prompt_embeds = torch.cat([prompt_embeds, pos_embeds])        

    scheduler.set_timesteps(50, device='cuda')
    timesteps = scheduler.timesteps

    with torch.no_grad():
        for i, t in enumerate(timesteps):
            
            latent_inp = torch.cat([latents]*repeats)
            latent_inp = scheduler.scale_model_input(latent_inp, t)

            if use_lr:
                latent_inp = torch.cat([latent_inp, lrim], dim=1)
                # cond = cond.resize((cond_image.width*2, cond_image.height*2))
                
            if use_ctrlnet:
                down_block_res, mid_block_res = controlnet(latent_inp, t, prompt_embeds, class_labels=noise_level, controlnet_cond=cond, return_dict=False)
                
                noise_pred = unet(latent_inp, t, prompt_embeds, class_labels=noise_level, down_block_additional_residuals=[
                            sample.to(dtype=weight_dtype) for sample in down_block_res
                        ], mid_block_additional_residual=mid_block_res.to(dtype=weight_dtype)).sample
            else:
                noise_pred = unet(latent_inp, t, prompt_embeds, class_labels=noise_level, return_dict=False)[0]

            if use_neg and not use_pos:
                noise_pred_text, noise_pred_neg = noise_pred.chunk(2)
                noise_pred = noise_pred_text + 3.5 * (noise_pred_text - noise_pred_neg)
            if use_neg and use_pos:
                noise_pred_text, noise_pred_neg, noise_pred_pos = noise_pred.chunk(3)
                noise_pred = noise_pred_text + 3 * (noise_pred_text - noise_pred_neg) + 5 * (noise_pred_pos - noise_pred_neg)
                
            latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
        
        result = vae.decode(latents/vae.config.scaling_factor, return_dict=False)[0]

        return transform(torch.clamp(result[0]*0.5+0.5, 0, 1))


In [None]:
import matplotlib.pyplot as plt
import os
import random
from PIL import Image
import numpy as np

searchdir = '/scratch/bbut/bing_datasets/bing_20/'

for k in range(100):
    example_imdir = random.choice(os.listdir(searchdir))
    cond_og = Image.open(f"{searchdir}/{example_imdir}/10_map.jpg")
    cond_og = cond_og.crop(((cond_og.width * 3) // 8, (cond_og.width * 3) // 8, (cond_og.width * 5) // 8, (cond_og.width * 5) // 8))
    cond_image = cond_og.resize((cond_og.width * 2, cond_og.height * 2))
    img = Image.open(f"{searchdir}/{example_imdir}/10.jpg")
    img = img.crop(((img.width * 3) // 8, (img.width * 3) // 8, (img.width * 5) // 8, (img.width * 5) // 8))

    cond = train_transforms(cond_image).cuda()
    images = []

    images.append(cond_og)
    
    # Generate and store the images from get_controlnet_image
    for i in range(1):
        res = get_controlnet_image(cond, use_lr=True, use_neg=False)
        images.append(res)
    

    # Calculate the total width of the combined image
    total_width = sum(img.width for img in images)

    # Create a new image with the combined width and the maximum height
    combined_image = Image.new('RGB', (total_width, max(img.height for img in images)))

    # Paste the individual images onto the combined image
    x_offset = 0
    for img in images:
        combined_image.paste(img, (x_offset, 0))
        x_offset += img.width

    # Save the combined image
    save_path = os.path.join(f"consistency/{k}.jpg")
    combined_image.save(save_path)

    # print(searchdir, example_imdir)
    # plt.imshow(np.array(combined_image))
    # plt.show()

In [None]:
import matplotlib.pyplot as plt
import os
import random
from PIL import Image
import numpy as np

searchdir = '/scratch/bbut/bing_datasets/bing_20/'
for i in range(50):
    example_imdir = random.choice(os.listdir(searchdir))
    cond_og = Image.open(f"{searchdir}/{example_imdir}/10_map.jpg")
    cond_og = cond_og.crop(((cond_og.width * 3) // 8, (cond_og.width * 3) // 8, (cond_og.width * 5) // 8, (cond_og.width * 5) // 8))
    cond_image = cond_og.resize((cond_og.width * 2, cond_og.height * 2))

    actual = Image.open(f"{searchdir}/{example_imdir}/10.jpg")
    actual = actual.crop(((actual.width * 3) // 8, (actual.width * 3) // 8, (actual.width * 5) // 8, (actual.width * 5) // 8))

    example_imdir = random.choice(os.listdir(searchdir))
    randomim = Image.open(f"{searchdir}/{example_imdir}/10.jpg")
    randomim = randomim.crop(((randomim.width * 3) // 8, (randomim.width * 3) // 8, (randomim.width * 5) // 8, (randomim.width * 5) // 8))

    images = []

    images.append(cond_og)
    images.append(actual)
    # images.append(randomim)

    # Calculate the total width of the combined image
    total_width = sum(img.width for img in images)

    # Create a new image with the combined width and the maximum height
    combined_image = Image.new('RGB', (total_width, max(img.height for img in images)))

    # Paste the individual images onto the combined image
    x_offset = 0
    for img in images:
        combined_image.paste(img, (x_offset, 0))
        x_offset += img.width

    # Save the combined image
    save_path = os.path.join(f"consistency_gts/{i}.jpg")
    combined_image.save(save_path)
    print(searchdir, example_imdir)
    # plt.imshow(np.array(combined_image))
    # plt.show()