In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install git+https://github.com/huggingface/diffusers
!pip install accelerate wand
!pip install -r https://raw.githubusercontent.com/huggingface/diffusers/main/examples/text_to_image/requirements.txt

!accelerate config default

In [None]:
import os
import safetensors.torch
from diffusers import DiffusionPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleLinear
from peft import LoraConfig
import torch
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module

In [None]:
    saved_model_dir = '' # insert path for trained generation model, along with the checkpoint that performs best

    accelerator = Accelerator() # (mixed_precision=None)

    unet = UNet2DConditionModel.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', subfolder="unet")
    unet.requires_grad_(False)
    unet_lora_config = LoraConfig(
        r=4,
        lora_alpha=4,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    )
    unet.add_adapter(unet_lora_config)
    unet = accelerator.prepare(unet)

    pipeline = DiffusionPipeline.from_pretrained(
       saved_model_dir,
       unet=unet,
       torch_dtype=torch.float32,
    ).to("cuda")
    pipeline.safety_checker = None
    pipeline.requires_safety_checker = False

    # Load the accelerator state (this is essential line as it turns out!!!)
    accelerator.load_state(saved_model_dir)

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

prompts=["a histopathological image of an area with adipose",
          "a histopathological image of an area with mucus",
          "a histopathological image of an area with cancer-associated stroma",
          "a histopathological image of an area with smooth muscle",
          "a histopathological image of an area with colorectal adenocarcinoma epithelium",
          "a histopathological image of an area with lymphocytes",
          "a histopathological image of an area with debris",
          "a histopathological image of an area with background",
          "a histopathological image of an area with normal colon mucosa"]
'''
image = pipeline(prompt="a histopathological image of an area with adipose").images[0]
width, height = image.size
image = image.resize((int(width / 2), int(height / 2)))
display(image)
'''

fig, axes = plt.subplots(3, 3, figsize=(12, 12))  # 3x3 grid, larger figure size

for i, prompt in enumerate(prompts):
    image = pipeline(prompt).images[0]
    row = i // 3  # Calculate row index for 3x3 grid
    col = i % 3   # Calculate column index for 3x3 grid
    axes[row, col].imshow(image)
    axes[row, col].axis('off')
    axes[row, col].set_title(prompt, fontsize=10)

plt.subplots_adjust(wspace=1, hspace=0.5)  # Adjust spacing for 3x3 grid
plt.show()