In [1]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

In [2]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f993a7bfa30>

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
train_set = torchvision.datasets.MNIST(root = './content',
                                       train = True,
                                       download = False,
                                       transform = transforms.ToTensor())

test_set = torchvision.datasets.MNIST(root='./content',
                                     train = False,
                                     download = False,
                                     transform = transforms.ToTensor())

In [5]:
train_set

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./content
    Split: Train
    StandardTransform
Transform: ToTensor()

In [6]:
train_data = DataLoader(train_set,
                       batch_size = 15,
                       shuffle = True,
                       num_workers = 4)
test_data = DataLoader(test_set,
                       batch_size = 15,
                       shuffle = True,
                       num_workers = 4)

In [7]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(28*28,256, bias=True),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256,128, bias=True),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128,64, bias=True),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64,3, bias=True))
        
        self.decoder = nn.Sequential(
            nn.Linear(3,64, bias=True),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64,128, bias=True),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128,256, bias=True),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256,28*28, bias=True))
        
    def forward(self, x):
        x = self.encoder(x)
#         x.apply(init_weight_function)
#         x.apply(init_batch_function)
        x = self.decoder(x)
#         x.apply(init_weight_function)
#         x.apply(init_batch_function)
        return x
        
    def init_weight_function(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
  
    def init_batch_function(self, n):
        if isinstance(n, nn.BatchNorm1d):
            torch.nn.init.xavier_uniform_(n.weight)


In [8]:
Autoencoder_model = Autoencoder()
optimizer = torch.optim.Adam(Autoencoder_model.parameters(),
                             lr=0.01,weight_decay=1e-5)
criterion = nn.MSELoss()
# criterion = nn.CrossEntropyLoss()

In [9]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

In [10]:
Autoencoder_model(torch.FloatTensor(15,784))

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward>)

In [12]:
epochs = 20
loss_data = []
collect_img = []

for i in range(epochs):
    for data in train_data:
        img_data, _ = data
        img_data = img_data.view(img_data.size(0),-1)
#         img_data = Variable(img_data)

        output = Autoencoder_model(img_data)
        loss = criterion(output, img_data)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch = [{}/{}], loss = {:.4f}'.format(i+1,epochs,loss.item()))
    loss_data.append(loss.item)
    pic = to_img(output.cpu().data)
    collect_img.append(pic)

epoch = [1/20], loss = 0.0421
epoch = [2/20], loss = 0.0526
epoch = [3/20], loss = 0.0517
epoch = [4/20], loss = 0.0513
epoch = [5/20], loss = 0.0464
epoch = [6/20], loss = 0.0454
epoch = [7/20], loss = 0.0498
epoch = [8/20], loss = 0.0520
epoch = [9/20], loss = 0.0526
epoch = [10/20], loss = 0.0410
epoch = [11/20], loss = 0.0524
epoch = [12/20], loss = 0.0552
epoch = [13/20], loss = 0.0524
epoch = [14/20], loss = 0.0480
epoch = [15/20], loss = 0.0486
epoch = [16/20], loss = 0.0470
epoch = [17/20], loss = 0.0452
epoch = [18/20], loss = 0.0439
epoch = [19/20], loss = 0.0580
epoch = [20/20], loss = 0.0492
