In [1]:
import numpy as np
import torch
from torch import optim
from tqdm import tqdm
from pprint import pprint
from torchvision import datasets, transforms
%matplotlib notebook
import matplotlib.pyplot as plt
import os
import models
from utils.load_model import save_checkpoint, load_checkpoint
import argparse
import skimage.io
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def get_mnist_data(device, reshape=True):
    def my_transform(x):
        if reshape:
            return x.to(device).reshape(-1)
        else:
            return x.to(device)
    preprocess = transforms.Compose([transforms.ToTensor(),my_transform])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST("data", train=True, download=True, transform=preprocess),
        batch_size=100,
        shuffle=True,
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST("data", train=False, download=True, transform=preprocess),
        batch_size=100,
        shuffle=True,
    )

    return train_loader, test_loader

def noisy_soft_labels(labels):
    noisy = torch.bernoulli(0.9*labels+0.05)
    noisy_soft = torch.where(noisy==0, torch.rand_like(noisy)*0.3, torch.rand_like(noisy)*0.7+0.5)
    return noisy_soft

In [3]:
class Args:
   def __init__(self, **kwargs):
      self.__dict__.update(kwargs)
args = Args(ae_mode = 'conv',
train_ae = 1,
ae_epoch = 5,
gen_mode = "w",
train_gen = 0,
gen_epoch = 40,
save_images = 0)
train_loader, test_loader = get_mnist_data(device, reshape=(args.ae_mode=="ae"))
num_epochs1 = args.train_ae
num_epochs2 = args.train_gen
x_dim = 784
z_dim = 10
z0_dim = 10
save_path = 'checkpoints'

In [4]:
if args.ae_mode=="ae":
    ae = models.AutoEncoder(x_dim, z_dim, n_units=[300,300]).to(device)
if args.ae_mode=="conv":
    ae = models.ConvAutoEncoder(in_channels=1, image_size=(28,28), z_dim=z_dim, activation=None).to(device)
# if args.gen_mode=="n" or args.gen_mode=="w":
generator = models.Encoder(z0_dim, z_dim, [300,300]).to(device)
# if args.gen_mode=="convn" or args.gen_mode=="convw":
#     generator = models.ConvGenerator(z0_dim, ae.z_dim)
discriminator = models.Discriminator(z_dim, [20,20]).to(device)
g_optimizer = optim.Adam(generator.parameters(), lr=1e-2)
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-3)
ed_optimizer = optim.Adam(ae.parameters(), lr=1e-3)

In [5]:
if args.train_ae:
# First train encoder and decoder
    print("Training Encoder-Decoder .............")
    for e in range(1, num_epochs1 + 1):
        for batch in train_loader:
            ae.train()
            x_batch, y_batch = batch
            batch_size = x_batch.size()[0]
            labels = torch.eye(10)[y_batch.cpu()].to(device).float()

            ## Train encoder-decoder
            ## min -E_{q(z|x)} log(p(x|z))
            ed_optimizer.zero_grad()
            z = ae.encode(x_batch)
#             z += torch.randn_like(z)*
            x_out = ae.decode(z)
            ed_loss = torch.nn.BCEWithLogitsLoss(reduction='none')(input=x_out, target=x_batch).sum(-1).mean()
            ed_loss.backward()
            ed_optimizer.step()

        with torch.no_grad():
            x_batch = next(iter(test_loader))[0]
            ae.eval()
            z = ae.encode(x_batch)
            x_out = ae.decode(z)
            images = torch.round(torch.sigmoid(x_out).cpu().detach())
            test_loss = torch.nn.BCEWithLogitsLoss(reduction='none')(input=x_out, target=x_batch).sum(-1).mean()
        images_tiled = np.reshape(
            np.transpose(np.reshape(images, (10, 10, 28, 28)), (0, 2, 1, 3)),
            (280, 280),
        )
        plt.imsave("images/mnist-ae/{}{}.png".format(args.ae_mode,e), images_tiled, cmap="gray")
        print(
            "Epoch {} : E-D train loss = {:.2e} test loss = {:.2e}".format(
                e, ed_loss, test_loss
            )
        )

        if e%5==0:
            checkpoint_dict = {
                'epoch':e,
                'autoencoder':ae.state_dict(),
                'ed_optimizer':ed_optimizer.state_dict()
            }
            fname = f'enc-dec_{args.ae_mode}{e}'
            save_checkpoint(checkpoint_dict, save_path, fname)
    for k in range(z_dim):
        plt.figure()
        plt.hist(z[:,k].cpu().numpy(), bins=100)
        plt.title(f'z_{k} histogram')
        plt.savefig(f'plots/{args.ae_mode}_z{k}.png')
else:
    fname = 'enc-dec_'+args.ae_mode+str(args.ae_epoch)
    enc_dec = load_checkpoint(save_path, fname, device)
    ae.load_state_dict(enc_dec['autoencoder'])

Training Encoder-Decoder .............
Epoch 1 : E-D train loss = 4.20e+00 test loss = 4.10e+00


TclError: no display name and no $DISPLAY environment variable

In [6]:
plt.figure()

TclError: no display name and no $DISPLAY environment variable