<a href="https://colab.research.google.com/github/VyatkinAlexey/Noiseless/blob/master/Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.utils.data as torch_data
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline

In [0]:
train_data = pd.read_csv('sample_data/mnist_train_small.csv', header=None)
test_data = pd.read_csv('sample_data/mnist_test.csv', header=None)

X_train = train_data[train_data[0].isin([0,1])].drop(0, axis=1)
X_test = test_data.drop(0, axis=1)
y_train = train_data[train_data[0].isin([0,1])][0]
y_test = test_data[0]

In [106]:
class mnist(torch_data.Dataset):
    def __init__(self, X, y):
        super(mnist, self).__init__()
        self.X = torch.FloatTensor(np.array(X/255))
        self.y = torch.FloatTensor(np.array(y))


    def __len__(self):
        return self.X.shape[0]


    def __getitem__(self, idx):
        return ([self.X[idx], self.y[idx]])

train_dset = mnist(X_train, y_train) 
test_dset = mnist(X_test, y_test) 

print(len(train_dset[5][0]))

784


In [0]:
class denoisingAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(denoisingAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x, y=None):
        """
        Take a pair of (data,labels) (clean, noise, semi-noisy)sample:
        shape ((batch_size, x_dim), (batch_size, 1)
        """
        if self.training:
            x = x.view(1,1,28,28)
            x = self.encoder(x)
            if y==0:
                x[0][100:128] = torch.zeros(28)
            elif y==1:
                x[0][:100] = torch.zeros(100)
            x = self.decoder(x)
            x = x.view(784)
        return x

In [0]:
class Unflatten(nn.Module):
    def forward(self, inp):
        return inp.reshape(inp.shape[0],inp.shape[1],1,1)



encoder = nn.Sequential(
    nn.Conv2d(1, 8, 5), # bs, 8, 24, 24
    nn.Sigmoid(),
    nn.MaxPool2d(2,2), # bs, 8, 12, 12
    nn.Conv2d(8, 16, 5), # bs, 16, 8, 8
    nn.Sigmoid(),
    nn.MaxPool2d(2), # bs, 16, 4, 4
    nn.Flatten(),
    nn.Linear(16*4*4, 128),
) 

decoder = nn.Sequential(
    nn.Linear(128, 16*4*4),
    Unflatten(),
    nn.ConvTranspose2d(16*4*4, 64, 1),
    nn.Sigmoid(),
    nn.ConvTranspose2d(64, 32, 5, stride=2),
    nn.Sigmoid(),
    nn.ConvTranspose2d(32, 16, 5, stride=2),
    nn.Sigmoid(),
    nn.ConvTranspose2d(16, 1, 4, stride=2)
)

In [0]:
device = torch.device('cuda')
net = denoisingAE(encoder, decoder)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)

net.to(device);

In [0]:
def train_epoch(net, data_train, criterion, optimizer):
    net.train()
    loss = 0
    for sample in data_train:
        x, y = sample
        x = x.to(device)
        y = y.to(device)
        x_hat = net.forward(x, y).to(device)
        sample_loss = criterion(x, x_hat)
        loss += sample_loss
        sample_loss.backward()
        optimizer.step()

    return loss/len(data_train)

In [0]:
for i in range(100):
    print(f'epoch: {i}, loss: {train_epoch(net, train_dset, criterion, optimizer)}')

epoch: 0, loss: 1.8757492303848267
epoch: 1, loss: 3.748345136642456
epoch: 2, loss: 5.048675537109375
epoch: 3, loss: 5.489572525024414
epoch: 4, loss: 5.916383266448975
epoch: 5, loss: 6.498894691467285
epoch: 6, loss: 6.940377235412598
epoch: 7, loss: 6.840693473815918
epoch: 8, loss: 6.925296306610107
epoch: 9, loss: 7.064377784729004
epoch: 10, loss: 7.172752380371094
epoch: 11, loss: 7.259344577789307
epoch: 12, loss: 7.31589412689209
epoch: 13, loss: 7.361939907073975
epoch: 14, loss: 7.387976169586182
epoch: 15, loss: 7.40094518661499
epoch: 16, loss: 7.400724411010742
epoch: 17, loss: 7.381535053253174
epoch: 18, loss: 7.365080833435059
epoch: 19, loss: 7.328773498535156
epoch: 20, loss: 7.316483497619629
epoch: 21, loss: 7.289459705352783


In [0]:
example = train_dset[0]
ex_x, ex_y = example
ex_x = ex_x.view(1,1,28,28).to(device)
enc_x = encoder(ex_x)

In [0]:
enc_x[0][:100] = torch.zeros(100)