In [1]:
import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from tqdm.auto import tqdm
import glob

In [2]:
# Modify from https://github.com/drorsimon/image_barycenters
use_gpu = True if torch.cuda.is_available() else False
device = "cuda:2" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.benchmark = True

batch_size = 64

In [3]:
class myCelebDataset(Dataset):
    def __init__(self, root_name: str = "img_align_celeba", normalize = True, image_type="jpg"):
        # --------------------------------------------
        # Initialize paths, transforms, and so on
        # --------------------------------------------
        # if root_name != "src" and root_name != "adv_imgs":
        #     raise NotImplementError

        self.img_paths = []
        for path in glob.glob("./"+root_name+"/*."+image_type):

            img_PIL = Image.open(path)
            if(len(np.array(img_PIL).shape) == 3 and np.array(img_PIL).shape[2] == 3):
                self.img_paths.append(path)

        self.mean_rgb = None
        self.std_rgb = None
        
        
        if normalize:
            self.transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
        else:
            self.transforms = transforms.Compose([transforms.ToTensor()])
        
    def __getitem__(self, index):
        # --------------------------------------------
        # 1. Read from file (using numpy.fromfile, PIL.Image.open)
        # 2. Preprocess the data (torchvision.Transform).
        # 3. Return the data (e.g. image and label)
        # --------------------------------------------
        img_path = self.img_paths[index]
        img_PIL = Image.open(img_path)
        # original image shape is : [3, 102, 136]
        # I resize it to : [3, 64, 64] to have the same setting of https://github.com/drorsimon/image_barycenters/blob/master/generate_h5.py
        img_PIL = img_PIL.resize((64, 64), Image.ANTIALIAS)
        img_tensor = self.transforms(img_PIL)
        
       
        return img_tensor


    def __len__(self):
        # --------------------------------------------
        # Indicate the total size of the dataset
        # --------------------------------------------
        return len(self.img_paths)

In [None]:
trainset = myCelebDataset()
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                  shuffle=True, num_workers=72, worker_init_fn=np.random.seed(1234))


In [None]:
class Generator(torch.nn.Module):
    def __init__(self, latent_dim, num_filters):
        super(Generator, self).__init__()

        # Hidden layers
        self.hidden_layer = torch.nn.Sequential()
        for i in range(len(num_filters)):
            # Deconvolutional layer
            if i == 0:
                deconv = torch.nn.ConvTranspose2d(latent_dim, num_filters[i], kernel_size=4, stride=1, padding=0)
            else:
                deconv = torch.nn.ConvTranspose2d(num_filters[i-1], num_filters[i], kernel_size=4, stride=2, padding=1)

            deconv_name = 'deconv' + str(i + 1)
            self.hidden_layer.add_module(deconv_name, deconv)

            torch.nn.init.normal_(deconv.weight, mean=0.0, std=0.02)
            torch.nn.init.constant_(deconv.bias, 0.0)

            # Batch normalization
            bn_name = 'bn' + str(i + 1)
            self.hidden_layer.add_module(bn_name, torch.nn.BatchNorm2d(num_filters[i]))

            # Activation
            act_name = 'act' + str(i + 1)
            self.hidden_layer.add_module(act_name, torch.nn.ReLU())

        # Output layer
        self.output_layer = torch.nn.Sequential()
        
        # Deconvolutional layer
        out = torch.nn.ConvTranspose2d(num_filters[i], 3, kernel_size=4, stride=2, padding=1)
        self.output_layer.add_module('out', out)
        torch.nn.init.normal_(out.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(out.bias, 0.0)
        
        # Activation
        self.output_layer.add_module('act', torch.nn.Tanh())

    def forward(self, x):
        x = self.hidden_layer(x)
        out = self.output_layer(x)
        return out

# Discriminator model
class Discriminator(torch.nn.Module):
    def __init__(self, num_filters):
        super(Discriminator, self).__init__()

        # Hidden layers
        self.hidden_layer = torch.nn.Sequential()
        for i in range(len(num_filters)):
            # Convolutional layer
            if i == 0:
                conv = torch.nn.Conv2d(3, num_filters[i], kernel_size=4, stride=2, padding=1)
            else:
                conv = torch.nn.Conv2d(num_filters[i-1], num_filters[i], kernel_size=4, stride=2, padding=1)

            conv_name = 'conv' + str(i + 1)
            self.hidden_layer.add_module(conv_name, conv)

            # Initializer
            torch.nn.init.normal_(conv.weight, mean=0.0, std=0.02)
            torch.nn.init.constant_(conv.bias, 0.0)

            # Batch normalization
            if i > 0:
                bn_name = 'bn' + str(i + 1)
                self.hidden_layer.add_module(bn_name, torch.nn.BatchNorm2d(num_filters[i]))

            # Activation
            act_name = 'act' + str(i + 1)
            self.hidden_layer.add_module(act_name, torch.nn.LeakyReLU(0.2))

        # Output layer
        self.output_layer = torch.nn.Sequential()
        # Convolutional layer
        out = torch.nn.Conv2d(num_filters[i], 1, kernel_size=4, stride=1, padding=0)
        self.output_layer.add_module('out', out)
        # Initializer
        torch.nn.init.normal_(out.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(out.bias, 0.0)
        # Activation
        self.output_layer.add_module('act', torch.nn.Sigmoid())

    def forward(self, x):
        x = self.hidden_layer(x)
        out = self.output_layer(x)
        return out

# Encoder model
class Encoder(torch.nn.Module):
    def __init__(self, num_filters, latent_dim):
        super(Encoder, self).__init__()

        # Hidden layers
        self.hidden_layer = torch.nn.Sequential()
        for i in range(len(num_filters)):
            # Convolutional layer
            if i == 0:
                conv = torch.nn.Conv2d(3, num_filters[i], kernel_size=4, stride=2, padding=1)
            else:
                conv = torch.nn.Conv2d(num_filters[i-1], num_filters[i], kernel_size=4, stride=2, padding=1)

            conv_name = 'conv' + str(i + 1)
            self.hidden_layer.add_module(conv_name, conv)

            # Initializer
            torch.nn.init.normal_(conv.weight, mean=0.0, std=0.02)
            torch.nn.init.constant_(conv.bias, 0.0)

            # Batch normalization
            if i > 0:
                bn_name = 'bn' + str(i + 1)
                self.hidden_layer.add_module(bn_name, torch.nn.BatchNorm2d(num_filters[i]))

            # Activation
            act_name = 'act' + str(i + 1)
            self.hidden_layer.add_module(act_name, torch.nn.LeakyReLU(0.2))

        # Output layer
        self.output_layer = torch.nn.Sequential()
        # Convolutional layer
        out = torch.nn.Conv2d(num_filters[i], latent_dim, kernel_size=4, stride=1, padding=0)
        self.output_layer.add_module('out', out)
        # Initializer
        torch.nn.init.normal_(out.weight, mean=0.0, std=0.02)
        torch.nn.init.constant_(out.bias, 0.0)
        # Activation
        self.output_layer.add_module('bn_out', torch.nn.BatchNorm2d(latent_dim))

    def forward(self, x):
        x = self.hidden_layer(x)
        out = self.output_layer(x)
        return out

In [None]:
from torchvision.utils import make_grid
%matplotlib inline
def show(img, fig_size=(12,8)):
    npimg = img.numpy()
    fig = plt.figure(figsize=fig_size, dpi=100) 
    ax = fig.add_subplot(1, 1, 1)
    ax.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')    

In [None]:
G = Generator(latent_dim=100, num_filters=[1024, 512, 256, 128]).to(device)
D = Discriminator(num_filters=[128, 256, 512, 1024])
E = Encoder(num_filters=[128, 256, 512, 1024], latent_dim=100).to(device)

In [None]:
import torchvision.models as models

# I have no idea why using AlexNet for latent space = =
if torch.cuda.is_available():
    alexnet = models.alexnet(pretrained=True).to(device)
else:
    alexnet = models.alexnet(pretrained=True)
alexnet.eval()
for param in alexnet.parameters():
    param.requires_grad = False

In [None]:
G = G.to(device)
D = D.to(device)
E = E.to(device)

In [None]:
from torch import optim
# Train GAN

num_epochs = 20

# Loss function
criterion = torch.nn.BCELoss()

# Optimizers
G_optimizer = optim.Adam(G.parameters(), lr=1e-4, weight_decay=1e-5)
D_optimizer = optim.Adam(D.parameters(), lr=1e-4, weight_decay=1e-5)

# Schedulers
G_scheduler = optim.lr_scheduler.MultiStepLR(G_optimizer, milestones=[25,50,75])
D_scheduler = optim.lr_scheduler.MultiStepLR(D_optimizer, milestones=[25,50,75])

# loss arrays
D_avg_losses = []
G_avg_losses = []

save_dir = "BaryCenter"

num_test_samples = 16
latent_dim = 100
fixed_noise = torch.randn(num_test_samples, latent_dim, 1, 1).to(device)

In [None]:
for epoch in range(num_epochs):
        D_epoch_losses = []
        G_epoch_losses = []
        trange = tqdm(trainloader)
        for i, images in enumerate(trange):
#             if i > 10:
#                 break
            mini_batch = images.size()[0]
            x = images.to(device)

            y_real = torch.ones(mini_batch, device=device)
            y_fake = torch.zeros(mini_batch, device=device)

            # Train discriminator
            D_real_decision = D(x).squeeze()
            D_real_loss = criterion(D_real_decision, y_real)

            z = torch.randn(mini_batch, latent_dim, 1, 1).to(device)
            generated_images = G(z)

            D_fake_decision = D(generated_images).squeeze()
            D_fake_loss = criterion(D_fake_decision, y_fake)

            # Backprop
            D_loss = D_real_loss + D_fake_loss
            D.zero_grad()
            if i%2 == 0:  # Update discriminator only once every 2 batches
                D_loss.backward()
                D_optimizer.step()

            # Train generator
            z = torch.randn(mini_batch, latent_dim, 1, 1).to(device)
            generated_images = G(z)

            D_fake_decision = D(generated_images).squeeze()
            G_loss = criterion(D_fake_decision, y_real)

            # Backprop Generator
            D.zero_grad()
            G.zero_grad()
            G_loss.backward()
            G_optimizer.step()

            # loss values
            D_epoch_losses.append(D_loss.data.item())
            G_epoch_losses.append(G_loss.data.item())

            trange.set_description('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f'
                % (epoch+1, num_epochs, i+1, len(trainloader), D_loss.data.item(), G_loss.data.item()))

        D_avg_loss = torch.mean(torch.FloatTensor(D_epoch_losses)).item()
        G_avg_loss = torch.mean(torch.FloatTensor(G_epoch_losses)).item()
        D_avg_losses.append(D_avg_loss)
        G_avg_losses.append(G_avg_loss)
        
        G.eval()
        generated_images = G(fixed_noise).detach()
        G.train()
        
        show(make_grid(generated_images.cpu(), nrow=4, normalize=False, scale_each=False, range=(-1,1)))

        # Save models
        torch.save(G.state_dict(), os.path.join(save_dir,'generator'))
        torch.save(D.state_dict(), os.path.join(save_dir,'discriminator'))

        # Decrease learning-rate
        G_scheduler.step()
        D_scheduler.step()

In [None]:
# Train Encoder with noise
for param in G.parameters():
    param.requires_grad = False
        
criterion = torch.nn.MSELoss()

# Optimizer
E_optimizer = optim.Adam(E.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-5)

E_avg_losses = []

In [None]:
for epoch in range(num_epochs):
    E_losses = []

    # minibatch training
    trange = tqdm(trainloader)
    for i, images in enumerate(trange):
#         if i > 10:
#             break
        # generate_noise
        z = torch.randn(images.shape[0],latent_dim,1,1).to(device)
        x = G(z)

        # Train Encoder
        out_latent = E(x)
        E_loss = criterion(z, out_latent)

        # Back propagation
        E.zero_grad()
        E_loss.backward()
        E_optimizer.step()

        # loss values
        E_losses.append(E_loss.data.item())

        trange.set_description('Epoch [%d/%d], Step [%d/%d], E_loss: %.4f'
            % (epoch+1, num_epochs, i+1, len(trainloader), E_loss.data.item()))

    E_avg_loss = torch.mean(torch.FloatTensor(E_losses)).item()

    # avg loss values for plot
    E_avg_losses.append(E_avg_loss)

    # Save models
    torch.save(E.state_dict(), os.path.join(save_dir,'encoder'))

In [None]:
# finetune_encoder_with_samples

# load alexnet:
alexnet.eval()
for param in alexnet.parameters():
    param.requires_grad = False

G.eval()
for param in G.parameters():
    param.requires_grad = False

# Load encoder    
E.train()

# Loss function
criterion = torch.nn.MSELoss()

# Optimizers
E_optimizer = optim.Adam(E.parameters(), lr=1e-4, betas=(0.5, 0.999), weight_decay=1e-5)

E_avg_losses = []

In [None]:
def alexnet_norm(x): 
    assert x.max() <= 1 or x.min() >= 0, f"Alexnet received input outside of range [0,1]: {x.min(),x.max()}"
    out = x - torch.tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).type_as(x)
    out = out / torch.tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).type_as(x)
    return out
def denorm(x):
    return x/2+0.5

In [None]:
from torch.nn import functional as F
interpolate = lambda x: F.interpolate(x, scale_factor=4, mode='bilinear')
get_features = lambda x: alexnet.features(alexnet_norm(interpolate(denorm(x))))
for epoch in range(num_epochs):
    E_losses = []

    # minibatch training
    trange = tqdm(trainloader)
    for i, images in enumerate(trange):
#         if i > 10:
#             break

        # generate_noise
        mini_batch = images.size()[0]
        x = images.to(device)

        # Train Encoder
        out_images = G(E(x))
        E_loss = criterion(x, out_images) + 0.002*criterion(get_features(x), get_features(out_images))

        # Backprop
        E.zero_grad()
        E_loss.backward()
        E_optimizer.step()

        # loss values
        E_losses.append(E_loss.data.item())

        trange.set_description('Epoch [%d/%d], Step [%d/%d], E_loss: %.4f'
            % (epoch+1, num_epochs, i+1, len(trainloader), E_loss.data.item()))

    E_avg_loss = torch.mean(torch.FloatTensor(E_losses)).item()

    # avg loss values for plot
    E_avg_losses.append(E_avg_loss)

    # Save models
    torch.save(E.state_dict(), os.path.join(save_dir,'encoder'))