In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Загрузка предобученной модели VAE (свою можно загрузить через .pth)
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(32*7*7, 64)  # Уменьшенный латентный размер
        self.fc_var = nn.Linear(32*7*7, 64)
        self.decoder = nn.Sequential(
            nn.Linear(64, 32*7*7),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)
    
    def decode(self, z):
        return self.decoder(z)


from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe.to("cuda")


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Fetching 18 files: 100%|██████████| 18/18 [13:36<00:00, 45.38s/it]
Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  4.94it/s]


StableDiffusionXLPipeline {
  "_class_name": "StableDiffusionXLPipeline",
  "_diffusers_version": "0.32.2",
  "_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
  "feature_extractor": [
    null,
    null
  ],
  "force_zeros_for_empty_prompt": true,
  "image_encoder": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "EulerDiscreteScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "text_encoder_2": [
    "transformers",
    "CLIPTextModelWithProjection"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "tokenizer_2": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [9]:
def preprocess_image(path, size=(28, 28)):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
    _, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    img = img.astype(np.float32) / 255.0  # Нормализация [0, 1]
    return torch.from_numpy(img).unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]

img1 = preprocess_image("image_test/image1.png")
img2 = preprocess_image("image_test/image2.png")

In [11]:
def interpolate(model, img1, img2, steps=10):
    mu1, _ = model.encode(img1)
    mu2, _ = model.encode(img2)
    
    interpolated = []
    for alpha in np.linspace(0, 1, steps):
        z = (1 - alpha) * mu1 + alpha * mu2
        decoded = model.decode(z).squeeze().detach().numpy()
        decoded = (decoded * 255).astype(np.uint8)  # [0,1] -> [0,255]
        interpolated.append(decoded)
    return interpolated

results = interpolate(vae, img1, img2, steps=10)

RuntimeError: Input type (float) and bias type (struct c10::Half) should be the same

In [None]:
plt.figure(figsize=(15, 3))
for i, img in enumerate(results):
    plt.subplot(1, len(results), i+1)
    plt.imshow(img, cmap='gray')
    plt.title(f"{i*10}%")
    plt.axis('off')
plt.savefig('interpolation.png')
plt.show()

# Сохранение каждого кадра
# for i, img in enumerate(results):
#     cv2.imwrite(f"pipe_step_{i}.png", img)