In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision

%matplotlib inline
import random
import matplotlib.pyplot as plt

import numpy as np # linear algebra
import struct
from array import array
from os.path  import join
import torch.optim as optim

def show_images(images, title_texts):
    cols = 5
    rows = int(len(images)/cols) + 1
    plt.figure(figsize=(20,12))
    index = 1    
    for x in zip(images, title_texts):
        image = x[0]        
        title_text = x[1]
        plt.subplot(rows, cols, index)        
        plt.imshow(image, cmap=plt.cm.gray)
        if (title_text != ''):
            plt.title(title_text, fontsize = 8);        
        index += 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')

batch_size = 32

train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=torchvision.transforms.ToTensor(), download=True)

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
imgs_train = [train_dataset.data[i] for i in range(10)]
labels_train = train_dataset.targets[:10]

show_images(imgs_train, labels_train)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0 = nn.Linear(784, 784)

    def forward(self, x, label):
        label = torch.reshape(label, shape=(batch_size, 1))
        label = label.expand(batch_size, 784)
        x = torch.relu(torch.add(x, label))
        x = torch.relu(self.fc0(x))
        x = torch.clip(x, 0.0, 256.0)
        return x

net = Net()
net.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.00001, amsgrad=False)

In [None]:
for epoch in range(10):

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        inputs, labels = data

        inputs = inputs.to(device)
        labels = labels.to(device)

        x = torch.normal(0.0, 1.0, size=(batch_size, 784), device = device)

        optimizer.zero_grad()

        predicted = net(x, labels)

        loss = criterion(predicted, torch.reshape(inputs.float(), shape=(batch_size, 784)))

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

PATH = './models/handwritten_digit_creation_net.pth'
torch.save(net.state_dict(), PATH)

print('Finished Training')

In [None]:
net = Net()
net.load_state_dict(torch.load(PATH))
net.to(device)
net.eval()

dataiter = iter(testloader)
inputs, labels = dataiter.next()

inputs_qualitative = inputs.to(device)
labels_qualitative = labels.to(device)

x_qualitative = torch.normal(0.0, 1.0, size=(batch_size, 32), device = device)

predicted_qualitative = net(x_qualitative, labels_qualitative)

predicted_qualitative = torch.round(predicted_qualitative).int()

imgs_qualitative = [inputs_qualitative[i].squeeze(0).cpu() for i in range(10)]
labels_qualitative = [labels_qualitative[i] for i in range(10)]

predicted_qualitative = torch.reshape(predicted_qualitative, shape=(batch_size, 28, 28))
predicted_qualitative = [predicted_qualitative[i].squeeze(0).cpu() for i in range(10)]

print("Ground Truth")
show_images(imgs_qualitative, labels_qualitative)
print()
print("Created")
show_images(predicted_qualitative, labels_qualitative)