In [2]:

import os
import argparse
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
# import pdb
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

batch_size=128 #'input batch size for training (default: 64)')
epochs=19#'number of epochs to train (default: 2)')
seed=1
log_interval=10 #'how many batches to wait before logging training status')


train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
        transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

lam = 1e-4

class CAE(nn.Module):
    def __init__(self):
        super(CAE, self).__init__()

        self.fc1 = nn.Linear(784, 256, bias = False) # Encoder
        self.fc2 = nn.Linear(256, 784, bias = False) # Decoder

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()


    def encoder(self, x):
        h1 = self.relu(self.fc1(x.view(-1, 784)))
        return h1

    def decoder(self,z):
        h2 = self.sigmoid(self.fc2(z))
        return h2

    def forward(self, x):
            h1 = self.encoder(x)
            h2 = self.decoder(h1)
            return h1, h2

        # Writing data in a grid to check the quality and progress
    def samples_write(self, x, epoch):
        _, samples = self.forward(x)
        #pdb.set_trace()
        samples = samples.data.cpu().numpy()[:16]
        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)
        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
        if not os.path.exists('out/'):
            os.makedirs('out/')
        plt.savefig('out/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
        #self.c += 1
        plt.close(fig)


mse_loss = nn.BCELoss(size_average = False)

def loss_function(W, x, recons_x, h,  lam=0.0001):
    """Compute the Contractive AutoEncoder Loss
    Evalutes the CAE loss, which is composed as the summation of a Mean
    Squared Error and the weighted l2-norm of the Jacobian of the hidden
    units with respect to the inputs.
    See reference below for an in-depth discussion:
      #1: http://wiseodd.github.io/techblog/2016/12/05/contractive-autoencoder
    Args:
        `W` (FloatTensor): (N_hidden x N), where N_hidden and N are the
          dimensions of the hidden units and input respectively.
        `x` (Variable): the input to the network, with dims (N_batch x N)
        recons_x (Variable): the reconstruction of the input, with dims
          N_batch x N.
        `h` (Variable): the hidden units of the network, with dims
          batch_size x N_hidden
        `lam` (float): the weight given to the jacobian regulariser term
    Returns:
        Variable: the (scalar) CAE loss
    """
    mse = mse_loss(recons_x, x)
    # Since: W is shape of N_hidden x N. So, we do not need to transpose it as
    # opposed to #1
    dh = h * (1 - h) # Hadamard product produces size N_batch x N_hidden
    # Sum through the input dimension to improve efficiency, as suggested in #1
    w_sum = torch.sum(Variable(W)**2, dim=1)
    # unsqueeze to avoid issues with torch.mv
    w_sum = w_sum.unsqueeze(1) # shape N_hidden x 1
    contractive_loss = torch.sum(torch.mm(dh**2, w_sum), 0)
    return mse + contractive_loss.mul_( lam)


model = CAE()
optimizer = optim.Adam(model.parameters(), lr = 0.0001)


def train(epoch):
    model.train()
    train_loss = 0

    for idx, (data, _) in enumerate(train_loader):
        data = Variable(data)

        optimizer.zero_grad()

        hidden_representation, recons_x = model(data)

        # Get the weights
        # model.state_dict().keys()
        # change the key by seeing the keys manually.
        # (In future I will try to make it automatic)
        W = model.state_dict()['fc1.weight']
        loss = loss_function(W, data.view(-1, 784), recons_x,
                             hidden_representation,  0.0001)

        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()

        if idx % log_interval == 0:
            print('Train epoch: {} [{}/{}({:.0f}%)]\t Loss: {:.6f}'.format(
                  epoch, idx*len(data), len(train_loader.dataset),
                  100*idx/len(train_loader),
                  loss.data[0]/len(data)))


    print('====> Epoch: {} Average loss: {:.4f}'.format(
         epoch, train_loss / len(train_loader.dataset)))
    model.samples_write(data,epoch)

for epoch in range(epochs):
    train(epoch)

====> Epoch: 0 Average loss: 239.4672
====> Epoch: 1 Average loss: 144.7671
====> Epoch: 2 Average loss: 122.6848
====> Epoch: 3 Average loss: 108.7309
====> Epoch: 4 Average loss: 98.6613
====> Epoch: 5 Average loss: 91.0449
====> Epoch: 6 Average loss: 85.2243
====> Epoch: 7 Average loss: 80.6985
====> Epoch: 8 Average loss: 77.1137
====> Epoch: 9 Average loss: 74.2029
====> Epoch: 10 Average loss: 71.8013
====> Epoch: 11 Average loss: 69.7963
====> Epoch: 12 Average loss: 68.0868
====> Epoch: 13 Average loss: 66.6253
====> Epoch: 14 Average loss: 65.3576
====> Epoch: 15 Average loss: 64.2546
====> Epoch: 16 Average loss: 63.2894
====> Epoch: 17 Average loss: 62.4394
====> Epoch: 18 Average loss: 61.6908


In [3]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
fc1.weight 	 torch.Size([256, 784])
fc2.weight 	 torch.Size([784, 256])
Optimizer's state_dict:
state 	 {0: {'step': 8911, 'exp_avg': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'exp_avg_sq': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}, 1: {'step': 8911, 'exp_avg': tensor([[0.0016, 0.0011, 0.0009,  ..., 0.0017, 0.0012, 0.0016],
        [0.0011, 0.0008, 0.0006,  ..., 0.0011, 0.0008, 0.0011],
        [0.0011, 0.0010, 0.0007,  ..., 0.0012, 0.0010, 0.0013],
        ...,
        [0.0016, 0.0014, 0.0010,  ..., 0.0018, 0.0015, 0.0021],
        [0.0010, 0.0008

In [4]:
torch.save(model.state_dict(), "/home/crns/Desktop/PFE/PFE/Code/Self-Net/ICLR 2019/models/CAE.pt")