In [None]:
import random
from PIL import Image
import numpy as np
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from torchvision import transforms
from configuration_generator import RectangleConfigurationGenerator
import matplotlib.pyplot as plt

In [None]:
# ATTRS = dict(min_w=10, max_w=50, min_h=10, max_h=50, min_x0 = 0, min_y0 = 0, padding = 5, canvas_size = 28)

In [None]:
gen = RectangleConfigurationGenerator()
b, d = gen.generate_training_pairs()

In [None]:
b[0]

In [None]:
d[0]

In [None]:
# create training samples:
training_dataset = gen.generate_training_data(200)
#for _ in training_dataset:
 #   print(b[0], d[0])

In [None]:
#convert boolean data to bitmap data for the aligned and the corrupted rectnagles

#transform Image to np.array
def boolstr_to_floatstr(v):
    if v == 'True':
        return '1'
    elif v == 'False':
        return '0'
    else:
        return v

def preprocess(sample):
    #vectorization to get 0s and 1s
    new_data = np.vectorize(boolstr_to_floatstr)(sample).astype(float)
    data = torch.Tensor(list(new_data))
    # resize:
    return data.view(1, 28, 28)


beautified = (b for b, *_ in training_dataset)
distorted = (d for _, d, *_ in training_dataset)
beautified = list(map(preprocess, beautified))
distorted = list(map(preprocess, distorted))
beautified_img = torch.stack([torch.Tensor(i) for i in beautified])
distorted_img = torch.stack([torch.Tensor(i) for i in distorted])

In [None]:
#Set Hyperparameters

epoch = 200
batch_size = 100
learning_rate = 0.001

#data processing
#to-do add dataset
dataset = beautified_img

#set data loader(input pipeline)
train_loader = torch.utils.data.DataLoader(dataset = dataset, batch_size = batch_size, shuffle = True)

In [None]:
def min_max_normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    
def tensor_round(tensor):
    return torch.round(tensor)

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor: min_max_normalization(tensor, 0, 1)),
    transforms.Lambda(lambda tensor: tensor_round(tensor))
])

In [None]:
# Encoder 
# torch.nn.Conv2d(in_channels, out_channels, kernel_size,
#                 stride = 1, padding = 0, dilation = 1,
#                 groups = 1, bias = True)
# batch x 1 x 28 x 28 -> batch x 512

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.layer1 = nn.Sequential(
                        nn.Conv2d(1,32,3,padding = 1),   # batch x 16 x 28 x 28
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.Conv2d(32,32,3,padding = 1),   # batch x 16 x 28 x 28
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.Conv2d(32,64,3,padding = 1),  # batch x 32 x 28 x 28
                        nn.ReLU(),
                        nn.BatchNorm2d(64),
                        nn.Conv2d(64,64,3,padding = 1),  # batch x 32 x 28 x 28
                        nn.ReLU(),
                        nn.BatchNorm2d(64),
                        nn.MaxPool2d(2,2)   # batch x 64 x 14 x 14
        )
        self.layer2 = nn.Sequential(
                        nn.Conv2d(64,128,3,padding = 1),  # batch x 64 x 14 x 14
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.Conv2d(128,128,3,padding = 1),  # batch x 64 x 14 x 14
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(128,256,3,padding = 1),  # batch x 64 x 7 x 7
                        nn.ReLU()
        )
    
    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(batch_size, -1)
        return out
    
encoder = Encoder()

# Decoder 
# torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
#                          stride=1, padding=0, output_padding=0,
#                          groups=1, bias=True)
#output_height = (height-1)*stride + kernel_size - 2*padding + output_padding
#output_size = (input_size + 2*padding - kernel_size)/stride + 1 
# batch x 512 -> batch x 1 x 28 x 28

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.layer1 = nn.Sequential(
                        nn.ConvTranspose2d(256,128,3,2,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.ConvTranspose2d(128,128,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(128),
                        nn.ConvTranspose2d(128,64,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(64),
                        nn.ConvTranspose2d(64,64,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(64)
        )
        self.layer2 = nn.Sequential(
                        nn.ConvTranspose2d(64,32,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.ConvTranspose2d(32,32,3,1,1),
                        nn.ReLU(),
                        nn.BatchNorm2d(32),
                        nn.ConvTranspose2d(32,1,3,2,1,1),
                        nn.ReLU()
        )
        
    def forward(self,x):
        out = x.view(batch_size, 256, 7, 7)
        out = self.layer1(out)
        out = self.layer2(out)
        return out

decoder = Decoder()

In [None]:
#Check output of autoencoder
for image in train_loader:
    image = Variable(image)
    output = encoder(image)
    output = decoder(output)
    print(output.size())
    break    

In [None]:
# loss func and optimizer
# we compute reconstruction after decoder so use Mean Squared Error
# In order to use multi parameters with one optimizer,
# concat parameters after changing into list

parameters = list(encoder.parameters())+ list(decoder.parameters())
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(parameters, lr = learning_rate)

In [None]:
#train encoder, decoder
# save and load model

#try:
#    encoder, decoder = torch.load('./model/denoising_autoencoder.pkl')
#    print("\n--------model restored--------\n")
#except:
#    print("\n--------model not restored--------\n")
#    pass

for i in range(epoch):
    for image in train_loader:
        #d_image = distorted_img
        image = Variable(image)
        #d_image = Variable(d_image)
        optimizer.zero_grad()
        output = encoder(image)
        output = decoder(output)
        #plt.imshow(out[0], cmap = 'gray')
        #plt.show()
        loss = loss_func(output, image)
        loss.backward()
        optimizer.step()  

In [None]:
#check output with corrupted image as input

for i in np.arange(85, 99, 1):
    img = beautified_img[i]
    #input_img = distorted_img[80]
    input_img = beautified_img[i] #to learn the identity
    output_img = output[i]

    #origin = img.data.numpy()
    inp = input_img.data.numpy()
    out = output_img.data.numpy()
    
    #plt.imshow(img[0], cmap = 'gray')
    #plt.show()
    
    plt.imshow(inp[0], cmap = 'gray')
    plt.show()

    plt.imshow(out[0], cmap = 'gray')
    plt.show()
    np.shape(inp)

In [None]:
inp = []
out = []
for i in np.arange(85, 99, 1):
    img = beautified_img[i]
    #input_img = distorted_img[80]
    input_img = beautified_img[i] #to learn the identity
    output_img = output[i]

    #origin = img.data.numpy()
    inp.append(input_img.data.numpy())
    out.append(output_img.data.numpy())
print(np.shape(inp[10][0]))

In [None]:
def show_images(images, cols = 1, titles = None):
    """Display a list of images
    
    Parameters
    ---------
    images: List of np.arrays compatible with plt.imshow.
    
    cols: Number of columns in figure (number of rows is 
                        set to np.ceil(n_images/float(cols))).
    
    titles: List of titles corresponding to each image. Must have
            the same length as titles.
    """
    assert((titles is None)or (len(images) == len(titles)))
    n_images = len(images)
    if titles is None: titles = ['Image (%d)' % i for i in range(1,n_images + 1)]
    fig = plt.figure()
    for n, (image, title) in enumerate(zip(images, titles)):
        a = fig.add_subplot(cols, np.ceil(n_images / float(cols)), n + 1)
        if image.ndim == 2:
            plt.gray()
        plt.imshow(image[0])
        a.set_title(title)
    fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
    plt.show()
show_images(inp, cols = 1, titles = None)
show_images(out, cols = 1, titles = None)