In [None]:
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
import torch
import os
import gc
from PIL import Image

from src.linfusion import LinFusion

In [None]:
pipeline = StableDiffusionPipeline.from_pretrained(
    "Lykon/dreamshaper-8", torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

In [None]:
linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path="Yuanshi/LinFusion-1-5")

In [None]:
if not os.path.exists('results'):
    os.mkdir('results')

In [None]:
generator = torch.manual_seed(3)
image = pipeline(
    "A photo of the Milky Way galaxy",
    height=512,
    width=1024,
    generator=generator
).images[0]
image.save('results/output_1k.jpg')
image

In [None]:
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    "Lykon/dreamshaper-8", torch_dtype=torch.float16, variant="fp16"
).to(torch.device("cuda"))

In [None]:
linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path="Yuanshi/LinFusion-1-5")

In [None]:
init_image = image.resize((2048, 1024))
generator = torch.manual_seed(3)
image = pipeline(
    "A photo of the Milky Way galaxy",
    image=init_image, strength=0.4, generator=generator).images[0]
image.save('results/output_2k.jpg')
image

In [None]:
pipeline.enable_vae_tiling()
pipeline.vae.tile_sample_min_size = 2048
pipeline.vae.tile_latent_min_size = 2048 // 8

In [None]:
init_image = image.resize((4096, 2048))
generator = torch.manual_seed(3)
image = pipeline(
    "A photo of the Milky Way galaxy",
    image=init_image, strength=0.3, generator=generator).images[0]
image.save('results/output_4k.jpg')
image

In [None]:
init_image = image.resize((8192, 4096))
generator = torch.manual_seed(3)
image = pipeline(
    "A photo of the Milky Way galaxy",
    image=init_image, strength=0.2, generator=generator).images[0]
image.save('results/output_8k.jpg')

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
init_image = image.resize((16384, 8192))
generator = torch.manual_seed(3)
image = pipeline(
    "A photo of the Milky Way galaxy",
    image=init_image, strength=0.1, generator=generator).images[0]
image.save('results/output_16k.jpg')