In [None]:
import sys
from pathlib import Path
import os

sys.path.append('..')

import torch
from diffusers import UniPCMultistepScheduler

from diffusionsat import SatUNet, DiffusionSatPipeline, metadata_normalize

In [None]:
# cache location (optional) and device/dtype helpers
os.environ["HF_HOME"] = "path/to/.cache/"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

# Change caption/metadata here

In [None]:
caption = "a fmow satellite image of a amusement park in Australia"
# Normalized metadata: [longitude, latitude, gsd, cloud cover, year, month, day]
metadata = [ 925.8798,  345.2111, 411.4541,    0.0000,  308.3333,  166.6667, 354.8387]

# Or provide metadata values and then normalize
caption = "a fmow satellite image of a electric substation in India"
metadata = metadata_normalize([76.5712666476, 28.6965307997, 0.929417550564, 0.0765712666476, 2015, 2, 27]).tolist()

# Pipe 1: No finetuning

In [None]:
base_model_id = "stabilityai/stable-diffusion-2-1"
unet1 = SatUNet.from_pretrained(base_model_id, subfolder="unet", use_metadata=False, torch_dtype=DTYPE)
pipe1 = DiffusionSatPipeline.from_pretrained(base_model_id, unet=unet1, torch_dtype=DTYPE)
pipe1.scheduler = UniPCMultistepScheduler.from_config(pipe1.scheduler.config)
pipe1 = pipe1.to(DEVICE)

#### Prompt pipe 1

In [None]:
image = pipe1(caption, num_inference_steps=50, guidance_scale=7.5, height=512, width=512).images[0]
image

# Pipe: Finetuning with metadata, SNR 5

In [None]:
# If running on CPU, `DTYPE` will fallback to float32 and CUDA moves are skipped.

checkpoint_root = Path("./checkpoints_diffusionsat")
model_dir = checkpoint_root / "finetune_sd21_sn-satlas-fmow_snr5_md7norm_bs64"
unet = SatUNet.from_pretrained(model_dir / "checkpoint-150000", subfolder="unet", torch_dtype=DTYPE)
pipe = DiffusionSatPipeline.from_pretrained(model_dir, unet=unet, torch_dtype=DTYPE)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)

#### Prompt Pipe

Play around with guidance scale and number of inference steps to generate images

In [None]:
image = pipe(caption, metadata=metadata, num_inference_steps=20, guidance_scale=7.5, height=512, width=512).images[0]
image