# VAE-GAN 

## Преобразование одного изображения в другое

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import torchvision.transforms as T

## Предобработка  датасета

In [2]:
import os
from PIL import Image

def crop_center_square(img: Image.Image) -> Image.Image:
    width, height = img.size
    side = min(width, height)
    left = (width - side) // 2
    top = (height - side) // 2
    right = left + side
    bottom = top + side
    return img.crop((left, top, right, bottom))

def process_dataset(input_dir, output_dir, size=128):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for filename in os.listdir(input_dir):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            input_path = os.path.join(input_dir, filename)
            output_path = os.path.join(output_dir, filename)

            with Image.open(input_path) as img:
                img_cropped = crop_center_square(img)
                img_resized = img_cropped.resize((size, size), Image.Resampling.LANCZOS)
                if img_resized.mode == 'RGBA':
                    img_resized = img_resized.convert('RGB')
                img_resized.save(output_path)

            print(f"Processed {filename}")

# Пример использования
input_folder = '/home/maksim/develops/python/develops_test/dataset_1/train/'
output_folder = '/home/maksim/develops/python/develops_test/dataset_1/train_2/'

# process_dataset(input_folder, output_folder, size=128)


## Функции потерь

In [3]:
# ----- Функции потерь -----

# VAE loss: Reconstruction + KL divergence
def vae_loss_function(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')  # можно заменить на bce loss для бинарных данных
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + KLD, recon_loss, KLD

# GAN потери (бинарная кросс-энтропия)
# adversarial_loss = nn.BCELoss()
adversarial_loss = nn.BCEWithLogitsLoss()


# def vae_loss_function(recon_x, x, mu, logvar):
#     # Используем BCE loss для восстановления
#     recon_loss = F.mse_loss(recon_x, x, reduction='mean')
#     # KL дивергенция среднее по батчу
#     KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
#     total_vae_loss = recon_loss + KLD
#     return total_vae_loss, recon_loss, KLD




In [4]:
def discriminator_hinge_loss(real_out, fake_out):
    loss_real = torch.mean(F.relu(1. - real_out))
    loss_fake = torch.mean(F.relu(1. + fake_out))
    return loss_real + loss_fake



def generator_hinge_loss(fake_out):
    return -torch.mean(fake_out)


## Датасет

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os

# ----- Класс датасета: загрузка изображений из папки -----
class ImageFolderDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super().__init__()
        self.root_dir = root_dir
        self.filenames = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.filenames[idx])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

## Функция тренировки

In [6]:

# ----- Ре параметризация -----
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


# ----- Функция тренировки одной эпохи -----
def train_epoch(encoder, decoder, discriminator, dataloader, optim_enc_dec, optim_disc, device):
    encoder.train()
    decoder.train()
    discriminator.train()

    total_vae_loss = 0
    total_recon_loss = 0
    total_kld_loss = 0
    total_gan_disc_loss = 0
    total_gan_gen_loss = 0

    for batch_idx, imgs in enumerate(dataloader):
        imgs = imgs.to(device)

        batch_size = imgs.size(0)
        # real_labels = torch.ones(batch_size, 1, device=device)
        real_labels = torch.full((batch_size, 1), 0.9, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)

        # === Обучение дискриминатора ===
        optim_disc.zero_grad()
        # Реальные изображения
        real_out = discriminator(imgs)
        loss_real = adversarial_loss(real_out, real_labels)
        
        # Сгенерированные модели изображения
        mu, logvar = encoder(imgs)
        z = reparameterize(mu, logvar)
        fake_imgs = decoder(z)

        fake_out = discriminator(fake_imgs.detach())
        loss_fake = adversarial_loss(fake_out, fake_labels)

        loss_disc = (loss_real + loss_fake) * 0.5
        loss_disc.backward()
        optim_disc.step()

        # === Обучение энкодера и декодера (VAE + GAN) ===
        optim_enc_dec.zero_grad()
        
        mu, logvar = encoder(imgs)
        z = reparameterize(mu, logvar)
        recon_imgs = decoder(z)

        # VAE loss
        vae_loss, recon_loss, kld_loss = vae_loss_function(recon_imgs, imgs, mu, logvar)



        # GAN loss для генератора (декодера)
        gen_out = discriminator(recon_imgs)
        gan_gen_loss = adversarial_loss(gen_out, real_labels)  # хотим, чтобы сгенерированные проходили как "реальные"

        #1e-4, 1e-3, 1e-2
        alpha = 0.001
        total_gen_loss = vae_loss +  alpha * gan_gen_loss  # балансируем веса

        
       

        total_gen_loss.backward()
        optim_enc_dec.step()

        total_vae_loss += vae_loss.item()
        total_recon_loss += recon_loss.item()
        total_kld_loss += kld_loss.item()
        total_gan_disc_loss += loss_disc.item()
        total_gan_gen_loss += gan_gen_loss.item()

    n = len(dataloader.dataset)

    avg_vae_loss = total_vae_loss / n
    avg_gan_disc_loss = total_gan_disc_loss / len(dataloader)
    avg_gan_gen_loss = total_gan_gen_loss / len(dataloader)

    print(f"VAE Loss: {total_vae_loss/n:.4f}, Recon: {total_recon_loss/n:.4f}, KLD: {total_kld_loss/n:.4f}")
    print(f"Discriminator Loss: {total_gan_disc_loss/len(dataloader):.4f}, Generator Loss (GAN): {total_gan_gen_loss/len(dataloader):.4f}")

    return avg_vae_loss, avg_gan_disc_loss, avg_gan_gen_loss




## Модель

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os

# ----- Архитектура VAE-GAN -----

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # B,64, 64,64 for 128x128 inputs
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.fc_mu = nn.Linear(512*8*8, latent_dim)
        self.fc_logvar = nn.Linear(512*8*8, latent_dim)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv(x)
        x = x.view(batch_size, -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512*8*8)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # 256x16x16
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 128x32x32
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 64x64x64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),  # 3x128x128
            nn.Sigmoid()  # Чтобы на выходе были пиксели от 0 до 1  
            # nn.Tanh()  #  Если нормализация
        )

    def forward(self, z):
        batch_size = z.size(0)
        x = self.fc(z)
        x = x.view(batch_size, 512, 8, 8)
        x = self.deconv(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 64x64x64
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),  # 128x32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),  # 256x16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),  # 512x8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(512*8*8, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)



## Обучение

In [8]:
# ----- Пример запуска -----
if __name__ == "__main__":
    from torchvision.datasets import ImageFolder

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Параметры модели
    latent_dim = 128
    batch_size = 64
    image_size = 128  # предполагается квадратное изображение 128x128

    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ])


    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]


    # transform = transforms.Compose([
    #     transforms.Resize((image_size, image_size)),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=mean,  # средние значения каналов ImageNet
    #                         std=std)   # стандартные отклонения каналов ImageNet
    # ])


    inv_normalize = T.Normalize(
        mean=[-m/s for m, s in zip(mean, std)],
        std=[1/s for s in std]
    )

    # dataset = ImageFolderDataset('/home/maksim/develops/python/develops_test/dataset/face/3/train/', transform=transform)
    dataset = ImageFolderDataset('/home/maksim/develops/python/develops_test/dataset_1/train_2/', transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    encoder = Encoder(latent_dim).to(device)
    decoder = Decoder(latent_dim).to(device)
    discriminator = Discriminator().to(device)

    encoder.load_state_dict(torch.load('best_encoder.pth'))
    decoder.load_state_dict(torch.load('best_decoder.pth'))
    discriminator.load_state_dict(torch.load('best_discriminator.pth'))


    optim_enc_dec = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
    optim_disc = torch.optim.Adam(discriminator.parameters(), lr=1e-2)

    num_epochs = 150
    best_loss = float('inf')
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        avg_vae_loss, avg_gan_disc_loss, avg_gan_gen_loss = train_epoch(encoder, decoder, discriminator, dataloader, optim_enc_dec, optim_disc, device)

        if avg_vae_loss < best_loss:
            best_loss = avg_vae_loss
            
            torch.save(encoder.state_dict(), 'best_encoder.pth')
            torch.save(decoder.state_dict(), 'best_decoder.pth')
            torch.save(discriminator.state_dict(), 'best_discriminator.pth')
            
            print(f"Сохранена лучшая модель с VAE loss {best_loss:.4f}")


        # Опционально сохранять примеры реконструкций
        encoder.eval()
        decoder.eval()
        with torch.inference_mode():
            sample_imgs = next(iter(dataloader))
            sample_imgs = sample_imgs.to(device)
            mu, logvar = encoder(sample_imgs)
            z = reparameterize(mu, logvar)
            recon = decoder(z)
            # recon = inv_normalize(recon)  # Если нормализация
            # sample_imgs = inv_normalize(sample_imgs) # Если нормализация

            comparison = torch.cat([sample_imgs[:8], recon[:8]])
            save_image(comparison.cpu(), f'photo/reconstruction_epoch_{epoch+1}.png', nrow=8)

Epoch 1/150
VAE Loss: 295.2565, Recon: 160.3926, KLD: 134.8639
Discriminator Loss: 0.1673, Generator Loss (GAN): 11.1835
Сохранена лучшая модель с VAE loss 295.2565
Epoch 2/150
VAE Loss: 295.0482, Recon: 160.1835, KLD: 134.8647
Discriminator Loss: 0.1626, Generator Loss (GAN): 12.7851
Сохранена лучшая модель с VAE loss 295.0482
Epoch 3/150
VAE Loss: 294.8541, Recon: 160.0889, KLD: 134.7651
Discriminator Loss: 0.1626, Generator Loss (GAN): 12.9411
Сохранена лучшая модель с VAE loss 294.8541
Epoch 4/150
VAE Loss: 295.0699, Recon: 160.2789, KLD: 134.7910
Discriminator Loss: 0.1626, Generator Loss (GAN): 13.4160
Epoch 5/150
VAE Loss: 295.2850, Recon: 160.3585, KLD: 134.9265
Discriminator Loss: 0.1626, Generator Loss (GAN): 13.2448
Epoch 6/150
VAE Loss: 295.2674, Recon: 160.4749, KLD: 134.7925
Discriminator Loss: 0.1626, Generator Loss (GAN): 13.1520
Epoch 7/150
VAE Loss: 295.0436, Recon: 160.1593, KLD: 134.8842
Discriminator Loss: 0.1626, Generator Loss (GAN): 13.0430
Epoch 8/150
VAE Loss:

# Использование модели

## Загрузка модели

In [9]:
encoder = Encoder(latent_dim=128)
decoder = Decoder(latent_dim=128)
encoder.load_state_dict(torch.load('best_encoder.pth'))
decoder.load_state_dict(torch.load('best_decoder.pth'))
encoder.eval()
decoder.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder.to(device)
decoder.to(device)


Decoder(
  (fc): Linear(in_features=128, out_features=32768, bias=True)
  (deconv): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): Sigmoid()
  )
)

## Получение эмбедингов 

In [14]:
path_img_1 = "/home/maksim/develops/python/develops_test/dataset/face/2/images/8000.jpg" 
path_img_2 = "/home/maksim/develops/python/develops_test/dataset/face/2/images/8009.jpg"

In [10]:


path_img_1 = "/home/maksim/develops/python/develops_test/dataset/face/3/train/000002.jpg" 
path_img_2 = "/home/maksim/develops/python/develops_test/dataset/face/3/train/000008.jpg"

In [12]:
path_img_1 = "data/1.jpg"  # Путь к примеру 1
path_img_2 = "data/2.jpg"  # Путь к примеру 2

In [13]:
from PIL import Image
from torchvision import transforms

def crop_center_square(pil_img: Image.Image) -> Image.Image:
    width, height = pil_img.size
    side = min(width, height)
    left = (width - side) // 2
    top = (height - side) // 2
    right = left + side
    bottom = top + side
    cropped_img = pil_img.crop((left, top, right, bottom))
    return cropped_img




transform = transforms.Compose([
    transforms.Resize((128, 128)),  # под архитектуру
    transforms.ToTensor(),
])


square_img1 = crop_center_square(Image.open(path_img_1))
square_img2 = crop_center_square(Image.open(path_img_2))


img1 = transform(square_img1).unsqueeze(0).to(device)
img2 = transform(square_img2).unsqueeze(0).to(device)

with torch.no_grad():
    mu1, logvar1 = encoder(img1)
    z1 = mu1  # просто mu для детерминированной интерполяции
    mu2, logvar2 = encoder(img2)
    z2 = mu2

## Интерполяция

In [14]:
import numpy as np
import cv2
import imageio
import numpy as np

n_frames = 60  # число кадров в видео
alphas = np.linspace(0, 1, n_frames)

z1 = z1.cpu().numpy()
z2 = z2.cpu().numpy()
z_interp = [(1 - a) * z1 + a * z2 for a in alphas]
z_interp = torch.tensor(np.stack(z_interp)).to(device)


## Генерация видео

In [23]:
with torch.no_grad():
    generated = decoder(z_interp.float()).cpu()



import imageio
from torchvision.utils import save_image

imgs_np = (generated.permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)

target_size = (256, 256)  # желаемый размер GIF

resized_frames = []
for img in imgs_np:
    # img — uint8, RGB
    img_resized = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
    resized_frames.append(img_resized)

resized_frames = np.stack(resized_frames)
imageio.mimsave('video/interpolation.gif', resized_frames, duration=0.1)

imageio.mimsave('video/interpolation.mp4', resized_frames, fps=10)


import cv2
import numpy as np

frames = []
for img_tensor in generated:
    img = img_tensor.cpu().numpy().transpose(1, 2, 0)  # если CHW -> HWC
    img = (img * 255).astype(np.uint8)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    frames.append(img)

height, width, _ = frames[0].shape
video = cv2.VideoWriter('video/morphing_video.avi', cv2.VideoWriter_fourcc(*'XVID'), 10, (height, width), isColor=True)

for frame in frames:
    video.write(frame)

video.release()






torch.Size([1, 3, 128, 128])
torch.Size([1, 3, 128, 128])
torch.Size([1, 3, 128, 128])
torch.Size([1, 3, 128, 128])


# Gradio

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os

# ----- Архитектура VAE-GAN -----

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # B,64, 64,64 for 128x128 inputs
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.fc_mu = nn.Linear(512*8*8, latent_dim)
        self.fc_logvar = nn.Linear(512*8*8, latent_dim)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv(x)
        x = x.view(batch_size, -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512*8*8)

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),  # 256x16x16
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 128x32x32
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 64x64x64
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),  # 3x128x128
            nn.Sigmoid()  # Чтобы на выходе были пиксели от 0 до 1  
            # nn.Tanh()  #  Если нормализация
        )

    def forward(self, z):
        batch_size = z.size(0)
        x = self.fc(z)
        x = x.view(batch_size, 512, 8, 8)
        x = self.deconv(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 64x64x64
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),  # 128x32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),  # 256x16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),  # 512x8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(512*8*8, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)



In [3]:
import gradio as gr
import torch
from PIL import Image
import numpy as np
import tempfile
import cv2
import os

# --- Загрузка  модели ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder(latent_dim=128)
decoder = Decoder(latent_dim=128)
encoder.load_state_dict(torch.load('data/encoder.pth', map_location=device))
decoder.load_state_dict(torch.load('data/decoder.pth', map_location=device))
encoder.eval()
decoder.eval()
encoder.to(device)
decoder.to(device)

# --- Преобразования перед подачей в модель ---
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def center_crop_and_preprocess(img: Image.Image):
    # Центрированный квадрат по наим. стороне и resize
    w, h = img.size
    side = min(w, h)
    left = (w - side) // 2
    top = (h - side) // 2
    img = img.crop((left, top, left + side, top + side))
    img = img.resize((128, 128), Image.Resampling.LANCZOS)
    img = np.array(img).astype(np.float32) / 255.0
    # img = (img - mean) / std
    img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    return img.to(device)

def denormalize(img):
    # img: torch.Tensor, [B, C, H, W]
    mean_t = torch.tensor(mean, device=img.device).view(1, 3, 1, 1)
    std_t = torch.tensor(std, device=img.device).view(1, 3, 1, 1)
    img = img * std_t + mean_t
    return torch.clamp(img, 0, 1)

# --- Основная функция синтеза морф-видео ---
def morph_video(im1, im2, n_frames=60):
    # Предобработка
    t1 = center_crop_and_preprocess(im1)
    t2 = center_crop_and_preprocess(im2)
    with torch.no_grad():
        mu1, _ = encoder(t1)
        mu2, _ = encoder(t2)
    z1 = mu1.cpu().numpy()
    z2 = mu2.cpu().numpy()
    alphas = np.linspace(0, 1, n_frames)
    z_interp = [(1 - a) * z1 + a * z2 for a in alphas]
    z_interp = np.stack(z_interp)
    z_interp = torch.tensor(z_interp, dtype=torch.float32).to(device)
    with torch.no_grad():
        imgs = decoder(z_interp)
    # imgs = denormalize(imgs).cpu().numpy()  # [N, 3, 128, 128]
    imgs = imgs.cpu().numpy()
    
    # Собираем кадры для видео
    frames = []
    for img in imgs:
        img = np.transpose(img, (1, 2, 0))
        img = (img * 255).astype(np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        frames.append(img)

    # Записываем временный mp4-файл
    # tmp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=True)
    # fname = tmp_file.name
    temp_dir = "temp_videos"
    os.makedirs(temp_dir, exist_ok=True)
    fname = os.path.join(temp_dir, "temp_video.mp4")

    print(fname)
    # tmp_file.close()
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video = cv2.VideoWriter(fname, fourcc, 10, (128, 128), True)
    for f in frames:
        video.write(f)
    video.release()
    return fname

# --- Настройка Gradio интерфейса с дефолтными картинками ---
default1 = "data/1.jpg"  # Путь к примеру 1
default2 = "data/2.jpg"  # Путь к примеру 2

def gradio_wrapper(img1, img2):
    video_path = morph_video(img1, img2)
    return video_path




with gr.Blocks() as demo:
    gr.Markdown("### Модель морфинга между двумя изображениями VAE-GAN")
    with gr.Row():
        with gr.Column(scale=1):  # Левая колонка - 1/3 ширины
            image1 = gr.Image(label="Изображение 1", value=default1, type="pil", show_label=True, elem_id="img1_small")
            image2 = gr.Image(label="Изображение 2", value=default2, type="pil", show_label=True, elem_id="img2_small")
            generate_btn = gr.Button("Создать переход")
        with gr.Column(scale=2):  # Правая колонка - 2/3 ширины
            output_video = gr.Video(label="Морф-видео", format="mp4", autoplay=True, height=512, width=512)

    
    generate_btn.click(gradio_wrapper, inputs=[image1, image2], outputs=output_video)

demo.launch()




* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.




temp_videos/temp_video.mp4


