In [1]:
import torch
import pickle
import time
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.patches as mpatches

In [2]:
is_cuda = torch.cuda.is_available()

In [3]:
seed = 10

n_classes = 10
#dimension of z or latent representation
z_dimension = 10
#dimension of X or data
X_dimension = 784
#dimension of label of data
y_dimension = 10

TRAIN_BATCH_SIZE = 100
VALID_BATCH_SIZE = 1000
EPOCHS = 1000
N = 1000
TINY_ERROR = 1e-15
DATA_PATH = "/floyd/input/skripsi_datasets_2/"
cuda = torch.device('cuda')

training_reconstruction_loss = []
training_generator_loss = []
training_discriminator_loss = []
training_generator_sample = []

In [4]:
class Encoder_net(nn.Module):
    def __init__(self):
        super(Encoder_net, self).__init__()
        self.layer1 = nn.Linear(X_dimension, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, z_dimension)
    
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.5, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.layer2(x), p=0.5, training=self.training)
        x = F.relu(x)
        x_gauss = self.layer3(x)
        
        return x_gauss

class Decoder_net(nn.Module):
    def __init__(self):
        super(Decoder_net, self).__init__()
        self.layer1 = nn.Linear(z_dimension, N)
        self.layer2 = nn.Linear(N,N)
        self.layer3 = nn.Linear(N,X_dimension)
        
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.5, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.layer2(x), p=0.5, training=self.training)
        x = self.layer3(x)
        
        return F.sigmoid(x)

class Discriminator_net_gauss(nn.Module):
    def __init__(self):
        super(Discriminator_net_gauss, self).__init__()
        self.layer1 = nn.Linear(z_dimension, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, 1)
    
    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.5, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.layer2(x), p=0.5, training=self.training)
        x = F.relu(x)
        
        return F.sigmoid(self.layer3(x))

In [5]:
trainset_labeled = pickle.load(open(DATA_PATH + "train_labeled.p", "rb"))
trainset_unlabeled = pickle.load(open(DATA_PATH + "train_unlabeled.p", "rb"))
# Set -1 as labels for unlabeled data
trainset_unlabeled._train_labels = torch.from_numpy(np.array([-1] * 47000))
validset = pickle.load(open(DATA_PATH + "validation.p", "rb"))
train_labeled_loader = torch.utils.data.DataLoader(trainset_labeled,
                                                       batch_size=TRAIN_BATCH_SIZE,
                                                       shuffle=True)

train_unlabeled_loader = torch.utils.data.DataLoader(trainset_unlabeled,
                                                         batch_size=TRAIN_BATCH_SIZE,
                                                         shuffle=True)

valid_loader = torch.utils.data.DataLoader(validset, batch_size=VALID_BATCH_SIZE, shuffle=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/subMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/subMNIST/raw/train-images-idx3-ubyte.gz to ../data/subMNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/subMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/subMNIST/raw/train-labels-idx1-ubyte.gz to ../data/subMNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/subMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/subMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/subMNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/subMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/subMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/subMNIST/raw
Processing...
Done!
3000
750


In [6]:
def train_one_epoch(decoder, encoder, discriminator_gauss, decoder_optimizer, encoder_optimizer, generator_optimizer, discriminator_optimizer, data_loader):
    encoder = encoder.train()
    decoder = decoder.train()
    discriminator_gauss.train()
    
    discriminator_loss = None
    generator_loss = None
    reconstruction_loss = None
    
    for X, target in data_loader:
        X = X * 0.3081 + 0.1307
        X = X.resize(TRAIN_BATCH_SIZE, X_dimension)
        X, target = Variable(X), Variable(target)
        
        if is_cuda:
            X, target = X.cuda(cuda), target.cuda(cuda)
            
        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()
        
        #reconstruction phase
        z_sample = encoder(X)
        X_sample = decoder(z_sample)
        compared_with_original = X.resize(TRAIN_BATCH_SIZE, X_dimension)
        mse_loss = torch.nn.MSELoss()
        reconstruction_loss = mse_loss(X_sample + TINY_ERROR, compared_with_original + TINY_ERROR)
        
        reconstruction_loss.backward()
        decoder_optimizer.step()
        encoder_optimizer.step()
        
        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()
        
        #regularization phase
        #Train Discriminator
        encoder = encoder.eval()
        z_real_gauss = Variable(torch.empty(TRAIN_BATCH_SIZE, z_dimension).normal_(mean=0, std=1.0))
        
        if is_cuda:
            z_real_gauss = z_real_gauss.cuda(cuda)
        
        z_fake_gauss = encoder(X)
        
        discriminator_real_gauss = discriminator_gauss(z_real_gauss)
        discriminator_fake_gauss = discriminator_gauss(z_fake_gauss)
        
        discriminator_loss = 0.5 * (torch.mean((discriminator_real_gauss + TINY_ERROR - 1)**2) + torch.mean((discriminator_fake_gauss + TINY_ERROR)**2))
        
        discriminator_loss.backward()
        discriminator_optimizer.step()
        
        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()
        
        #Train Generator
        encoder = encoder.train()
        z_fake_gauss = encoder(X)
        
        generator_fake_gauss = discriminator_gauss(z_fake_gauss)
        generator_loss = 0.5 * torch.mean((generator_fake_gauss + TINY_ERROR - 1)**2)

        
        generator_loss.backward()
        generator_optimizer.step()
        
    return discriminator_loss, generator_loss, reconstruction_loss







In [7]:
def train_model(train_labeled_loader, train_unlabeled_loader, valid_loader):
    torch.manual_seed(10)

    if is_cuda:
        encoder = Encoder_net().cuda(cuda)
        decoder = Decoder_net().cuda(cuda)
        discriminator_gauss = Discriminator_net_gauss().cuda(cuda)
    else:
        encoder = Encoder_net()
        decoder = Decoder_net()
        discriminator_gauss = Discriminator_net_gauss()

    #learning rates for optimization
    learning_rate_1 = 0.0001
    learning_rate_2 = 0.00005

    #optimization for decoder and encoder
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate_1)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate_1)

    generator_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate_2)
    discriminator_optimizer = optim.Adam(discriminator_gauss.parameters(), lr=learning_rate_2)

    for epoch in range(EPOCHS):
        start_time = time.time()
        discriminator_loss, generator_loss, reconstruction_loss = train_one_epoch(decoder, encoder, discriminator_gauss, 
                                                                              decoder_optimizer, encoder_optimizer, generator_optimizer, 
                                                                              discriminator_optimizer, train_unlabeled_loader)
        
        epoch_time = time.time() - start_time
        if epoch % 10 == 0:
            training_reconstruction_loss.append(reconstruction_loss)
            training_generator_loss.append(generator_loss)
            training_discriminator_loss.append(discriminator_loss)
            print('Epoch-{}, Time-{:.2}, Discriminator_loss-{:.4}, Generator_loss-{:.4}, reconstruction_loss-{:.4}'.format(epoch, epoch_time, discriminator_loss.item(), generator_loss.item(), reconstruction_loss.item()))
    
        if epoch % 20 == 0:
            encoder = encoder.eval()
            decoder = decoder.eval()
            discriminator_gauss = discriminator_gauss.eval()
            
            X_test = None
            y_test = None

            for X, target in valid_loader:
                X_test = X
                y_test = target
                break

            if is_cuda:
                X_test = X_test.cuda(cuda)
                
            X_test = X_test.resize(VALID_BATCH_SIZE, X_dimension) 
            
            list_y_test = []
            for item in y_test:
                list_y_test.append(item.item())
            
            encoded_X_test = encoder(X_test)
            training_generator_sample.append(encoded_X_test)
            target_list = list_y_test

            figure = plt.figure()
            set_classes = set(target_list)
            color_map = plt.cm.rainbow(np.linspace(0, 1, len(set_classes)))
            axis = plt.subplot(111, aspect='equal')
            box = axis.get_position()
            axis.set_position([box.x0, box.y0, box.width * 0.8, box.height])
            handles = [mpatches.Circle((0, 0), label=class_, color=color_map[i]) for i, class_ in enumerate(set_classes)]
            axis.legend(handles=handles, shadow=True, bbox_to_anchor=(1.05, 0.45), fancybox=True, loc='center left')
            kwargs = {'alpha': 0.8, 'c': [color_map[i] for i in target_list]}
            encoded_X_test_cpu = encoded_X_test.cpu()
            plt.scatter(encoded_X_test_cpu[:, 0].detach().numpy(), encoded_X_test_cpu[:, 1].detach().numpy(), s = 2, **kwargs)
            axis.set_xlim([-10, 10])
            axis.set_ylim([-30, 30])

            plt.savefig('latent_space_standard_aae_least/epoch_%d.png' % epoch)
            plt.close('all')

            n_digits = 20
            decoded_X_test = decoder(encoder(X_test[:n_digits]))
            decoded_X_test_cpu = decoded_X_test.cpu()
            decoded_X_test_cpu = np.reshape(decoded_X_test_cpu.detach().numpy(), [-1, 28, 28]) * 255
            figure = plt.figure(figsize=(20, 4))

            for i in range (n_digits):
                axis = plt.subplot(2, n_digits, i + 1)
                X_test_cpu = X_test.cpu()
                plt.imshow(X_test_cpu[i].reshape(28, 28).detach().numpy())
                plt.gray()
                axis.get_xaxis().set_visible(False)
                axis.get_yaxis().set_visible(False)
                
                axis = plt.subplot(2, n_digits, i + 1 + n_digits)
                plt.imshow(decoded_X_test_cpu[i])
                plt.gray()
                axis.get_xaxis().set_visible(False)
                axis.get_yaxis().set_visible(False)

            plt.savefig('reconstruction_standard_aae_least/epoch_%d.png' % epoch)
            plt.close('all')

            z_sampling = [np.linspace(-3, 3, 10) for i in range (10)]

            n_x, n_y = 10, 10
            plt.subplot()
            grid_spec = gridspec.GridSpec(n_x, n_y, hspace=0.05, wspace=0.05)

            for i, j in enumerate(grid_spec):
                latent_variable = np.concatenate([[z_i[np.random.randint(10)]] for z_i in z_sampling])
                latent_variable = np.reshape(latent_variable, (-1, z_dimension))
                latent_variable = torch.from_numpy(latent_variable).float().cuda(cuda)
                
                reconstructed_x = decoder(latent_variable)
                reconstructed_x_cpu = reconstructed_x.cpu()
                axis = plt.subplot(j)
                image = np.array(reconstructed_x_cpu.detach().numpy().tolist()).reshape(28, 28)
                axis.imshow(image, cmap='gray')
                axis.set_xticks([])
                axis.set_yticks([])
                axis.set_aspect('auto')
            
            plt.savefig('sampling_standard_aae_least/epoch_%d.png' % epoch)
            plt.close('all')

            encoder = encoder.train()
            decoder = decoder.train()
            discriminator_gauss = discriminator_gauss.train()
            
    return encoder, decoder

In [8]:
trained_encoder, trained_decoder = train_model(train_labeled_loader, train_unlabeled_loader, valid_loader)




Epoch-0, Time-2.2e+01, Discriminator_loss-0.2448, Generator_loss-0.3138, reconstruction_loss-0.05374
Epoch-10, Time-1.9e+01, Discriminator_loss-0.236, Generator_loss-0.1348, reconstruction_loss-0.02998
Epoch-20, Time-1.9e+01, Discriminator_loss-0.2354, Generator_loss-0.1492, reconstruction_loss-0.02731
Epoch-30, Time-1.8e+01, Discriminator_loss-0.2275, Generator_loss-0.135, reconstruction_loss-0.02785
Epoch-40, Time-1.9e+01, Discriminator_loss-0.2298, Generator_loss-0.1612, reconstruction_loss-0.02779
Epoch-50, Time-1.9e+01, Discriminator_loss-0.221, Generator_loss-0.1558, reconstruction_loss-0.02406
Epoch-60, Time-1.8e+01, Discriminator_loss-0.213, Generator_loss-0.1459, reconstruction_loss-0.02613
Epoch-70, Time-1.9e+01, Discriminator_loss-0.2024, Generator_loss-0.1487, reconstruction_loss-0.02359
Epoch-80, Time-1.9e+01, Discriminator_loss-0.1937, Generator_loss-0.1795, reconstruction_loss-0.02205
Epoch-90, Time-1.9e+01, Discriminator_loss-0.2036, Generator_loss-0.1675, reconstructio

In [9]:
file_location_encoder = "least-standard-encoder.pt"
file_location_decoder = "least-standard-decoder.pt"
torch.save(trained_encoder.state_dict(), file_location_encoder)
torch.save(trained_decoder.state_dict(), file_location_decoder)