# Adversarial Autoencoder

## Project Info
this notebook is an implementation Adversarial Autoencoders


# Header


 ## Import Necessary Packages

In [0]:
#Basic Level Libs
import math
import os 
import itertools

#Mid Level Libs
import numpy as np
import pickle

#FrameWorks
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import save_image


#Other
import matplotlib.pyplot as plt

## Mout Google Dirve

In [0]:
from google.colab import drive
drive.mount('/content/drive')

##config

In [0]:
cuda = torch.cuda.is_available()
seed = 10

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
n_classes = 10
z_dim = 2
X_dim = 784
y_dim = 10
train_batch_size = 100
valid_batch_size = 100
N = 1000
epochs = 500
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# Dataset 

In [0]:
class subMNIST(MNIST):

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False, k=3000):
        super(subMNIST, self).__init__(root, train, transform,
                                       target_transform, download)
        self.k = k
        self.test_dataa = None
        self.test_labelsa = None
    def __len__(self):
        if self.train:
            return self.k
        else:
            return 10000

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5))])
trainset = subMNIST(root='../data', train=True,
                    download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)


print(len(trainset))
print(len(trainloader))

In [0]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

trainset_original = datasets.MNIST('/content/', train=True, download=True,
                                   transform=transform)

train_label_index = []
valid_label_index = []
for i in range(10):
    train_label_list = trainset_original.train_labels.numpy()
    label_index = np.where(train_label_list == i)[0]
    label_subindex = list(label_index[:300])
    valid_subindex = list(label_index[300: 1000 + 300])
    train_label_index += label_subindex
    valid_label_index += valid_subindex

trainset_np = trainset_original.train_data.numpy()
trainset_label_np = trainset_original.train_labels.numpy()
train_data_sub = torch.from_numpy(trainset_np[train_label_index])
train_labels_sub = torch.from_numpy(trainset_label_np[train_label_index])

trainset_new = subMNIST(root='/content/', train=True, download=True, transform=transform, k=3000)
trainset_new.train_dataa = None#train_data_sub.clone()
trainset_new.train_labelsa = train_labels_sub.clone()

pickle.dump(trainset_new, open("/content/train_labeled.p", "wb"))

validset_np = trainset_original.train_data.numpy()
validset_label_np = trainset_original.train_labels.numpy()
valid_data_sub = torch.from_numpy(validset_np[valid_label_index])
valid_labels_sub = torch.from_numpy(validset_label_np[valid_label_index])

validset = subMNIST(root='/content/data', train=False, download=True, transform=transform, k=10000)
validset.test_dataa = valid_data_sub.clone()
validset.test_labelsa = valid_labels_sub.clone()

pickle.dump(validset, open("/content/validation.p", "wb"))

train_unlabel_index = []
for i in range(60000):
    if i in train_label_index or i in valid_label_index:
        pass
    else:
        train_unlabel_index.append(i)

trainset_np = trainset_original.train_data.numpy()
trainset_label_np = trainset_original.train_labels.numpy()
train_data_sub_unl = torch.from_numpy(trainset_np[train_unlabel_index])
train_labels_sub_unl = torch.from_numpy(trainset_label_np[train_unlabel_index])

trainset_new_unl = subMNIST(root='/content/data', train=True, download=True, transform=transform, k=47000)
trainset_new_unl.train_dataa = train_data_sub_unl.clone()
trainset_new_unl.train_labelsa = None      # Unlabeled

trainset_new_unl.train_labelsa

pickle.dump(trainset_new_unl, open("/content/train_unlabeled.p", "wb"))

In [0]:
def load_data(data_path='/content/'):
    print('loading data!')
    trainset_labeled = pickle.load(open(data_path + "train_labeled.p", "rb"))
    trainset_unlabeled = pickle.load(open(data_path + "train_unlabeled.p", "rb"))
    # Set -1 as labels for unlabeled data
    trainset_unlabeled.train_labelsaaa = torch.from_numpy(np.array([-1] * 47000))
    validset = pickle.load(open(data_path + "validation.p", "rb"))

    train_labeled_loader = torch.utils.data.DataLoader(trainset_labeled,
                                                       batch_size=train_batch_size,
                                                       shuffle=True, **kwargs)

    train_unlabeled_loader = torch.utils.data.DataLoader(trainset_unlabeled,
                                                         batch_size=train_batch_size,
                                                         shuffle=True, **kwargs)

    valid_loader = torch.utils.data.DataLoader(validset, batch_size=valid_batch_size, shuffle=True)

    return train_labeled_loader, train_unlabeled_loader, valid_loader


In [0]:
train_labeled_loader, train_unlabeled_loader, valid_loader = load_data()

##Displayer

In [0]:
def imshow(img):
    plt.imshow(np.transpose(img, (1, 2, 0)).squeeze())

In [0]:
def displaying_data(dataiter):
    # obtain one batch of training images
    images, labels = dataiter.next()
    images = images.numpy() # convert images to numpy for display

    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(25, 4))
    # display 20 images
    for idx in np.arange(20):
        ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])
        try:
            imshow(images[idx][0])
        except:
            imshow(images[idx])

        ax.set_title(str(int(labels[idx])))

In [0]:
iterdata = iter(train_labeled_loader)
displaying_data(iterdata)

## Image saver

In [0]:
base_path = '/content'
#!rm -rf '/content/drive/My Drive/images/'
if not os.path.exists(os.path.join(base_path, 'images')):
    os.mkdir(os.path.join(base_path, 'images'))

def sample_image(encoder, epoch):
    """Saves a grid of generated digits"""
    # Sample noise
    z = np.array([np.array([num,num2]) for num2 in range(-10,10) for num in range(-10,10)])
    z = Variable(Tensor(np.array(z)))

    gen_imgs = encoder(z.cuda())

    gen_imgs = gen_imgs.view(-1,1,28,28)

    save_image(gen_imgs.data, os.path.join(base_path, 'images/')+"%d.png" % epoch, nrow=20, normalize=True)

# Model

## Encoder

In [0]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.lin1 = nn.Linear(X_dim, N)
        self.lin2 = nn.Linear(N, N)
        # Gaussian code (z)
        self.lin3gauss = nn.Linear(N, z_dim)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)
        xgauss = self.lin3gauss(x)

        return xgauss

##Decoder

In [0]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lin1 = nn.Linear(z_dim, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, X_dim)

    def forward(self, x):
        x = self.lin1(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin3(x)
        return F.sigmoid(x)


## Discriminator

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lin1 = nn.Linear(z_dim, N)
        self.lin2 = nn.Linear(N, N)
        self.lin3 = nn.Linear(N, 1)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)

        return F.sigmoid(self.lin3(x))


# Workers

##Utils

In [0]:
def save_model(model, filename):
    print('Best model so far, saving it...')
    torch.save(model.state_dict(), filename)

In [0]:
def report_loss(epoch, D_loss_gauss, G_loss, recon_loss):
    '''
    Print loss
    '''
    print('Epoch-{}; D_loss_gauss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'.format(epoch, 
                                                                                   D_loss_gauss.data,
                                                                                   G_loss.data,
                                                                                   recon_loss.data))                 

##Train

In [0]:
!mkdir "/content/graph"
def create_latent(Q, loader, e):
    '''
    Creates the latent representation for the samples in loader
    return:
        z_values: numpy array with the latent representations
        labels: the labels corresponding to the latent representations
    '''
    Q.eval()
    labels = []

    for batch_idx, (X, target) in enumerate(loader):

        X = X * 0.3081 + 0.1307
        X.resize_(train_batch_size, X_dim)

        # X.resize_(loader.batch_size, X_dim)
        X, target = Variable(X), Variable(target)
        labels.extend(target.data.tolist())
        if cuda:
            X, target = X.cuda(), target.cuda()
        # Reconstruction phase
        z_sample = Q(X)
        if batch_idx > 0:
            z_values = np.concatenate((z_values, np.array(z_sample.data.tolist())))
        else:
            z_values = np.array(z_sample.data.tolist())
    labels = np.array(labels)
    import matplotlib.pyplot as plt
    %matplotlib inline

    plt.scatter(z_values[:,0], z_values[:, 1], c=labels)
    plt.savefig('/content/graph/{}.png'.format(e))
    

    return z_values, labels 

In [0]:
def train(P, Q, D_gauss, P_decoder, Q_encoder, Q_generator, D_gauss_solver, data_loader):
    '''
    Train procedure for one epoch.
    '''
    TINY = 1e-15
    # Set the networks in train mode (apply dropout when needed)
    Q.train()
    P.train()
    D_gauss.train()

    # Loop through the labeled and unlabeled dataset getting one batch of samples from each
    # The batch size has to be a divisor of the size of the dataset or it will return
    # invalid samples
    for X, target in data_loader:

        # Load batch and normalize samples to be between 0 and 1
        X = X * 0.3081 + 0.1307
        X.resize_(train_batch_size, X_dim)
        X, target = Variable(X), Variable(target)
        if cuda:
            X, target = X.cuda(), target.cuda()

        # Init gradients
        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()

        #######################
        # Reconstruction phase
        #######################
        z_sample = Q(X)
        X_sample = P(z_sample)
        recon_loss = F.binary_cross_entropy(X_sample + TINY, X.resize(train_batch_size, X_dim) + TINY)

        recon_loss.backward()
        P_decoder.step()
        Q_encoder.step()

        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()

        #######################
        # Regularization phase
        #######################
        # Discriminator
        Q.eval()
        z_real_gauss = Variable(Tensor(np.random.randn(train_batch_size, z_dim) * 5.))#Variable(n.sample((train_batch_size, z_dim))).squeeze()
        if cuda:
            z_real_gauss = z_real_gauss.cuda()

        z_fake_gauss = Q(X)
        D_real_gauss = D_gauss(z_real_gauss)
        D_fake_gauss = D_gauss(z_fake_gauss)

        D_loss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))

        D_loss.backward()
        D_gauss_solver.step()

        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()

        # Generator
        Q.train()
        z_fake_gauss = Q(X)

        D_fake_gauss = D_gauss(z_fake_gauss)
        G_loss = -torch.mean(torch.log(D_fake_gauss + TINY))

        G_loss.backward()
        Q_generator.step()

        P.zero_grad()
        Q.zero_grad()
        D_gauss.zero_grad()

    return D_loss, G_loss, recon_loss


def generate_model(train_labeled_loader, train_unlabeled_loader, valid_loader):
    best_loss = np.Inf
    torch.manual_seed(10)
    global P
    if cuda:
        Q = Encoder().cuda()
        P = Decoder().cuda()
        D_gauss = Discriminator().cuda()
    else:
        Q = Q_net()
        P = P_net()
        D_gauss = D_net_gauss()

    # Set learning rates
    gen_lr = 0.0001
    reg_lr = 0.00005

    # Set optimizators
    P_decoder = optim.Adam(P.parameters(), lr=gen_lr)
    Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)

    Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)
    D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)

    for epoch in range(700):
        sample_image(P,epoch)
        create_latent(Q, valid_loader,epoch) 

        D_loss_gauss, G_loss, recon_loss = train(P, Q, D_gauss, P_decoder, Q_encoder,
                                                 Q_generator,
                                                 D_gauss_solver,
                                                 train_unlabeled_loader)
        report_loss(epoch, D_loss_gauss, G_loss, recon_loss)
        if best_loss > G_loss:
            torch.save(P.state_dict(),'/content/P.pth')
            torch.save(Q.state_dict(),'/content/Q.pth')
            print("\rsaved",end='')
            best_loss = G_loss
    return Q, P

if __name__ == '__main__':
    train_labeled_loader, train_unlabeled_loader, valid_loader = load_data()    
    Q, P = generate_model(train_labeled_loader, train_unlabeled_loader, valid_loader)

# Test

In [0]:
dataiter = iter(train_labeled_loader)
data = dataiter.next()

In [0]:
print('orginal')
image = data[0][0]
plt.imshow(image[0])
plt.show()
save_image(image,'orginal.png')

print('image after encoding and decoding')
Q = Encoder().cuda()
Q.load_state_dict(torch.load("/content/Q.pth"))
P = Decoder().cuda()
P.load_state_dict(torch.load("/content/P.pth"))
encode = Q(transforms.ToTensor()(transforms.ToPILImage()(image)).view(-1,784 ).cuda())
plt.imshow(decode.cpu().view(28,28).detach().numpy())
plt.show()
save_image(decode.cpu().view(28,28),'encodeddecoded.png')

print('generated from distribution')
decode = P(torch.FloatTensor([0,-10]).cuda())
plt.imshow(decode.cpu().view(28,28).detach().numpy())
save_image(decode.cpu().view(28,28).detach(), 'dist.png')
plt.show()


#GIF

In [0]:
import glob
import imageio

In [0]:
with imageio.get_writer('AAE_GAN.gif', mode='I') as writer:
  
  files = os.listdir('/content/images')
  new_files = [int(f[:-4]) for f in files]
  new_files.sort()
  new_files = new_files[::3]
  filenames = ['/content/images/' + str(f)+'.png' for f in new_files]

  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
    
# this is a hack to display the gif inside the notebook
os.system('cp aaegan.gif AAE_GAN.gif.png')

In [0]:
with imageio.get_writer('AAE_GAN_graph.gif', mode='I') as writer:
  
  files = os.listdir('/content/graph')
  new_files = [int(f[:-4]) for f in files]
  new_files.sort()
  new_files = new_files[::3]
  filenames = ['/content/graph/' + str(f)+'.png' for f in new_files]

  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
    
# this is a hack to display the gif inside the notebook
os.system('cp aaegan.gif AAE_GAN_graph.gif.png')