In [None]:
%pip install lpips -q

import sys
sys.path.append('/Users/Serebryakova/Desktop/utils')

import torch
from lpips import LPIPS
from torchvision import transforms
from munch import Munch
from tqdm.auto import trange
import matplotlib.pyplot as plt
import numpy as np
from celeba import CelebADataset

In [None]:
compute_device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

params = Munch()
params.image_size = 256
params.batch = 32
params.domains = 40
params.lr = 0.0002
params.latent = 64

data_transforms = transforms.Compose([
    transforms.Resize(params.image_size),
    transforms.CenterCrop(params.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

face_data = CelebADataset(
    root_dir='/Users/Serebryakova/Desktop/celeba',
    transform=data_transforms)

loader_workers = 0 if compute_device.type == 'cuda' else 2
pin_memory = compute_device.type == 'cuda'

data_loader = torch.utils.data.DataLoader(
    face_data,
    batch_size=params.batch,
    num_workers=loader_workers,
    pin_memory=pin_memory,
    shuffle=True
)

fig, axes = plt.subplots(3, 3, figsize=(15, 15))
for idx, choice in enumerate(np.random.choice(len(face_data), 9)):
    sample, _ = face_data[choice]
    sample = (sample - sample.min()) / (sample.max() - sample.min())
    axes[idx//3][idx%3].imshow(sample.permute(1, 2, 0).cpu().numpy())

In [None]:
class ResidualLayer(torch.nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.core = torch.nn.Sequential(
            torch.nn.Conv2d(channels, channels, 3, padding=1),
            torch.nn.InstanceNorm2d(channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(channels, channels, 3, padding=1),
            torch.nn.InstanceNorm2d(channels))
    
    def forward(self, input):
        return input + self.core(input)

class ImageGenerator(torch.nn.Module):
    def __init__(self, channels=3, attributes=40, base=64):
        super().__init__()
        components = [
            torch.nn.Conv2d(channels + attributes, base, 7, padding=3),
            torch.nn.InstanceNorm2d(base),
            torch.nn.ReLU()]
        
        current = base
        for _ in range(2):
            components.append(torch.nn.Conv2d(current, current*2, 4, 2, 1))
            components.append(torch.nn.InstanceNorm2d(current*2))
            components.append(torch.nn.ReLU())
            current *= 2
        
        for _ in range(6):
            components.append(ResidualLayer(current))
        
        for _ in range(2):
            components.append(torch.nn.ConvTranspose2d(current, current//2, 4, 2, 1))
            components.append(torch.nn.InstanceNorm2d(current//2))
            components.append(torch.nn.ReLU())
            current //= 2
        
        components.extend([
            torch.nn.Conv2d(current, channels, 7, padding=3),
            torch.nn.Tanh()])
        self.layers = torch.nn.Sequential(*components)
    
    def forward(self, x, a):
        a = a.view(a.size(0), -1, 1, 1).expand(-1, -1, x.shape[2], x.shape[3])
        combined = torch.cat([x, a], 1)
        return self.layers(combined)

class ImageDiscriminator(torch.nn.Module):
    def __init__(self, channels=3, attributes=40, base=64):
        super().__init__()
        modules = [
            torch.nn.Conv2d(channels, base, 4, 2, 1),
            torch.nn.LeakyReLU(0.01)]
        
        current = base
        for _ in range(5):
            modules.append(torch.nn.Conv2d(current, current*2, 4, 2, 1))
            modules.append(torch.nn.LeakyReLU(0.01))
            current *= 2
        
        self.net = torch.nn.Sequential(*modules)
        self.validator = torch.nn.Conv2d(current, 1, 3, 1, 1)
        self.classifier = torch.nn.Conv2d(current, attributes, 3, 1, 1)
    
    def forward(self, x):
        features = self.net(x)
        validity = self.validator(features).flatten(1)
        domain = self.classifier(features).mean([2,3])
        return validity, domain

models = Munch()
models.generator = ImageGenerator().to(compute_device)
models.discriminator = ImageDiscriminator().to(compute_device)

adv_loss = torch.nn.MSELoss()
cls_loss = torch.nn.BCEWithLogitsLoss()
rec_loss = torch.nn.L1Loss()

gen_optim = torch.optim.Adam(models.generator.parameters(), lr=params.lr, betas=(0.5, 0.999))
disc_optim = torch.optim.Adam(models.discriminator.parameters(), lr=params.lr, betas=(0.5, 0.999))

for epoch in range(5):
    for batch, labels in data_loader:
        batch, labels = batch.to(compute_device), labels.float().to(compute_device)
        shuffled = labels[torch.randperm(labels.size(0))]
        
        disc_optim.zero_grad()
        real_valid, real_cls = models.discriminator(batch)
        loss_real = adv_loss(real_valid, torch.ones_like(real_valid))
        loss_cls = cls_loss(real_cls, labels)
        
        generated = models.generator(batch, shuffled)
        fake_valid, fake_cls = models.discriminator(generated.detach())
        loss_fake = adv_loss(fake_valid, torch.zeros_like(fake_valid))
        
        total_disc = loss_real + loss_fake + loss_cls
        total_disc.backward()
        disc_optim.step()
        
        gen_optim.zero_grad()
        fake_valid, fake_cls = models.discriminator(generated)
        loss_gen = adv_loss(fake_valid, torch.ones_like(fake_valid))
        loss_cls_gen = cls_loss(fake_cls, shuffled)
        
        reconstructed = models.generator(generated, labels)
        loss_rec = rec_loss(reconstructed, batch)
        
        total_gen = loss_gen + loss_cls_gen + 10 * loss_rec
        total_gen.backward()
        gen_optim.step()

perceptual_loss = LPIPS()

test_results = []
for _ in trange(100):
    origin, _ = next(iter(data_loader))
    reference, _ = next(iter(data_loader))
    
    batch_size = origin.size(0)
    target_domain = torch.randint(params.domains, (batch_size,)).to(compute_device)
    
    with torch.no_grad():
        noise = torch.randn(batch_size, params.latent, device=compute_device)
        style = models.mapping_network(noise, target_domain)
        synthetic = models.generator(origin, style)
    
    test_results.append(perceptual_loss(synthetic.cpu(), origin.cpu()).squeeze().item())

print("Average LPIPS:", np.mean(test_results))
assert np.mean(test_results) < 1.3

with torch.no_grad():
    random_noise = torch.randn(params.batch, params.latent).to(compute_device)
    target_domain = torch.randint(params.domains, (params.batch,)).to(compute_device)
    style = models.mapping_network(random_noise, target_domain)
    result = models.generator(origin.to(compute_device), style)

plt.figure(figsize=(6,6))
plt.imshow(result[1].permute(1,2,0).cpu().numpy())
plt.axis('off')
plt.show()

In [None]:
def attribute_morph(image, start_attr, end_attr, steps=7):
    repeated = image.unsqueeze(0).repeat(steps,1,1,1).to(compute_device)
    alpha = torch.linspace(0, 1, steps).view(-1,1).to(compute_device)
    mixed = start_attr * (1 - alpha) + end_attr * alpha
    
    with torch.no_grad():
        outputs = models.generator(repeated, mixed)
    
    fig, axs = plt.subplots(1, steps, figsize=(20,8))
    for i in range(steps):
        axs[i].imshow(outputs[i].permute(1,2,0).cpu().numpy().clip(0,1))
        axs[i].axis('off')
    plt.show()

sample_image, sample_attr = face_data[0]
attr_a = sample_attr.clone()
attr_b = sample_attr.clone()
attr_b[0] = 1 - attr_b[0]
attribute_morph(sample_image, attr_a, attr_b, 10)