In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import cv2
from time import time, sleep

from utils import *
from custom_pca import custom_pca
from video_loader import VideoLoader

### Example of video transformation

In [None]:
model = custom_pca()

video = VideoLoader('data/sample20s.mp4', grayscale=True)
t1 = time()
frames_rand = video.get_random_frames(0.6)
model.fit(frames_rand)
t2 = time()
reconstructed = []
for j, frames in enumerate(video):
    reconstructed.append(model.inverse_transform(model.transform(frames), shape=(video.height, video.width)))
reconstructed = np.vstack(reconstructed)
t3 = time()
print(reconstruction_error(video.get_all_frames(), reconstructed))
t4 = time()

print('Fitting time:', t2-t1)
print('Transform:', t3-t2)
print('Error calculation:', t4-t3)

### LDS framework

Default of linear dynamic system is an ARMA:
$$x_{t+1} = Ax_t + Bv_t $$
$$y_t = \phi(x_t) + w_t$$

We will begin with:
$$x_{t+1} = Ax_t + B $$
$$y_t = \phi(x_t)$$
The $y$'s are the frames, $x$ the low dimensional representation (obtained with pca), $\phi$ the inverse transform of pca

In [40]:
import torch.nn as nn
import torch.nn.functional as F

ncomp = 20
bs = 64
num_epoch = 1000

In [24]:
model = custom_pca(ncomp=ncomp)

video = VideoLoader('data/sample20s.mp4', grayscale=True)
frames_rand = video.get_random_frames(0.5)
model.fit(frames_rand)

reduced = []
for frames in video:
    reduced.append(model.transform(frames))
xs = np.vstack(reduced)

reconstructed = []
for frame in xs:
    reconstructed.append(model.inverse_transform(frame[np.newaxis,:], 
                                                 shape=(video.height, video.width)))
reconstructed = np.vstack(reconstructed)

In [38]:
def criterion(output, gt):
    return torch.sqrt(torch.mean((output[:-1] - gt[1:])**2))

In [51]:
#xs = torch.from_numpy(xs)
lds = nn.Linear(ncomp, ncomp)
optimizer = torch.optim.SGD(lds.parameters(), 0.0001)
for i in range(num_epoch):
    for batch in xs.split(bs):
        output = lds(batch)
        loss = criterion(output, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if i % 50 == 0:
        with torch.no_grad():
            print(criterion(lds(xs), xs).item())

244.19300842285156
65.34095764160156
51.71570587158203
44.89241409301758
40.84038543701172
38.30612564086914
36.67134094238281
35.57298278808594
34.81026077270508
34.250606536865234
33.81930160522461
33.47224807739258
33.18277359008789
32.93437194824219
32.71657943725586
32.52256393432617
32.34769821166992
32.18877410888672
32.04341506958008
31.90989875793457


In [77]:
xt = xs[0]
generated = [xt.numpy()]
for i in range(len(xs)-1):
    xt = lds(xt)
    generated.append(xt.detach().numpy())

ys = []
for xt in generated:
    ys.append(model.inverse_transform(xt).reshape(video.height, video.width))
ys = np.array(ys)

In [75]:
show_video(ys)

In [79]:
print("Error of reconstruction:")
print("Without prediction:", reconstruction_error(video.get_all_frames(), reconstructed))
print("With prediction:", reconstruction_error(video.get_all_frames(), ys))

Error of reconstruction:
Without prediction: 2.3006290616128737
With prediction: 8.962386536253646
