# Middle Frame Prediction

## 0) Preparaion

In [None]:
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

## 1) Autoencoder Architecture

In [2]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

## 2) Loading Dataset

In [13]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#     transforms.Grayscale(num_output_channels=1),
])


image_folder = './data'

# dataset = ImageFolder('./data', trans..=...)
dataset = MNIST(image_folder, download=True, transform=img_transform)


batch_size = 128
data_iter = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## 3) Network Training

In [18]:
from sklearn.model_selection import train_test_split
X_train, X_test = train_test_split(dataset, test_size=0.33, random_state=1)

In [22]:
data_iter_train = DataLoader(X_train, batch_size = batch_size, shuffle = True)
data_iter_test = DataLoader(X_test, batch_size = batch_size, shuffle = True)

In [4]:
model = autoencoder()#.cuda()


loss_func = nn.MSELoss()

learning_rate = 1e-3
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [12]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 100

for epoch in range(num_epochs):
    total_tr_loss = 0 
    for data in data_iter_train:
        img, _ = data
        
        input = img.view(img.size(0), -1)#.to(device)#.cuda()
        output = model(img)
        
        loss = loss_func(output, input)
        
        total_tr_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    total_te_loss = 0
    with torch.no_grad():
        for data in data_iter_test:
            img_t, _ = data
        
            input = img.view(img_t.size(0), -1)#.to(device)#.cuda()
            output = model(img_t)
        
            loss = loss_func(output, input)
            total_te_loss += loss.item()
        
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item()))
    
    if epoch % 10 == 0:
        pic = to_img(output.cpu().data)
        save_image(pic, './mlp_img/image_{}.png'.format(epoch))

RuntimeError: size mismatch, m1: [3584 x 28], m2: [784 x 128] at C:\w\1\s\tmp_conda_3.7_055457\conda\conda-bld\pytorch_1565416617654\work\aten\src\TH/generic/THTensorMath.cpp:752

## 4) Loading And Saving

##### Save

In [5]:
torch.save(model.state_dict(), './sim_autoencoder.pth')

##### Load

In [6]:
# torch.save(model, './sim_autoencoder.pth')
model.load_state_dict(torch.load('./sim_autoencoder.pth', map_location='cpu'))

<All keys matched successfully>

## Applying to Images

In [7]:
import numpy as np
import matplotlib.pyplot as plt

In [8]:
def forshow(img):
    return transforms.ToPILImage(mode='L')(img.reshape(28,28))

def encode(img_from_dataset):
    norm_img = img_from_dataset.reshape(28*28)
    return model.encoder(norm_img)

def decode(encoding):
    dec_img = model.decoder(encoding)
    return dec_img

def midFrame(img1, img2):
    return decode((encode(img1) + encode(img2))/2)

### Tests

In [9]:
im1 = dataset[0][0]
im2 = dataset[6][0]

forshow(im1).show()
forshow(decode(encode(im1))).show()



In [10]:
mid1 = midFrame(im1, im2)

mid2 = midFrame(mid1,im2)

mid3 = midFrame(im1, mid1)
for i in [im1, mid3, mid1, mid2, im2]:
    forshow(i).show()

In [11]:
img, _ = dataset[0]

my_img = model(img.reshape(28*28))
my_img = my_img.reshape(1,28,28)

results = transforms.ToPILImage(mode='L')(my_img)
results.show()

#output[0]
#image = output.cpu().numpy()[0]
#image = np.transpose(image, (1,2,0))
#plt.matshow(image)
#plt.show()

In [7]:
img, _ = dataset[0]
results = transforms.ToPILImage(mode='L')(img.reshape(1,28,28)) 
results.show()

img2 = model(img.reshape(28*28))
results2 = transforms.ToPILImage(mode='L')(img2.reshape(1,28,28)) 
results2.show()



In [63]:
#show middle
img, _ = dataset[0]
img2, _ = dataset[1]
enc1 = model.encoder(img.reshape(28*28))
enc2 = model.encoder(img2.reshape(28*28))

In [68]:
toPILimg(img).show()
toPILimg(img2).show()
result = (enc1+enc2)/2
decode(result).show()

In [7]:
img5_2, _ = dataset[11]
img5_1, _ = dataset[0]
toPILimg(img5_1).show()
toPILimg(img5_2).show()
midFrame(img5_1, img5_2).show()