# Middle Frame Prediction

## 0) Preparaion

In [2]:
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter

if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

## 1) Autoencoder Architecture

In [15]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Sigmoid())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

## 1) VAE Architecture

In [25]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 12)
        self.fc41 = nn.Linear(12, 3)
        self.fc42 = nn.Linear(12, 3)
        self.fc5 = nn.Linear(3, 12)
        self.fc6 = nn.Linear(12, 64)
        self.fc7 = nn.Linear(64, 128)
        self.fc8 = nn.Linear(128, 28 * 28)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        h3 = self.relu(self.fc3(h2))
        return self.fc41(h3), self.fc42(h3)
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def decode(self, z):
        h4 = self.relu(self.fc5(z))
        h5 = self.relu(self.fc6(h4))
        h6 = self.relu(self.fc7(h5))
        return self.sigmoid(self.fc8(h6))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

## 2) Loading Dataset

In [16]:
img_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#     transforms.Grayscale(num_output_channels=1),
])


image_folder = './data'

# dataset = ImageFolder('./data', trans..=...)
dataset = MNIST(image_folder, download=True, transform=img_transform)


batch_size = 128
data_iter = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## 3) Network Training

In [17]:
from sklearn.model_selection import train_test_split
X_train, X_test = train_test_split(dataset, test_size=0.33, random_state=1)

In [18]:
data_iter_train = DataLoader(X_train, batch_size = batch_size, shuffle = True)
data_iter_test = DataLoader(X_test, batch_size = batch_size, shuffle = True)

In [32]:
model = autoencoder()#.cuda()
model2 = VAE()

loss_func = nn.MSELoss()

learning_rate = 1e-3
#optimizer = torch.optim.Adam(
#    model.parameters(), lr=learning_rate, weight_decay=1e-5)
optimizer = torch.optim.Adam(
    model2.parameters(), lr=learning_rate, weight_decay=1e-5)

In [33]:
writer = SummaryWriter()

testdataset = [dataset[1][0].reshape(1,28,28),
               dataset[3][0].reshape(1,28,28),
               dataset[5][0].reshape(1,28,28),
               dataset[7][0].reshape(1,28,28),
               dataset[2][0].reshape(1,28,28),
               dataset[0][0].reshape(1,28,28),
               dataset[13][0].reshape(1,28,28),
               dataset[38][0].reshape(1,28,28),
               dataset[17][0].reshape(1,28,28),
               dataset[4][0].reshape(1,28,28)]

#(dataset[1]) #0
#(dataset[3]) #1
#(dataset[5]) #2
#(dataset[7]) #3
#(dataset[2]) #4
#(dataset[0]) #5
#(dataset[13]) #6
#(dataset[38]) #7
#(dataset[17]) #8
#(dataset[4]) #9

temp = torchvision.utils.make_grid(testdataset, nrow=5)
temp = temp[0]
#print(temp.shape)
writer.add_image("Test Dataset", temp.reshape(1,62,152))

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 200

for epoch in range(num_epochs):
    
    total_tr_loss = 0
    n_tr = 0
    total_te_loss = 0
    n_te = 0
    
    for data in data_iter_train:
        img, _ = data
        
        input = img.view(img.size(0), -1)#.to(device)#.cuda()
        #output = model(input)
        output = model2(input)[0]
        
        loss = loss_func(output, input)
        total_tr_loss += loss.item()
        n_tr += 1
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    with torch.no_grad():
        for data in data_iter_test:
            img_t, _ = data
        
            input = img_t.view(img_t.size(0), -1)#.to(device)#.cuda()
            #output = model(input)
            output = model2(input)[0]
        
            loss = loss_func(output, input)
            total_te_loss += loss.item()
            n_te += 1
    
    mean_te_loss = total_te_loss/n_te
    mean_tr_loss = total_tr_loss/n_tr
    
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, mean_te_loss))
    
    if epoch % 10 == 0:
        pic = to_img(output.cpu().data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))
    
    temp = torchvision.utils.make_grid([input[0].reshape(1,28,28),
                    input[1].reshape(1,28,28),
                    input[2].reshape(1,28,28),
                    input[3].reshape(1,28,28),
                    input[4].reshape(1,28,28),
                    input[5].reshape(1,28,28),
                    input[7].reshape(1,28,28),
                    input[8].reshape(1,28,28)])
    
    temp = temp[0]
    
    writer.add_image("Training data", temp.reshape(1,32,242), epoch)
    
    temp = torchvision.utils.make_grid([output[0].reshape(1,28,28),
                    output[1].reshape(1,28,28),
                    output[2].reshape(1,28,28),
                    output[3].reshape(1,28,28),
                    output[4].reshape(1,28,28),
                    output[5].reshape(1,28,28),
                    output[7].reshape(1,28,28),
                    output[8].reshape(1,28,28)])
    
    temp = temp[0]
    test_output = []
    
    for data in testdataset:
        img = data
        
        input = img.view(img.size(0), -1)#.to(device)#.cuda()
        #test_output.append(model(input)[0])
        test_output.append(model2(input)[0])
        
    temp2 = torchvision.utils.make_grid([test_output[0].reshape(1,28,28),
                    test_output[1].reshape(1,28,28),
                    test_output[2].reshape(1,28,28),
                    test_output[3].reshape(1,28,28),
                    test_output[4].reshape(1,28,28),
                    test_output[5].reshape(1,28,28),
                    test_output[6].reshape(1,28,28),
                    test_output[7].reshape(1,28,28),
                    test_output[8].reshape(1,28,28),
                    test_output[9].reshape(1,28,28)], nrow=5)

    writer.add_image("Model Test Dataset", temp2[0].reshape(1,62,152), epoch)
    writer.add_image("Output", temp.reshape(1,32,242), epoch)
    writer.add_scalar("Loss/Train", mean_tr_loss, epoch)
    writer.add_scalar("Loss/Test", mean_te_loss, epoch)
    writer.add_scalars("Loss", {"Train": mean_tr_loss, "Test": mean_te_loss}, epoch)

epoch [1/200], loss:0.0549
epoch [2/200], loss:0.0494
epoch [3/200], loss:0.0462
epoch [4/200], loss:0.0432
epoch [5/200], loss:0.0416
epoch [6/200], loss:0.0407
epoch [7/200], loss:0.0394
epoch [8/200], loss:0.0387
epoch [9/200], loss:0.0381
epoch [10/200], loss:0.0376
epoch [11/200], loss:0.0373
epoch [12/200], loss:0.0370
epoch [13/200], loss:0.0366
epoch [14/200], loss:0.0363
epoch [15/200], loss:0.0360
epoch [16/200], loss:0.0360
epoch [17/200], loss:0.0359
epoch [18/200], loss:0.0356
epoch [19/200], loss:0.0355
epoch [20/200], loss:0.0355
epoch [21/200], loss:0.0354
epoch [22/200], loss:0.0351
epoch [23/200], loss:0.0351
epoch [24/200], loss:0.0348
epoch [25/200], loss:0.0349
epoch [26/200], loss:0.0346
epoch [27/200], loss:0.0347
epoch [28/200], loss:0.0344
epoch [29/200], loss:0.0344
epoch [30/200], loss:0.0344
epoch [31/200], loss:0.0341
epoch [32/200], loss:0.0344
epoch [33/200], loss:0.0340
epoch [34/200], loss:0.0342
epoch [35/200], loss:0.0341
epoch [36/200], loss:0.0338
e

KeyboardInterrupt: 

## 4) Loading And Saving

##### Save

In [5]:
torch.save(model.state_dict(), './sim_autoencoder.pth')

##### Load

In [6]:
# torch.save(model, './sim_autoencoder.pth')
model.load_state_dict(torch.load('./sim_autoencoder.pth', map_location='cpu'))

<All keys matched successfully>

## Applying to Images

In [14]:
import numpy as np
import matplotlib.pyplot as plt

In [17]:
def forshow(img):
    return transforms.ToPILImage(mode='L')(img.reshape(28,28))

def encode(img_from_dataset):
    norm_img = img_from_dataset.reshape(28*28)
    return model.encoder(norm_img)

def decode(encoding):
    dec_img = model.decoder(encoding)
    return dec_img

def midFrame(img1, img2):
    return decode((encode(img1) + encode(img2))/2)

### Tests

In [19]:
im1 = dataset[0][0]
im2 = dataset[6][0]

forshow(im1).show()
forshow(decode(encode(im1))).show()



In [21]:
mid1 = midFrame(im1, im2)

mid2 = midFrame(mid1,im2)

mid3 = midFrame(im1, mid1)
for i in [im1, mid3, mid1, mid2, im2]:
    forshow(i).show()

In [48]:
img, _ = dataset[5]

my_img = model(img.reshape(28*28))
my_img = my_img.reshape(1,28,28)

results = transforms.ToPILImage(mode='L')(my_img)
results.show()

#output[0]
#image = output.cpu().numpy()[0]
#image = np.transpose(image, (1,2,0))
#plt.matshow(image)
#plt.show()

In [49]:
img, _ = dataset[5]
results = transforms.ToPILImage(mode='L')(img.reshape(1,28,28)) 
results.show()

img2 = model(img.reshape(28*28))
results2 = transforms.ToPILImage(mode='L')(img2.reshape(1,28,28)) 
results2.show()

In [63]:
#show middle
img, _ = dataset[0]
img2, _ = dataset[1]
enc1 = model.encoder(img.reshape(28*28))
enc2 = model.encoder(img2.reshape(28*28))

In [68]:
toPILimg(img).show()
toPILimg(img2).show()
result = (enc1+enc2)/2
decode(result).show()

In [20]:
img5_2, _ = dataset[11]
img5_1, _ = dataset[0]
toPILimg(img5_1).show()
toPILimg(img5_2).show()
midFrame(img5_1, img5_2).show()

NameError: name 'toPILimg' is not defined