In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.set_default_tensor_type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor)
print(torch.cuda.is_available())

True


In [9]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5))
        self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5))
        self.conv3 = nn.Conv2d(16, 60, kernel_size=(4, 4))
        
        self.bn1 = nn.BatchNorm2d(6)
        self.bn2 = nn.BatchNorm2d(16)
        self.bn3 = nn.BatchNorm2d(60)
    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = self.bn1(x)
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = self.bn2(x)
        x = F.relu(self.conv3(x))
        #x = self.bn3(x)
        return x
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(60, 16, kernel_size=(4, 4))
        self.deconv2 = nn.ConvTranspose2d(16, 16, kernel_size=(5, 5))
        self.deconv3 = nn.ConvTranspose2d(16, 6, kernel_size=(5, 5))
        self.deconv4 = nn.ConvTranspose2d(6, 6, kernel_size=(13, 13))
        self.deconv5 = nn.ConvTranspose2d(6, 1, kernel_size=(5, 5))
        
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(6)
    def forward(self, x):
        x = F.relu(self.deconv1(x))
        x = self.bn1(x)
        x = F.relu(self.deconv2(x))
        x = self.bn1(x)
        x = F.relu(self.deconv3(x))
        x = self.bn2(x)
        x = F.relu(self.deconv4(x))
        x = self.bn2(x)
        x = F.relu(self.deconv5(x))
        return F.sigmoid(x)

In [10]:
def assess(model, data, labels):
    total, correct = 0.0, 0.0
    for x, y, in zip(data, labels):
        if torch.argmax(model(x)) == y:
            correct += 1
        total += 1
    return correct / total

In [11]:
data_train = np.fromfile("MNIST/images_train", dtype=np.ubyte).reshape(-1, 28, 28).astype(np.float)
labels_train = np.fromfile("MNIST/labels_train", dtype=np.ubyte)
data_test = np.fromfile("MNIST/images_test", dtype=np.ubyte).reshape(-1, 28, 28).astype(np.float)
labels_test = np.fromfile("MNIST/labels_test", dtype=np.ubyte)

indices = np.random.choice(np.arange(len(labels_train)), len(labels_train), replace=False)
data_train = data_train[indices]
labels_train = labels_train[indices]

data_train = torch.tensor(data_train, dtype=torch.float) / 255
data_test = torch.tensor(data_test, dtype=torch.float) / 255
labels_train = torch.tensor(labels_train, dtype=torch.long).unsqueeze(1)
labels_test = torch.tensor(labels_test, dtype=torch.long).unsqueeze(1)
#plt.imshow(data_train[34], cmap='gray')

In [12]:
encoder = Encoder()
decoder = Decoder()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

In [13]:
epochs = 10
encoder.train()
decoder.train()
for epoch in range(epochs):
    for count, x in enumerate(data_train):
        optimizer.zero_grad()
        z = encoder(x)
        y = decoder(z)
        criterion = nn.BCELoss()
        loss = criterion(y.view(-1, 28, 28), x.view(-1, 28, 28))
        loss.backward()
        optimizer.step() 
        if count % 1000 == 0:
            print(loss)
            index = np.random.randint(len(data_test))
            encoder.eval()
            decoder.eval()
            z = encoder(data_test[index])
            y = decoder(z)
            y_vis = y.detach().cpu().numpy().reshape(28,28)
            x_vis = data_test[index].cpu().numpy()

            encoder.train()
            decoder.train()

            fig = plt.figure(figsize=(28,28))
            fig.add_subplot(1, 2, 1)
            plt.imshow(x_vis, cmap='gray')
            fig.add_subplot(1, 2, 2)
            plt.imshow(y_vis, cmap='gray')
            plt.show()
    torch.save(encoder.state_dict(), "encoder")
    torch.save(decoder.state_dict(), "decoder")

torch.Size([1, 16, 8, 8])




RuntimeError: invalid argument 2: size '[-1 x 28 x 28]' is invalid for input with 1024 elements at /opt/conda/conda-bld/pytorch_1532579245307/work/aten/src/TH/THStorage.cpp:80