In [1]:
import torch
import pickle
import time
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.patches as mpatches
from skimage import exposure
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision import utils

In [2]:
is_cuda = torch.cuda.is_available()

In [3]:
seed = 10

n_classes = 10
#dimension of z or latent representation
z_dimension = 3
#dimension of X or data
X_dimension = 784
#dimension of label of data
y_dimension = 10

TRAIN_BATCH_SIZE = 100
VALID_BATCH_SIZE = 10000
EPOCHS = 350
N = 1000
TINY_ERROR = 1e-15
DATA_PATH = "/floyd/input/skripsi_datasets_2/"
cuda = torch.device('cuda')

training_reconstruction_loss = []
training_generator_loss = []
training_discriminator_loss = []
training_generator_sample = []

In [4]:
train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "../../home/Data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(28), transforms.ToTensor()]
        ),
    ),
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
)

valid_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "../../home/Data/mnist",
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(28), transforms.ToTensor()]
        ),
    ),
    batch_size=VALID_BATCH_SIZE,
    shuffle=True,
)


class Convolutional_Net(nn.Module):
    def __init__(self):
        super(Convolutional_Net, self).__init__()
        self.convolutional1 = nn.Conv2d(1, 20, 5, 1)
        self.convolutional2 = nn.Conv2d(20, 50, 5 , 1)
        self.linear1 = nn.Linear(4*4*50, 500)
        self.linear2 = nn.Linear(500, 10)
    
    def forward(self, x):
        x = F.relu(self.convolutional1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.convolutional2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return F.log_softmax(x, dim=1)

with torch.no_grad():
    classifier = Convolutional_Net()
    classifier.load_state_dict(torch.load("fashionmnist.pt"))
    classifier.cuda(cuda)
    classifier.eval()

In [5]:
class Encoder_net(nn.Module):
    def __init__(self):
        super(Encoder_net, self).__init__()
        self.layer1 = nn.Linear(X_dimension, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, z_dimension)

    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.5, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.layer2(x), p=0.5, training=self.training)
        x = F.relu(x)
        x = self.layer3(x)

        return x


# Decoder
class Decoder_net(nn.Module):
    def __init__(self):
        super(Decoder_net, self).__init__()
        self.layer1 = nn.Linear(z_dimension + n_classes, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, X_dimension)

    def forward(self, x):
        x = self.layer1(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.dropout(x, p=0.5, training=self.training)
        s = F.relu(x)
        x = self.layer3(x)
        return F.sigmoid(x)

class Discriminator_net_gauss(nn.Module):
    def __init__(self):
        super(Discriminator_net_gauss, self).__init__()
        self.layer1 = nn.Linear(z_dimension, N)
        self.layer2 = nn.Linear(N, N)
        self.layer3 = nn.Linear(N, 1)

    def forward(self, x):
        x = F.dropout(self.layer1(x), p=0.5, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.layer2(x), p=0.5, training=self.training)
        x = F.relu(x)

        return self.layer3(x)

    
def free_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def frozen_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [6]:
Tensor = torch.cuda.FloatTensor
def train_one_epoch(decoder, encoder, discriminator_gauss, decoder_optimizer, encoder_optimizer, generator_optimizer, discriminator_optimizer, data_loader):
    #'''
    encoder.train()
    decoder.train()
    discriminator_gauss.train()

    for X, target in data_loader:
        #X = X * 0.3081 + 0.1307
        X.resize_(TRAIN_BATCH_SIZE, X_dimension)
        X, target = Variable(X), Variable(target)
        if cuda:
            X, target = X.cuda(cuda), target.cuda(cuda)

        # Init gradients
        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()


        z_gauss = encoder(X)
        
        category = np.array(target.data.tolist())
        category = np.eye(n_classes)[category].astype('float32')
        category = torch.from_numpy(category)
        z_category = Variable(category)
        
        if cuda:
            z_category = z_category.cuda(cuda)

        z_sample = torch.cat((z_category, z_gauss), 1)

        X_sample = decoder(z_sample)
        compared_with_original = X.resize(TRAIN_BATCH_SIZE, X_dimension)
        mse_loss = torch.nn.MSELoss()
        reconstruction_loss = 0.4 * mse_loss(X_sample + TINY_ERROR, compared_with_original + TINY_ERROR)
        
        reconstruction_loss.backward()
        decoder_optimizer.step()
        encoder_optimizer.step()

        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()

        # Discriminator
        encoder.eval()
        z_real_gauss = Variable(torch.randn(TRAIN_BATCH_SIZE, z_dimension) * 5.)
        if cuda:
            z_real_gauss = z_real_gauss.cuda(cuda)

        z_fake_gauss = encoder(X)

        discriminator_real_gauss = discriminator_gauss(z_real_gauss)
        discriminator_fake_gauss = discriminator_gauss(z_fake_gauss)

        discriminator_loss = 0.5 * (torch.mean((discriminator_real_gauss + TINY_ERROR - 1)**2) + torch.mean((discriminator_fake_gauss + TINY_ERROR)**2))

        discriminator_loss.backward()
        discriminator_optimizer.step()

        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()

        # Generator
        encoder = encoder.train()
        z_fake_gauss = encoder(X)

        generator_fake_gauss = discriminator_gauss(z_fake_gauss)
        generator_loss = 0.5 * torch.mean((generator_fake_gauss + TINY_ERROR - 1)**2)

        generator_loss.backward()
        generator_optimizer.step()

        decoder.zero_grad()
        encoder.zero_grad()
        discriminator_gauss.zero_grad()
    #'''
    '''
    encoder.train()
    decoder.train()
    discriminator_gauss.train()
    
    reconstruction_loss = None
    real_loss = None
    fake_loss = None
    generator_loss = None
    #reconstruction_loss = None
    
    for X, target in train_loader:
        encoder.zero_grad()
        decoder.zero_grad()
        discriminator_gauss.zero_grad()
        
        valid = Variable(Tensor(TRAIN_BATCH_SIZE, 1).fill_(0.9), requires_grad=False)
        fake = Variable(Tensor(TRAIN_BATCH_SIZE, 1).fill_(0.1), requires_grad=False)
        
        X = X.resize(TRAIN_BATCH_SIZE, 784)
        X, target = Variable(X), Variable(target)
        
        if is_cuda:
            X, target = X.cuda(cuda), target.cuda(cuda)
              
        
        mse_loss = torch.nn.MSELoss()
        
        frozen_params(decoder)
        frozen_params(encoder)
        free_params(discriminator_gauss)
        
        z_real = torch.randn(TRAIN_BATCH_SIZE, z_dimension) * 1.0
        z_real = z_real.cuda(cuda)
        real_value = discriminator_gauss(z_real)
        
        z_fake = encoder(X)
                
        fake_value = discriminator_gauss(z_fake)
        
        real_loss = 0.01 * mse_loss(real_value + TINY_ERROR, valid)
        fake_loss = 0.01 * mse_loss(fake_value + TINY_ERROR, fake)
        
        real_loss.backward()
        fake_loss.backward()
        
        discriminator_optimizer.step()
        
        free_params(decoder)
        free_params(encoder)
        frozen_params(discriminator_gauss)
        
        z = encoder(X)
        
        
        category = np.array(target.data.tolist())
        category = np.eye(n_classes)[category].astype('float32')
        category = torch.from_numpy(category)
        z_category = Variable(category)
        
        if cuda:
            z_category = z_category.cuda(cuda)

        #print(z_category.size())
        #z_category = z_category.resize(TRAIN_BATCH_SIZE, 1, n_classes)
        z_with_label = torch.cat((z_category, z), 1)
        
        
        
        x_hat = decoder(z_with_label)
        
        z_2 = encoder(Variable(X.data))
        z_2_dis = discriminator_gauss(z_2)
        
        x_hat_resize = x_hat.view(-1, 784)
        X_resize = X.view(-1, 784)
        
        reconstruction_loss = 0.25 * mse_loss(x_hat_resize + TINY_ERROR, X_resize)
        generator_loss = 0.01 * mse_loss(z_2_dis + TINY_ERROR, valid)
        
        reconstruction_loss.backward()
        generator_loss.backward()
        
        encoder_optimizer.step()
        decoder_optimizer.step()
    '''
    return discriminator_loss, generator_loss, reconstruction_loss
    #return real_loss + fake_loss, generator_loss, reconstruction_loss

In [7]:
def output_label(label):
    output_mapping = {
                 0: "Kaos",
                 1: "Cln Pnjg",
                 2: "Pullover",
                 3: "Gaun",
                 4: "Mantel", 
                 5: "Sandal", 
                 6: "Kemeja",
                 7: "Sneaker",
                 8: "Tas",
                 9: "Ankle Boot"
                 }
    input = (label.item() if type(label) == torch.Tensor else label)
    return output_mapping[input]

def train_model(train_loader, valid_loader):
    torch.manual_seed(10)

    if cuda:
        encoder = Encoder_net().cuda(cuda)
        decoder = Decoder_net().cuda(cuda)
        discriminator_gauss = Discriminator_net_gauss().cuda(cuda)
    else:
        encoder = Encoder_net()
        decoder = Decoder_net()
        discriminator_gauss = Discriminator_net_gauss()

    #learning rates for optimization
    learning_rate_1 = 0.0002
    learning_rate_2 = 0.0002

    #optimization for decoder and encoder
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate_1)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate_1)

    generator_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate_1)
    discriminator_optimizer = optim.Adam(discriminator_gauss.parameters(), lr=learning_rate_1)

    for epoch in range(EPOCHS):
        start_time = time.time()
        discriminator_loss, generator_loss, reconstruction_loss = train_one_epoch(decoder, encoder, discriminator_gauss, 
                                                                              decoder_optimizer, encoder_optimizer, generator_optimizer, 
                                                                              discriminator_optimizer, train_loader)
        
        epoch_time = time.time() - start_time
        if epoch % 1 == 0:
            training_reconstruction_loss.append(reconstruction_loss)
            training_generator_loss.append(generator_loss)
            training_discriminator_loss.append(discriminator_loss)
            print('Epoch-{}, Time-{:.2}, Discriminator_loss-{:.4}, Generator_loss-{:.4}, reconstruction_loss-{:.4}'.format(epoch, epoch_time, discriminator_loss.item(), generator_loss.item(), reconstruction_loss.item()))
    
        
        if epoch % 1 == 0:
            with torch.no_grad():
                encoder = encoder.eval()
                decoder = decoder.eval()
                discriminator_gauss = discriminator_gauss.eval()

                X_test = None
                y_test = None

                for X, target in valid_loader:
                    X_test = X
                    y_test = target
                    break

                if is_cuda:
                    X_test = X_test.cuda(cuda)

                X_test = X_test.resize(VALID_BATCH_SIZE, X_dimension) 

                list_y_test = []
                for item in y_test:
                    list_y_test.append(item.item())


                encoded_X_test = encoder(X_test)
                training_generator_sample.append(encoded_X_test)
                target_list = list_y_test

                '''
                figure = plt.figure()
                set_classes = set(target_list)
                color_map = plt.cm.rainbow(np.linspace(0, 1, len(set_classes)))
                axis = plt.subplot(111, aspect='equal')
                box = axis.get_position()
                axis.set_position([box.x0, box.y0, box.width * 0.8, box.height])
                handles = [mpatches.Circle((0, 0), label=class_, color=color_map[i]) for i, class_ in enumerate(set_classes)]
                axis.legend(handles=handles, shadow=True, bbox_to_anchor=(1.05, 0.45), fancybox=True, loc='center left')
                kwargs = {'alpha': 0.8, 'c': [color_map[i] for i in target_list]}
                encoded_X_test_cpu = encoded_X_test.cpu()
                plt.scatter(encoded_X_test_cpu[:, 0].detach().numpy(), encoded_X_test_cpu[:, 1].detach().numpy(), s = 2, **kwargs)
                axis.set_xlim([-20, 20])
                axis.set_ylim([-20, 20])

                plt.savefig('2_latent_space_supervised_aae_least/epoch_%d.png' % epoch)
                plt.close('all')
                '''

                n_digits = 20
                #decoded_X_test = decoder(encoder(X_test[:n_digits]))

                category_test = np.array(y_test.numpy().data.tolist())
                category_test = np.eye(n_classes)[category_test].astype('float32')
                category_test = torch.from_numpy(category_test)
                
                z_category_test = Variable(category_test[:n_digits])
                encoded_X_test = encoder(X_test[:n_digits])

                if is_cuda:
                    z_category_test = z_category_test.cuda(cuda)

                encoded_X_test = torch.cat((z_category_test, encoded_X_test), 1)
                decoded_X_test = decoder(encoded_X_test)

                resized_decoded_X_test = decoded_X_test.resize(n_digits, 1, 28, 28)
                resized_decoded_X_test = resized_decoded_X_test.cuda(cuda)
                label_decoded = classifier(resized_decoded_X_test)
                label_decoded = label_decoded.argmax(dim=1, keepdim=True)
                label_decoded = torch.flatten(label_decoded)

                original_X = X_test[:n_digits]
                resized_original_X = original_X.resize(n_digits, 1, 28, 28)
                resized_original_X = resized_original_X.cuda(cuda)
                target_original_X = classifier(resized_original_X)
                target_original_X = target_original_X.argmax(dim=1, keepdim=True)
                target_original_X = torch.flatten(target_original_X)

                decoded_label_cpu = label_decoded.cpu().detach().numpy()
                decoded_target_original_X = target_original_X.cpu().detach().numpy()
                decoded_X_test_cpu = decoded_X_test.cpu()
                decoded_X_test_cpu = np.reshape(decoded_X_test_cpu.detach().numpy(), [-1, 28, 28]) * 255
                figure = plt.figure(figsize=(20, 4))

                for i in range (n_digits):
                    axis = plt.subplot(2, n_digits, i + 1)
                    axis.set_title(output_label(decoded_target_original_X[i]))
                    X_test_cpu = X_test.cpu()
                    plt.imshow(X_test_cpu[i].reshape(28, 28).detach().numpy())
                    plt.gray()
                    axis.get_xaxis().set_visible(False)
                    axis.get_yaxis().set_visible(False)

                    axis = plt.subplot(2, n_digits, i + 1 + n_digits)
                    axis.set_title(output_label(decoded_label_cpu[i]))
                    plt.imshow(decoded_X_test_cpu[i])
                    plt.gray()
                    axis.get_xaxis().set_visible(False)
                    axis.get_yaxis().set_visible(False)

                plt.savefig('3_reconstruction_supervised_aae_least/recon_%d.png' % epoch)
                plt.close('all')
                
                
                z = torch.randn(20, z_dimension) * 5
                z = z.float().cuda(cuda)

                target = torch.randint(10, (20, 1))
                target = target.flatten()

                one_hot_target = torch.zeros(20, 10)
                one_hot_target[torch.arange(20), target] = 1
                one_hot_target = one_hot_target.cuda(cuda)

                z_target = torch.cat([one_hot_target, z], dim=1)
                z_target = z_target.cuda(cuda)

                recon_z = decoder(z_target)
                recon_z = recon_z.resize(20, 1, 28, 28)
                
                recon = recon_z.cpu()
                recon = np.reshape(recon.detach().numpy(), [-1, 28, 28]) * 255
                
                figure = plt.figure(figsize=(20, 4))

                for i in range (20):
                    axis = plt.subplot(2, n_digits, i + 1)
                    plt.imshow(recon[i].reshape(28, 28))
                    plt.gray()
                    axis.get_xaxis().set_visible(False)
                    axis.get_yaxis().set_visible(False)
                
                plt.savefig('3_sampling_supervised_aae_least/epoch_%d.png' % epoch)
                plt.close()
                

                '''
                z_sampling = [np.linspace(-5, 5, 10) for i in range (10)]

                n_x, n_y = 10, 10
                random_input = np.random.randn(10, z_dimension)
                sample_y = np.identity(10)
                plt.subplot()
                grid_spec = gridspec.GridSpec(n_x, n_y, hspace=0.05, wspace=0.05)
                i = 0
                for r in random_input:
                    for t in sample_y:
                        r = np.reshape(r, (1, z_dimension))
                        t = np.reshape(t, (1, n_classes))
                        input_decoder = np.concatenate((t, r), 1)
                        input_decoder = input_decoder.astype('float32')
                        input_decoder = torch.from_numpy(input_decoder).float()
                        input_decoder = input_decoder.cuda(cuda)

                        decoded_X = decoder(input_decoder)
                        decoded_X_cpu = decoded_X.cpu().detach().numpy()

                        axis = plt.subplot(grid_spec[i])
                        i += 1
                        image = np.array(decoded_X_cpu.tolist()).reshape(28, 28)
                        axis.imshow(image, cmap='gray')
                        axis.set_xticks([])
                        axis.set_yticks([])
                        axis.set_aspect('auto')

                plt.savefig('3_sampling_supervised_aae_least/epoch_%d.png' % epoch)
                plt.close()
                '''
                
                encoder = encoder.train()
                decoder = decoder.train()
                discriminator_gauss = discriminator_gauss.train()

    return encoder, decoder

In [8]:
trained_encoder, trained_decoder = train_model(train_loader, valid_loader)



Epoch-0, Time-1.2e+01, Discriminator_loss-0.2586, Generator_loss-0.2065, reconstruction_loss-0.01649
Epoch-1, Time-1.2e+01, Discriminator_loss-0.2623, Generator_loss-0.177, reconstruction_loss-0.01534
Epoch-2, Time-1.2e+01, Discriminator_loss-0.2535, Generator_loss-0.1508, reconstruction_loss-0.01274
Epoch-3, Time-1.2e+01, Discriminator_loss-0.2455, Generator_loss-0.1454, reconstruction_loss-0.01402
Epoch-4, Time-1.1e+01, Discriminator_loss-0.2544, Generator_loss-0.1349, reconstruction_loss-0.0118
Epoch-5, Time-1.2e+01, Discriminator_loss-0.2527, Generator_loss-0.1341, reconstruction_loss-0.01199
Epoch-6, Time-1.2e+01, Discriminator_loss-0.2515, Generator_loss-0.135, reconstruction_loss-0.01107
Epoch-7, Time-1.2e+01, Discriminator_loss-0.2526, Generator_loss-0.1274, reconstruction_loss-0.01165
Epoch-8, Time-1.2e+01, Discriminator_loss-0.2447, Generator_loss-0.1249, reconstruction_loss-0.01001
Epoch-9, Time-1.2e+01, Discriminator_loss-0.2505, Generator_loss-0.1245, reconstruction_loss-0

Epoch-82, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.01053
Epoch-83, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008502
Epoch-84, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008912
Epoch-85, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009363
Epoch-86, Time-1.4e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009621
Epoch-87, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009551
Epoch-88, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.00901
Epoch-89, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009206
Epoch-90, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008644
Epoch-91, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009232
Ep

Epoch-164, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.1254, reconstruction_loss-0.009762
Epoch-165, Time-1.2e+01, Discriminator_loss-0.2501, Generator_loss-0.125, reconstruction_loss-0.009343
Epoch-166, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008569
Epoch-167, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008621
Epoch-168, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009991
Epoch-169, Time-1.2e+01, Discriminator_loss-0.2501, Generator_loss-0.1251, reconstruction_loss-0.009464
Epoch-170, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009058
Epoch-171, Time-1.3e+01, Discriminator_loss-0.25, Generator_loss-0.1251, reconstruction_loss-0.009701
Epoch-172, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008964
Epoch-173, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstructi

Epoch-245, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008582
Epoch-246, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009524
Epoch-247, Time-1.2e+01, Discriminator_loss-0.2499, Generator_loss-0.125, reconstruction_loss-0.009936
Epoch-248, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009316
Epoch-249, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008764
Epoch-250, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.01035
Epoch-251, Time-1.2e+01, Discriminator_loss-0.2498, Generator_loss-0.125, reconstruction_loss-0.008075
Epoch-252, Time-1.3e+01, Discriminator_loss-0.25, Generator_loss-0.1251, reconstruction_loss-0.008888
Epoch-253, Time-1.3e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008297
Epoch-254, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_

Epoch-326, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.01016
Epoch-327, Time-1.2e+01, Discriminator_loss-0.2501, Generator_loss-0.1249, reconstruction_loss-0.008984
Epoch-328, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008871
Epoch-329, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.009388
Epoch-330, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.125, reconstruction_loss-0.008704
Epoch-331, Time-1.2e+01, Discriminator_loss-0.2503, Generator_loss-0.1257, reconstruction_loss-0.009391
Epoch-332, Time-1.2e+01, Discriminator_loss-0.2499, Generator_loss-0.125, reconstruction_loss-0.008692
Epoch-333, Time-1.2e+01, Discriminator_loss-0.2499, Generator_loss-0.125, reconstruction_loss-0.009258
Epoch-334, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.1249, reconstruction_loss-0.009388
Epoch-335, Time-1.2e+01, Discriminator_loss-0.25, Generator_loss-0.1249, reconstr

In [9]:
file_location_encoder = "3-least-supervised-encoder.pt"
file_location_decoder = "3-least-supervised-decoder.pt"
torch.save(trained_encoder.state_dict(), file_location_encoder)
torch.save(trained_decoder.state_dict(), file_location_decoder)