In [None]:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch
import os
import PIL
from metrics import LPIPSMetric, SSIM , PSNR, F1_score
from torchvision import transforms
import numpy as np
import tifffile as tiff
from PIL import Image



In [None]:
base_model_path = "runwayml/stable-diffusion-v1-5"
remote_sensing_model = "tjisousa/sd-remote-sensing-model-256"

controlnet_path = "Saved Model"

In [None]:
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    remote_sensing_model, controlnet=controlnet, torch_dtype=torch.float16,
    cache_dir="RS_Saved Model"
)

In [None]:
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()

In [None]:
'''
Please just run the following code once you bring back the original ControlNet stable diffusion Model. 
'''

image_directory= "F:\Shahmir\ControlNet Satellite Imagery\Masks"
image_files = os.listdir(image_directory)
output_directory= "F:\Shahmir\ControlNet Satellite Imagery\Vanilla Test Images"
gt_directory = "F:\Shahmir\ControlNet Satellite Imagery\Ground Truths"
gt_files= os.listdir(gt_directory)
out_directory= "F:\Shahmir\ControlNet Satellite Imagery\Test"
# output_directory= "F:\Shahmir\ControlNet Satellite Imagery\RSI Test Images"
# gt_directory="F:\Shahmir\ControlNet Satellite Imagery\Vanilla Test Images"

In [None]:
for image in image_files:
    image_path = os.path.join(image_directory, image)
    print(image)
    break

In [None]:
for image in image_files:
    image_path = os.path.join(image_directory, image)
    control_image = load_image(image_path)
    # prompt =" Generate a realistic high- resolution satellite image of a which is zoomed out city with very little vegetation.\
    # # Houses should have brown roofs.\
    # # Focus on surrounding vegetation which is deep green with varied shades but majorly contains barren landscape. \
    # # Roads should be sharply defined against the landscape. Image should be ultra high defination."
    prompt ="A greener image. Segmentation area should have construction only"
    generator = torch.manual_seed(0)
    output = pipe(prompt, num_inference_steps=100, generator=generator, image=control_image).images[0]
    output_path = os.path.join(out_directory, image)
    print(output_path)
    #control_image.save(output_path)

    output.save(output_path)
    break

Evaluation Method 

In [None]:
data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
])

In [None]:
Lpips_list = []
ssim_list = []
psnr_list = []
for index, image in enumerate(image_files):
 

    Lpips_metric = LPIPSMetric()
    ssim_metric = SSIM()
    psnr_metric = PSNR()

  
    
    image_path = os.path.join(image_directory, image)
    control_image = load_image(image_path)
    prompt =" Generate a high- resolution  aerial satellite image of a city with lots of trees and brown landscape \
    Houses should have brown roofs.\
    Focus on surrounding vegetation which is deep green with varied shades. \
    Roads should be sharply defined agaisnt the landscape. Image should be ultra high defination."
    gt_path = os.path.join(gt_directory, gt_files[index])

    generator = torch.manual_seed(0)
    output = pipe(prompt, num_inference_steps=100, generator=generator, image=control_image).images[0]
    output_path = os.path.join(output_directory, image)
    output.save(output_path)
    
    output=data_transform(output)
    if output.shape != (3, 256, 256):
        output= output[:3, :256, :256]    
        
    gt=Image.open(gt_path)
    gt= data_transform(gt)
    gt= gt[:3, :, :]
    
    Lpips_metric.update(output.round().detach().cpu(),gt.detach().cpu())
    ssim_metric.update(output.unsqueeze(dim=0).round().detach().cpu(), gt.unsqueeze(dim=0).detach().cpu())
    psnr_metric.update(output.round().detach().cpu(), gt.detach().cpu())

    Lpips_list.append(Lpips_metric.compute())
    ssim_list.append(ssim_metric.compute())
    psnr_list.append(psnr_metric.compute()) 

    print(f"At index f{index},LPIPS: {Lpips_metric.compute()}, SSIM: {ssim_metric.compute()}, PSNR: {psnr_metric.compute()}")

avg_lpip = np.mean(Lpips_list)
avg_ssim = np.mean(ssim_list)   
avg_psnr = np.mean(psnr_list)

print(f"Average LPIPS: {avg_lpip}, Average SSIM: {avg_ssim}, Average PSNR: {avg_psnr}")

In [None]:
print(f"LPIPS: {np.mean(Lpips_list)} SSIM: {np.mean(ssim_list)} PSNR: {np.mean(psnr_list)}")

In [None]:
generator = torch.manual_seed(0)
image = pipe(
    prompt, num_inference_steps=100, generator=generator, image=control_image
).images[0]
image.save("/output.png")