<a href="https://colab.research.google.com/github/M1croZavr/hackaton/blob/master/Image_reference_based_synthesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Установка репозитория и необходимых пакетов

In [None]:
!git clone https://github.com/M1croZavr/hackaton.git

In [None]:
!pip install munch

In [None]:
import torch
import torchvision
from torchvision import transforms
import os
import gdown
import pathlib
import zipfile
from matplotlib import pyplot as plt
from hackaton.StarGAN.stargan_v2.core import model

import warnings
warnings.filterwarnings('ignore')

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Загрузка натренированных параметров и тестовых данных

In [None]:
# Download test data
url = 'https://drive.google.com/uc?id=1Edn1eCVe_9_cacf13unumD1gRATcFD-j'
output = 'test_images.zip'
gdown.download(url, output, quiet=False)

# Download models parameters
url = 'https://drive.google.com/uc?id=1UF3wDFE30JWRE1Zidas0VTpYaJEckmRH'
output = 'GAN_weights'
gdown.download(url, output, quiet=False)

In [None]:
with zipfile.ZipFile('./test_images.zip') as archive:
    archive.extractall('./')

In [None]:
weights_dir = pathlib.Path('./GAN_weights')
test_images_dir = pathlib.Path('./validation_images')

In [None]:
generator = model.Generator(
    img_size=256,
    style_dim=64,
    max_conv_dim=512,
    w_hpf=0
)
generator.to(DEVICE)
mapper = model.MappingNetwork(
    latent_dim=16,
    style_dim=64,
    num_domains=3
)
mapper.to(DEVICE)
encoder = model.StyleEncoder(
    img_size=256,
    style_dim=64,
    num_domains=3,
    max_conv_dim=512
)
encoder.to(DEVICE)

module_dict = torch.load(weights_dir, map_location=torch.device(DEVICE))
print(f'Загруженные модели: {module_dict.keys()}')

generator.load_state_dict(module_dict['generator'])
mapper.load_state_dict(module_dict['mapping_network'])
encoder.load_state_dict(module_dict['style_encoder'])

## Синтез изображений из тестовой выборки, reference-based синтез патологий

In [None]:
transform = transforms.Compose([
    transforms.Resize([256, 256]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5]),
])
dataset = torchvision.datasets.ImageFolder(test_images_dir, transform)

In [None]:
x, y = dataset[112]
x_ref, y_ref = dataset[42]

plt.figure(figsize=(12, 9))
plt.subplot(1, 2, 1)
plt.title(f'Source image {dataset.classes[y]}')
plt.imshow(x.permute(1, 2, 0))
plt.axis(False)

plt.subplot(1, 2, 2)
plt.title(f'Reference image {dataset.classes[y_ref]}')
plt.imshow(x_ref.permute(1, 2, 0))
plt.axis(False);

In [None]:
x, y = x.unsqueeze(dim=0).to(DEVICE), torch.LongTensor([y])
x_ref, y_ref = x_ref.unsqueeze(dim=0).to(DEVICE), torch.LongTensor([y_ref])

In [None]:
with torch.inference_mode():
    encoder.eval()
    generator.eval()
    style_code = encoder(x_ref, y_ref)
    generated_image = generator(x, style_code)

### Синтезированное изображение патологии

In [None]:
plt.figure(figsize=(9, 6))
plt.imshow(generated_image.cpu().squeeze(dim=0).permute(1, 2, 0))
plt.title('Image-to-image translation from referenced image')
plt.axis(False);

In [None]:
with torch.inference_mode():
    mapper.eval()
    generator.eval()
    latent_code = torch.randn(1, 16)
    latent_code = latent_code.to(DEVICE)
    mapped_style_code = mapper(latent_code, y_ref)
    generated_image = generator(x, mapped_style_code)

In [None]:
plt.figure(figsize=(9, 6))
plt.imshow(generated_image.cpu().squeeze(dim=0).permute(1, 2, 0))
plt.title('Image-to-image translation from cancer random latent code')
plt.axis(False);