# Middle Frame Prediction

## 0) Preparaion

In [1]:
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from sklearn.model_selection import train_test_split

from torch.utils.tensorboard import SummaryWriter

#from torch.utils.tensorboard import SummaryWriter

if not os.path.exists('./mlp_img'):
    os.mkdir('./mlp_img')

## 1) Autoencoder Architecture

In [2]:
#Linear model
class autoencoder_lin(nn.Module):
    def __init__(self):
        super(autoencoder_lin, self).__init__()
        self.encoder = nn.Sequential(
            View((-1, 28*28)),
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), nn.Linear(128, 28 * 28), nn.Tanh(),
            View((-1, 1, 28, 28)))

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [3]:
#Convolutional model
class autoencoder_conv(nn.Module):
    def __init__(self):
        super(autoencoder_conv, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  # b, 16, 10, 10
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # b, 16, 5, 5
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  # b, 8, 3, 3
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # b, 8, 2, 2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  # b, 16, 5, 5
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  # b, 8, 15, 15
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  # b, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

## 2) Loading Dataset

In [5]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
])


image_folder = './data'

# dataset = ImageFolder('./data', trans..=...)
dataset = MNIST(image_folder, download=True, transform=img_transform)


batch_size = 128
data_iter = DataLoader(dataset, batch_size=batch_size, shuffle=True)

### Splitting test and train

In [6]:

X_train, X_test = train_test_split(dataset, test_size=0.33, random_state=1)

In [7]:
data_iter_train = DataLoader(X_train, batch_size = batch_size, shuffle = True)
data_iter_test = DataLoader(X_test, batch_size = batch_size, shuffle = True)

train_and_test = (data_iter_train, data_iter_test)

## 3) Network Training

In [8]:
model_conv = autoencoder_conv()#.cuda()
model_lin = autoencoder_lin()

#### For visualizing

In [9]:
writer = SummaryWriter()

testdataset = [dataset[1][0].reshape(1,28,28),
               dataset[3][0].reshape(1,28,28),
               dataset[5][0].reshape(1,28,28),
               dataset[7][0].reshape(1,28,28),
               dataset[2][0].reshape(1,28,28),
               dataset[0][0].reshape(1,28,28),
               dataset[13][0].reshape(1,28,28),
               dataset[38][0].reshape(1,28,28),
               dataset[17][0].reshape(1,28,28),
               dataset[4][0].reshape(1,28,28)]

In [10]:
def create_grid_from(data, size, n_vis):
    grid = []
    for j in range(n_vis):
        el = data[j].reshape(*size)
        grid.append(el)
        
    return torchvision.utils.make_grid(grid)
    writer.add_image("Training data", temp, epoch)

In [11]:
def visualize(input, output, testdataset, losses, epoch):
    mean_te_loss, mean_tr_loss = losses
    temp = create_grid_from(input, (1,28,28), 8)
    
    writer.add_image("Training data", temp, epoch)
    
    temp = create_grid_from(output, (1,28,28), 8)

    test_output = []
    for img in testdataset:
        im = img.view(1,1,28,28)
        test_output.append(model_conv(im))
        
    temp2 = create_grid_from(test_output, (1,28,28), 10)
    
    writer.add_image("Model Test Dataset", temp2, epoch)
    writer.add_image("Output", temp, epoch)
    writer.add_scalar("Loss/Train", mean_tr_loss, epoch)
    writer.add_scalar("Loss/Test", mean_te_loss, epoch)

#### Actual training

In [12]:
def train(model, train_and_test, num_epochs = 100, learning_rate = 1e-3):
    
    data_iter_train, data_iter_test = train_and_test

    
    loss_func = nn.MSELoss()

    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-5)
        

    for epoch in range(num_epochs):
        
        n_tr = 0
        total_tr_loss = 0
        for batch in data_iter_train:
            imgs, _ = batch

            input = imgs
            output = model(input)
                    
            loss = loss_func(output, input)
        
            total_tr_loss += loss.item()
            n_tr += 1
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        n_te = 0
        total_te_loss = 0
        with torch.no_grad():
            for batch in data_iter_test:
                imgs, _ = batch
                
                input = imgs
                output = model(imgs)
        
                loss = loss_func(output, input)
                total_te_loss += loss.item()
                n_te += 1
        
        mean_te_loss = total_te_loss/n_te
        mean_tr_loss = total_tr_loss/n_tr
        print('epoch [{}/{}], loss:{:.4f}'
              .format(epoch + 1, num_epochs, total_te_loss))
        
        losses = (mean_te_loss, mean_tr_loss)
        
        visualize(input, output, testdataset, losses, epoch)

In [13]:
train(model_conv, train_and_test, num_epochs = 1)

epoch [1/1], loss:10.2357


In [14]:
train(model_lin, train_and_test, num_epochs = 1)

epoch [1/1], loss:8.3469


## 4) Loading And Saving

##### Save

In [16]:
torch.save(model_conv.state_dict(), './sim_autoencoder.pth')

##### Load

In [None]:
# torch.save(model, './sim_autoencoder.pth')
model.load_state_dict(torch.load('./sim_autoencoder.pth', map_location='cpu'))

## Applying to Images (old)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def forshow(img):
    return transforms.ToPILImage(mode='L')(img.reshape(28,28))

def encode(img_from_dataset):
    norm_img = img_from_dataset.reshape(28*28)
    return model.encoder(norm_img)

def decode(encoding):
    dec_img = model.decoder(encoding)
    return dec_img

def midFrame(img1, img2):
    return decode((encode(img1) + encode(img2))/2)

### Tests

In [None]:
im1 = dataset[0][0]
im2 = dataset[6][0]

forshow(im1).show()
forshow(decode(encode(im1))).show()



In [None]:
mid1 = midFrame(im1, im2)

mid2 = midFrame(mid1,im2)

mid3 = midFrame(im1, mid1)
for i in [im1, mid3, mid1, mid2, im2]:
    forshow(i).show()

In [None]:
img, _ = dataset[0]

my_img = model(img.reshape(28*28))
my_img = my_img.reshape(1,28,28)

results = transforms.ToPILImage(mode='L')(my_img)
results.show()

#output[0]
#image = output.cpu().numpy()[0]
#image = np.transpose(image, (1,2,0))
#plt.matshow(image)
#plt.show()

In [None]:
img, _ = dataset[0]
results = transforms.ToPILImage(mode='L')(img.reshape(1,28,28)) 
results.show()

img2 = model(img.reshape(28*28))
results2 = transforms.ToPILImage(mode='L')(img2.reshape(1,28,28)) 
results2.show()



In [None]:
#show middle
img, _ = dataset[0]
img2, _ = dataset[1]
enc1 = model.encoder(img.reshape(28*28))
enc2 = model.encoder(img2.reshape(28*28))

In [None]:
toPILimg(img).show()
toPILimg(img2).show()
result = (enc1+enc2)/2
decode(result).show()

In [None]:
img5_2, _ = dataset[11]
img5_1, _ = dataset[0]
toPILimg(img5_1).show()
toPILimg(img5_2).show()
midFrame(img5_1, img5_2).show()