In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from utils import mnist, plot_graphs, plot_mnist
import numpy as np

%matplotlib inline

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
use_cuda = True
device = torch.device("cuda" if use_cuda else "cpu")

In [4]:
mnist_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
           ])
train_loader, valid_loader, test_loader = mnist(valid=10000, transform=mnist_transform, batch_size=50)

In [5]:
class Encoder(nn.Module):
    def __init__(self, latent_size=10):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(28*28, latent_size)
    
    def forward(self, x):
        x = self.fc1(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, latent_size=10):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_size, 28*28)
    
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return x

In [6]:
class Net(nn.Module):
    def __init__(self, latent_size=10, loss_fn=F.mse_loss, lr=1e-4, l2=0.):
        super(Net, self).__init__()
        self.latent_size = latent_size
        self.E = Encoder(latent_size)
        self.D = Decoder(latent_size)
        self.loss_fn = loss_fn
        self._rho_loss = None
        self._loss = None
        self.optim = optim.Adam(self.parameters(), lr=lr, weight_decay=l2)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        h = self.E(x)
        self.data_rho = h
        out = self.D(h)
        return out
    
    def decode(self, h):
        with torch.no_grad():
            return self.D(h)
    
    def rho_loss(self, rho, size_average=True):
#         input_ = torch.softmax(self.data_rho, 1)
        L = torch.sum(-torch.softmax(self.data_rho, 1))*10000
#         target_ = torch.argmax(input_, 1)
#         L = F.nll_loss(input_, target_)
#         L = torch.abs(self.data_rho)
#         L = torch.sum(torch.abs(input_))
#         L = self.data_rho*self.data_rho
    
        if size_average:
            self._rho_loss = L.mean()
        else:
            self._rho_loss = L.sum()
        return self._rho_loss
    
    def loss(self, x, target, **kwargs):
        target = target.view(-1, 28*28)
        self._loss = self.loss_fn(x, target, **kwargs)
        return self._loss

In [7]:
models = {'16': Net(16).to(device), '32': Net(32).to(device), '64': Net(64).to(device)}
# models = {'64': Net(64).to(device)}
rho = 0.05
train_log = {k: [] for k in models}
test_log = {k: [] for k in models}

In [8]:
def train(epoch, models, log=None):
    train_size = len(train_loader.sampler)
    for batch_idx, (data, _) in enumerate(train_loader):
        for model in models.values():
            model.optim.zero_grad()
            output = model(data.to(device))
            rho_loss = model.rho_loss(rho, True)
            loss = model.loss(output, data.to(device)) + rho_loss
            loss.backward()
            model.optim.step()
            
        if batch_idx % 200 == 0:
            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
            losses = ' '.join(['{}: {:.6f}'.format(k, m._loss.item()) for k, m in models.items()])
            print(line + losses)
            
    else:
        batch_idx += 1
        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses '.format(
            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
        losses = ' '.join(['{}: {:.6f}'.format(k, m._loss.item()) for k, m in models.items()])
        if log is not None:
            for k in models:
                log[k].append((models[k]._loss, models[k]._rho_loss))
        print(line + losses)

In [9]:
avg_lambda = lambda l: 'loss: {:.4f}'.format(l)
rho_lambda = lambda p: 'rho_loss: {}'.format(p)
data_rho_lambda = lambda q: 'data_rho: {}'.format(q)
line = lambda i, l, p: '{}: '.format(i) + avg_lambda(l) + '\t' + rho_lambda(p)
line_extra = lambda i, l, p, q: '{}: '.format(i) + avg_lambda(l) + '\t' + rho_lambda(p) + '\n' + data_rho_lambda(q)
    
def test(models, loader, log=None):
    test_size = len(loader.sampler)


    test_loss = {k: 0. for k in models}
    rho_loss = {k: 0. for k in models}
    with torch.no_grad():
        for data, _ in loader:
            output = {k: m(data.to(device)) for k, m in models.items()}
            for k, m in models.items():
                test_loss[k] += m.loss(output[k], data.to(device), reduction='sum').item() # sum up batch loss
                rho_loss[k] += m.rho_loss(rho, size_average=False).item()
    
    for k in models:
        test_loss[k] /= (test_size * 784)
        rho_loss[k] /= (test_size * models[k].latent_size)
        if log is not None:
            log[k].append((test_loss[k], rho_loss[k], models[k].data_rho))
    
    lines = '\n'.join([line(k, test_loss[k], rho_loss[k]) for k in models]) + '\n'
    report = 'Test set:\n' + lines        
    print(report)

In [10]:
for epoch in range(1, 101):
    for model in models.values():
        model.train()
    train(epoch, models, train_log)
    for model in models.values():
        model.eval()
    test(models, valid_loader, test_log)

Test set:
16: loss: 0.2538	rho_loss: -625.0000037109374
32: loss: 0.2111	rho_loss: -312.5000018554687
64: loss: 0.1385	rho_loss: -156.24999995117187

Test set:
16: loss: 0.2278	rho_loss: -625.0000017578125
32: loss: 0.1611	rho_loss: -312.4999857421875
64: loss: 0.0954	rho_loss: -156.24999951171876

Test set:
16: loss: 0.2007	rho_loss: -625.0000068359375
32: loss: 0.1336	rho_loss: -312.49999873046875
64: loss: 0.0712	rho_loss: -156.2499994140625

Test set:
16: loss: 0.1821	rho_loss: -624.9999978515625
32: loss: 0.1180	rho_loss: -312.50000615234376
64: loss: 0.0588	rho_loss: -156.250000390625

Test set:
16: loss: 0.1676	rho_loss: -625.0000119140625
32: loss: 0.1082	rho_loss: -312.49999248046873
64: loss: 0.0506	rho_loss: -156.25000034179686

Test set:
16: loss: 0.1570	rho_loss: -625.000005078125
32: loss: 0.1011	rho_loss: -312.49999501953124
64: loss: 0.0449	rho_loss: -156.25000092773436

Test set:
16: loss: 0.1488	rho_loss: -625.000005859375
32: loss: 0.0954	rho_loss: -312.4999966796875

Test set:
16: loss: 0.1208	rho_loss: -624.9999962890626
32: loss: 0.0720	rho_loss: -312.4999958984375
64: loss: 0.0271	rho_loss: -156.24999990234375

Test set:
16: loss: 0.1185	rho_loss: -625.000003125
32: loss: 0.0699	rho_loss: -312.4999958984375
64: loss: 0.0262	rho_loss: -156.25000024414064

Test set:
16: loss: 0.1150	rho_loss: -624.999998828125
32: loss: 0.0679	rho_loss: -312.49999404296875
64: loss: 0.0254	rho_loss: -156.2499998046875

Test set:
16: loss: 0.1126	rho_loss: -625.00000625
32: loss: 0.0660	rho_loss: -312.4999984375
64: loss: 0.0246	rho_loss: -156.2499990234375

Test set:
16: loss: 0.1108	rho_loss: -625.0000005859375
32: loss: 0.0644	rho_loss: -312.4999969726563
64: loss: 0.0240	rho_loss: -156.24999975585936

Test set:
16: loss: 0.1092	rho_loss: -624.99999453125
32: loss: 0.0629	rho_loss: -312.4999983398437
64: loss: 0.0234	rho_loss: -156.24999951171876

Test set:
16: loss: 0.1075	rho_loss: -624.9999951171875
32: loss: 0.0615	rho_loss: -312.499997265625
64: loss: 0.022

Test set:
16: loss: 0.1005	rho_loss: -624.9999947265625
32: loss: 0.0547	rho_loss: -312.499997265625
64: loss: 0.0201	rho_loss: -156.24999975585936

Test set:
16: loss: 0.0998	rho_loss: -624.9999921875
32: loss: 0.0540	rho_loss: -312.50000107421874
64: loss: 0.0199	rho_loss: -156.24999956054688

Test set:
16: loss: 0.0991	rho_loss: -624.999989453125
32: loss: 0.0534	rho_loss: -312.5000009765625
64: loss: 0.0196	rho_loss: -156.25

Test set:
16: loss: 0.0983	rho_loss: -624.9999927734375
32: loss: 0.0528	rho_loss: -312.5
64: loss: 0.0194	rho_loss: -156.24999975585936



KeyboardInterrupt: 

In [None]:
test_log['64']
losses = np.array([[it[0],it[1]] for it in test_log['64']])
N = losses.shape[0]
plt.figure(1), plt.plot(range(N), losses[:,0])
plt.figure(2), plt.plot(range(N), losses[:,1])

In [None]:
mod = '64'
mod_ = int(mod)
data, _ = next(iter(test_loader))
output = models[mod](data.to(device))
to_plot = output.view(-1, 1, 28, 28).clamp(0, 1).data.cpu().numpy()
decoded = models[mod].decode(torch.eye(mod_).to(device))
dec_to_plot = ((decoded.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.cpu().numpy()
decoded_neg = models[mod].decode(-torch.eye(mod_).to(device))
dec_neg_to_plot = ((decoded_neg.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.cpu().numpy()
with torch.no_grad():
    encoded = models[mod].E(data.view(-1, 28*28).to(device))
    print((torch.abs(encoded) > 0.2).sum(1))
    encoded[torch.abs(encoded) < 0.2] = 0.
    decoded_f = models[mod].decode(encoded)
    f_to_plot = ((decoded_f.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.cpu().numpy()

In [None]:
mod = '64'
mod_ = int(mod)
data, _ = next(iter(test_loader))
with torch.no_grad():
    encoded = models[mod].E(data.view(-1, 28*28).to(device))
#     print(encoded.size())
    plt.figure(1), sns.barplot(np.arange(mod_), encoded.mean(0), palette="rocket")
    plt.figure(2), sns.distplot(encoded.reshape([-1,1]).cpu().numpy())
#     plt.figure(3), sns.violinplot(encoded.reshape([-1,1]).cpu().numpy())

In [None]:
plot_mnist(data.data.numpy(), (5, 10))
plot_mnist(to_plot, (5, 10))
# plot_mnist(f_to_plot, (5, 10))
plot_mnist(dec_to_plot, (8, 8))
plot_mnist(dec_neg_to_plot, (8, 8))