In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from os import listdir
import cv2
import random
import sys
sys.path.append("src")
import sindy_utils as sindy

# autoencoder architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__() 
        self.encode = nn.Sequential(
            # encoder: N, 3, 404, 720
            nn.Conv2d(3, 16, 2), # N, 16, 403, 719
            nn.ReLU(),
            nn.Conv2d(16, 32, 2), # N, 32, 402, 718
            nn.ReLU(),
            nn.MaxPool2d((2,3), stride=(2,3)), # N, 32, 201, 239              -- pool --
            nn.Conv2d(32, 64, 4), # N, 64, 198, 236
            nn.ReLU(),
            nn.Conv2d(64, 96, 4), # N, 96, 195, 233
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2), # N, 96, 97, 116                       -- pool --
            nn.Conv2d(96, 128, 5), # N, 128, 93, 112
            nn.ReLU(),
            nn.Conv2d(128, 150, 5, stride=2, padding=1), # N, 150, 46, 55
            nn.ReLU(),
            nn.MaxPool2d(2,stride=2), # N, 150, 23, 27                        -- pool --
            nn.Conv2d(150, 200, 9, stride=2), # N, 200, 8, 10
            nn.ReLU()
        )
        
        self.fc1 = nn.Linear(200*8*10,params['z_dim'])
        # Note: nn.MaxPool2d -> use nn.MaxUnpool2d, or use different kernelsize, stride etc to compensate...
        # Input [-1, +1] -> use nn.Tanh    
        
        # note: encoder and decoder are not symmetric
        self.decode = nn.Sequential(
            nn.ConvTranspose2d(200, 150, 4), # N, 150, 11, 13
            nn.ReLU(),
            nn.ConvTranspose2d(150, 128, 5, stride=(2,3), padding=(2,2), output_padding=(0,2)), # N, 128, 21, 39
            nn.ReLU(),
            nn.ConvTranspose2d(128, 96, 4, stride=2, padding=(1,0)), # N, 96, 42, 80
            nn.ReLU(),
            nn.ConvTranspose2d(96, 64, 8), # N, 64, 49, 87
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 8, stride=2, padding=(2,1), output_padding=(0,1)), # N, 32, 100, 179
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 5, stride=2, padding=1), # N, 16, 201, 359
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 5, stride=2, padding=1, output_padding=(1,1)), # N, 3, 404, 720
            nn.ReLU()
        )   
        
        self.fc2 = nn.Linear(params['z_dim'], 200*8*10)

    def forward(self, x, z, mode):
        '''
        x: input for encoder
        z: input for decoder
        mode: 
            'train' -> use encoded for decoder
            'test'  -> feed z in an get decoded
        
        '''
        if mode == 'train':
            encoded = self.encode(x)
            encoded = encoded.view(-1,200*8*10)
            encoded = self.fc1(encoded)

            decoded = self.fc2(encoded)
            decoded = decoded.view(-1,200,8,10)
            decoded = self.decode(decoded)
        else:
            encoded = torch.zeros(1)

            decoded = self.fc2(z)
            decoded = decoded.view(-1,200,8,10)
            decoded = self.decode(decoded)
        
        return encoded, decoded

    
def calculateSindy(z, Xi, poly_order, include_sine_param):
    z_new = z.detach().numpy()
    
    theta = torch.from_numpy(sindy.sindy_library(z_new, poly_order, include_sine=include_sine_param))
    
    dz_prediction = torch.matmul(theta, Xi).float()
    
    return dz_prediction

In [2]:
# loading model
path_folder = 'results/v5/'

to_load = path_folder+'Ae_4000epoch_bs16_lr1e-5_z2_sindt05_poly5.pt'
autoencoder = torch.load(to_load)
autoencoder = autoencoder.cpu()

# load a train data
path_folder_data = 'results/v5/data/'
train_data = torch.load(path_folder_data + 'train_data.pt')
print('train data: ', len(train_data), len(train_data[0]), len(train_data[0][0]), len(train_data[0][0][0]), len(train_data[0][0][0][0]))
print('train data reading done!')

## load a validation data
#validation_data = torch.load(path_folder_data + 'validation_data.pt')
#print('validation data: ', len(validation_data), len(validation_data[0]), len(validation_data[0][0]), len(validation_data[0][0][0]), len(validation_data[0][0][0][0]))
#print('validation data reading done!')
#
## loading test data
#test_data = torch.load(path_folder_data + 'test_data.pt')
#print('test data: ', len(test_data), len(test_data[0]), len(test_data[0][0]), len(test_data[0][0][0]), len(test_data[0][0][0][0]))
#print('test data reading done!')


train data:  30 16 3 404 720
train data reading done!


In [3]:
def matrixToNorm(x, offset=0, factor=0.95):
    x = (x - x.min() + offset) / x.max() * factor
    return x

poly_order = 4
include_sine_param = False
threshold_sindy = 1000
until = 5                      # choose the number of prediction stepts, 1 step are number of batch_size frames



In [4]:
# use more than 16 frames to get Xi
def constructXi(data, zDim):
    '''
    input: data as a list with shape [len batch_size RGB hight width]
    return: Xi
    
    '''
    # processs the data
    z_tensor = torch.empty((0, zDim))
    data_len = len(data)
    for i in range(data_len):
        z_tensor_tmp, _ = autoencoder(train_data[i], 0, mode='train')
        z_tensor = torch.cat((z_tensor, z_tensor_tmp), 0)
        if i % 5 == 0:
            print(i, z_tensor.shape)
        del z_tensor_tmp

    print(z_tensor.shape)
    
    dz_tensor = z_tensor[2:data_len]
    z_tensor = z_tensor[1:data_len-1]
    
    # calculate sindy and Xi for the data
    z = z_tensor.cpu().detach().numpy()
    dz = dz_tensor.cpu().detach().numpy()

    Theta = torch.from_numpy(sindy.sindy_library(z, poly_order, include_sine=include_sine_param))
    Xi = torch.from_numpy(sindy.sindy_fit(Theta, dz, threshold_sindy))
    
    return Xi

In [5]:
Xi = constructXi(train_data, zDim=2)

print(Xi)

print(z)
print(dz)
print(dz_predict)
print(dz_predict2)

print(recon2_pred_tensor[1].shape, len(recon2_pred_tensor))

0 torch.Size([16, 2])
5 torch.Size([96, 2])
10 torch.Size([176, 2])
15 torch.Size([256, 2])
20 torch.Size([336, 2])
25 torch.Size([416, 2])


RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 591126528 bytes.

In [None]:
# get a Xi for only one picture, idea: mean of all Xi's with threshold
z_tensor, recon_tensor = autoencoder(train_data[0], 0, mode='train')
dz_tensor, recon1_tensor = autoencoder(train_data[1], 0, mode='train')

recon_tensor = matrixToNorm(recon_tensor)
print('max recon_data', recon_tensor.cpu().detach().numpy().max())


z = z_tensor.cpu().detach().numpy()
dz = dz_tensor.cpu().detach().numpy()

Theta = torch.from_numpy(sindy.sindy_library(z, poly_order, include_sine=include_sine_param))
Xi = torch.from_numpy(sindy.sindy_fit(Theta, dz, threshold_sindy))
dz_predict = torch.matmul(Theta, Xi).float()
_, recon1_pred_tensor = autoencoder(0, dz_predict, mode='test')
print('max of z_tensor', z_tensor.cpu().detach().numpy().max())
print('max of dz_tensor', dz_tensor.cpu().detach().numpy().max())
print('max of dz_predict', dz_predict.cpu().detach().numpy().max())

# plot autoencoder result

for nbImag in range(len(train_data[0])):
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(train_data[0][nbImag].permute(1,2,0).detach().numpy())
    plt.subplot(1,2,2)
    plt.imshow(recon_tensor[nbImag].permute(1,2,0).detach().numpy())

print('One step')
recon1_tensor = matrixToNorm(recon1_tensor)
recon1_pred_tensor = matrixToNorm(recon1_pred_tensor)
# plot sindy result
for nbImag in range(len(drecon_tensor)):
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(recon1_tensor[nbImag].permute(1,2,0).detach().numpy())
    plt.subplot(1,2,2)
    plt.imshow(recon1_pred_tensor[nbImag].permute(1,2,0).detach().numpy())

# new calculations
dz_tensor, recon2_tensor = autoencoder(train_data[2], 0, mode='train')

Theta2 = torch.from_numpy(sindy.sindy_library(dz_predict, poly_order, include_sine=include_sine_param))
dz_predict2 = torch.matmul(Theta2, Xi).float()
_, recon2_pred_tensor = autoencoder(0, dz_predict2, mode='test')

print('Two steps with the same Xi')
recon2_tensor = matrixToNorm(recon2_tensor)
for i in range(len(recon2_pred_tensor)):
    recon2_pred_tensor[i] = matrixToNorm(recon2_pred_tensor[i])
for nbImag in range(len(drecon_tensor)):
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(recon2_tensor[nbImag].permute(1,2,0).detach().numpy())
    plt.subplot(1,2,2)
    plt.imshow(recon2_pred_tensor[nbImag].permute(1,2,0).detach().numpy())

## test model
#def test(data):
#    video_reconstruction = []
#    # predict videos
#    for vid_nbr in range(0, len(test_idxOfNewVideo)):
#        print('video number', vid_nbr)
#        # first step encode first batch
#        img = data[test_idxOfNewVideo[vid_nbr]]
#        encode_tensor, recon_tensor = autoencoder(img, 0, mode='train')
#        # too big
#        if vid_nbr == 2:
#            break
#        
#        # predict the future using only sindy model, new video starts always at position vid_nbr * until
#        for i in range(0, until):
#            #print('pred', i)
#            video_reconstruction.append(recon_tensor)
#            dz_tensor = calculateSindy(encode_tensor, Xi, poly_order, include_sine_param)
#            encode_tensor = dz_tensor
#            _, recon_tensor = autoencoder(0, dz_tensor, mode='test')
#            
#    return video_reconstruction
#
#
## test for different models or epoch number, TODO
#video_output = test(train_debug)
#print('prediction done!')
#
##del test_data


In [None]:
# first picture comparision
print(len(video_output))
#for i in range(0, len(video_output)):
plt.figure()
plt.subplot(1,2,1)
plt.imshow(train_debug[0][0].permute(1,2,0).detach().numpy())
plt.subplot(1,2,2)
plt.imshow(video_output[0][0].permute(1,2,0).detach().numpy())

plt.figure()
plt.subplot(1,2,1)
plt.imshow(train_debug[1][0].permute(1,2,0).detach().numpy())
plt.subplot(1,2,2)
plt.imshow(video_output[1][0].permute(1,2,0).detach().numpy())

plt.figure()
plt.subplot(1,2,1)
plt.imshow(train_debug[2][0].permute(1,2,0).detach().numpy())
plt.subplot(1,2,2)
plt.imshow(video_output[2][0].permute(1,2,0).detach().numpy())

plt.figure()
plt.subplot(1,2,1)
plt.imshow(train_debug[3][0].permute(1,2,0).detach().numpy())
plt.subplot(1,2,2)
plt.imshow(video_output[3][0].permute(1,2,0).detach().numpy())


In [None]:
# make videos
frame_width = len(video_output[0][0][0][0])
frame_height = len(video_output[0][0][0])
fps = 25.0

# write different videos
#forcc = cv2.VideoWriter_fourcc('M','J','P','G')
forcc = cv2.VideoWriter_fourcc('D','I','V','3')
#forcc = cv2.VideoWriter_fourcc('F','M','P','4')
out1 = cv2.VideoWriter('figures/run1_lre-5_z5_poly4/videoTest.avi', forcc, 1, (frame_width,frame_height))
#out2 = cv2.VideoWriter('video2.mov',cv2.VideoWriter_fourcc('M','J','P','G'), fps, (frame_width,frame_height))
#out3 = cv2.VideoWriter('video3.mov',cv2.VideoWriter_fourcc('M','J','P','G'), fps, (frame_width,frame_height))
    
print('output video', len(video_output), len(video_output[0]), len(video_output[0][0]), len(video_output[0][0][0]), len(video_output[0][0][0][0]))

for vid_nbr in range(0, len(test_idxOfNewVideo)):
    if vid_nbr == 1:
        break
    # undo batch structure
    videoProcessing = []
    count = -1
    for img in range(0, len(video_output)*len(video_output[0])):
        imgIn_batch = img % batch_size
        # new batch
        if imgIn_batch == 0:
            count += 1
        img_toAppend = video_output[count][imgIn_batch]
        videoProcessing.append(img_toAppend)
        
    #del video_output
    print('video currently procession', len(videoProcessing), len(videoProcessing[0]), len(videoProcessing[0][0]), len(videoProcessing[0][0][0]))
    
    for img in range(0,len(videoProcessing)):
        frame_local = np.transpose(videoProcessing[img].detach().numpy(), [1,2,0])
        frame_local = cv2.cvtColor(frame_local, cv2.COLOR_RGB2BGR)
        # print(frame_local) --> seems unstable, not a number and doesn't save it as a video
        out1.write(frame_local.shape)
        # show video
        cv2.imshow('Frame',frame_local)

        # Press Q on keyboard to  exit
        if cv2.waitKey(25) & 0xFF == ord('q') and img >= 10:
            break

    # When everything done, release the video capture and video write objects
    out1.release()

    # Closes all the frames
    cv2.destroyAllWindows()

print('finished prediction video output!')