In [None]:
%matplotlib notebook
from matplotlib import pyplot as plt
from scipy.spatial.distance import pdist, squareform

import imageio
import cv2
import numpy as np
from PIL import Image, ImageSequence, ImageChops
from io import BytesIO
from imageio.core import CannotReadFrameError

from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets

import os
import sqlite3
from tempfile import TemporaryDirectory

In [None]:
video_to_load = r"your_vid_here.mp4"

# Load video/gif

## First set up the database

In [None]:
test = TemporaryDirectory()

In [None]:
os.path.join(test.name, "temp.db")

In [None]:
#Create empty database for video
conn = sqlite3.connect(os.path.join(test.name, "temp.db"))
with conn:
    conn.execute("DROP TABLE IF EXISTS image")
    conn.execute("""
        CREATE TABLE IF NOT EXISTS image(v BLOB NOT NULL, -- video image frames
                                         f INTEGER NOT NULL, -- frame number
                                         n INTEGER NOT NULL DEFAULT 0, -- undoable actions
                                         UNIQUE(f, n) ON CONFLICT REPLACE
                                        )
        """)
    conn.execute("CREATE INDEX IF NOT EXISTS frame ON image(f)")

In [None]:
def to_db(image: np.ndarray, frame: int, edits: int=0):
    v = cv2.imencode('.jpg',cv2.cvtColor(image, cv2.COLOR_RGBA2BGR))[1].tostring()
#         # JPEG at the default .95 is sufficient for my purposes.
#         # If you want better, you can pass in quality parameters to the imencode call, or switch formats.
    with conn:
        conn.execute("INSERT INTO image VALUES (?, ?, ?)", (v, frame, edits))
        
def to_raw(image: np.ndarray, frame: int, edits: int):
    out = BytesIO()
    np.save(out, image)
    out.seek(0)
    with conn:
        conn.execute("INSERT INTO image VALUES (?, ?, ?)", (out.read(), frame, edits))#.encode('zlib'), frame, edits))

In [None]:
def from_db(frame: int, edits=None) -> bytes:
    if edits is None:
        edits = "(SELECT MAX(n) FROM image)"
    return conn.execute(f"SELECT v FROM image WHERE f=? and n={edits}",
                     (frame,)).fetchone()[0]

def from_raw(frame: int, edits: int) -> np.ndarray:
    out = BytesIO(conn.execute(f"SELECT v FROM image WHERE f=? and n=?",
                               (frame, edits)).fetchone()[0])
    out.seek(0)
    out = BytesIO(out.read())#.decode('zlib'))
    return np.load(out)

def slice_db(frames=None, edits=None):
    if edits is None:
        edits = "(SELECT MAX(n) FROM image)"
    if frames is None:
        return conn.execute(f"SELECT v, f, n FROM image WHERE n={edits}").fetchall()
    return conn.execute(f"SELECT v, f, n FROM image WHERE f>=? and f<? and n={edits}",
                     (frame[0], frame[1])).fetchall()

def as_array(byte: bytes) -> np.ndarray:
    return cv2.cvtColor(cv2.imdecode(np.frombuffer(byte, np.uint8), cv2.IMREAD_ANYCOLOR), cv2.COLOR_RGBA2BGR)

## Database recovery on crash

In [None]:
database_to_recover = r'old_path_here'

In [None]:
conn = sqlite3.connect(database_to_recover)

In [None]:
fps = 30

## Load

In [None]:
try:
    if video_to_load[:5] ==".webp":
        reader = ImageSequence.Iterator(Image.open(video_to_load))
        fps = 30 # reader.get_meta_data()['fps'] # REVIEW: UNKNOWN?!
    else:
        if video_to_load[:4] == ".gif":
            reader = imageio.get_reader(video_to_load)
        else:
            reader = imageio.get_reader(video_to_load, 'ffmpeg')
        fps = reader.get_meta_data()['fps']
    for i,im in enumerate(reader):
        to_db(np.array(im), i)
        #print(i*100/len(reader), end='%                      \r')
        print(i, end='                      \r')
    print("100%                      ")
except CannotReadFrameError:
    print("Read", i, "of", len(reader), "frames. This may be enough for what we want.")

In [None]:
orig_len = conn.execute("SELECT COUNT(DISTINCT f) FROM image").fetchone()[0]
orig_len, fps

# Inspect for duplicates (optional)

In [None]:
# A few of my videos got artificially upsampled with duplicate frames. No idea why.
# It's pretty uniform, so this is a sufficient fix.
source = [f for f, in conn.execute(f"SELECT f FROM image").fetchall()]
prv = from_db(0)
remove = []
for i in source:
    nxt = from_db(i)
    if np.all(prv==nxt):
        remove.append(i)
    prv = nxt
for i in remove:
    conn.execute("DELETE FROM image WHERE f=?", (i,))
conn.commit()

In [None]:
# Update the fps for when we write it back.
deduplicated_len = conn.execute("SELECT COUNT(DISTINCT f) FROM image").fetchone()[0]
deduplicated_fps = deduplicated_len*fps/orig_len
orig_fps, fps = fps, deduplicated_fps
deduplicated_len, deduplicated_fps

# Trim video to relevant frames
I recommend leaving a few past the section where a loop is to be searched for.

In [None]:
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
show = widgets.Image(value=from_db(source[0]))
play = widgets.Play(
    interval=1/fps,
    value=0,
    min=0,
    max=len(source)-1,
    step=1,
    description="Press play",
    disabled=False
)
slider = widgets.IntSlider(min=0, max=len(source)-1)
widgets.jslink((play, 'value'), (slider, 'value'))
def load_frame(change):
    show.value = from_db(source[change['new']])
slider.observe(load_frame, names='value')
control = widgets.HBox([play, slider])
#widgets.VBox([show, control])
display(show)
display(control)
start = widgets.IntSlider(min=0, max=len(source)-1, step=1, value=0)
end = widgets.IntSlider(min=0.0, max=len(source)-1, step=1, value=len(source)-1)
time_step = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=1/fps)
@interact(start=start, end=end, time_step=time_step)
def f(start, end, time_step):
    play.min = start
    play.max = end
    play.interval = time_step*1000

In [None]:
[source[play.min], source[play.max]+1]

In [None]:
play.min, play.max

In [None]:
with conn:
    conn.execute("""
        INSERT INTO image(v, f, n)
        SELECT v, f, n+1 FROM image
        WHERE f>=? and f<? and n=(SELECT MAX(n) FROM image)
    """, [source[play.min], source[play.max]+1])

# Crop (optional)

In [None]:
# Get 100 frames evenly throughout the most recent edit of the video.
# Average them together to get the general image gist to display.
# Draw average with a crop-to box that has widget controls.
# Apply. (Kills widgets)
#TODO!

# AFTER getting the indices correct above, run the below to look for loops

In [None]:
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]

In [None]:
conn.execute("DELETE FROM image WHERE n=-1 AND f NOT IN (SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image))")
downsize_shape = (100, 100)
# -1 is the downsized temp images for now.
for v, f, _ in slice_db():
    print(f, end='\r')
    to_db(np.array(Image.open(BytesIO(v)).resize(downsize_shape)), f, -1)

In [None]:
images_downsized = [as_array(im) for im, in conn.execute(f"SELECT v FROM image WHERE n=-1").fetchall()]
differences = pdist(np.reshape(images_downsized, [len(images_downsized), -1]))

In [None]:
# Attempt to line up two frames worth of matching.
test = squareform(differences)
test[1:, 1:] *= test[:-1, :-1]
test[0, :] = 0
test[:, 0] = 0
# Suppress local non-minima
result = test.copy()
result[:-1][test[:-1] > test[1:]] = 0
result[1:][test[:-1] < test[1:]] = 0
result[:,:-1][test[:,:-1] > test[:,1:]] = 0
result[:,1:][test[:,:-1] < test[:,1:]] = 0
plt.figure()
plt.imshow(result)

In [None]:
(result!=0).sum()

In [None]:
# Trim suggestions to be between the below parameters (Set them how you want it)
max_gif_len = 100 # max: trimmed length
min_gif_len = 10 # min: 1 (not suggested)
mask = np.array([[min_gif_len<j-i<max_gif_len for j in range(len(source))] for i in range(len(source))])
result_trimmed = result.copy()*mask
plt.imshow(result_trimmed)

In [None]:
# (OPTIONAL)
# This weights the loop by how much change is in the inbetween frames.
experimental = squareform(differences)
mask = np.array([[j>=i for j in range(len(source))] for i in range(len(source))])
experimental = experimental*mask
experimental = 1+np.maximum.accumulate(experimental, axis=1)
result_trimmed /= experimental
plt.imshow(result_trimmed)

In [None]:
# Pull out the best match and try it.
nonzeros = np.nonzero(result_trimmed)
local_mins = result_trimmed[nonzeros]
minimum_change = np.argmin(local_mins)
nonzeros[0][minimum_change], nonzeros[1][minimum_change]

In [None]:
nonzeros

# Denoise

## Spatial denoising (probably run first):

In [None]:
edit_current = conn.execute(f"SELECT MAX(n) FROM image").fetchone()[0]
edit_next = edit_current + 1
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]

In [None]:
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
show = widgets.Image(value=from_db(source[0]))
play = widgets.Play(
    interval=1/fps,
    value=0,
    min=0,
    max=len(source)-1,
    step=1,
    description="Press play",
    disabled=False
)
slider = widgets.IntSlider(min=0, max=len(source)-1)
widgets.jslink((play, 'value'), (slider, 'value'))
control = widgets.HBox([play, slider])
#widgets.VBox([show, control])
display(show)
display(control)
filterRadius = widgets.IntSlider(min=0, max=100, step=1, value=2)
sigmaColor = widgets.IntSlider(min=0, max=255, step=1, value=75)
sigmaDist = widgets.IntSlider(min=0.0, max=255, step=1, value=75)
time_step = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=1/fps)
def load_frame(change):
    im = cv2.bilateralFilter(as_array(from_db(source[slider.value], edit_current)),
                             filterRadius.value*2+1, sigmaColor.value, sigmaDist.value)
    show.value = cv2.imencode('.jpg', im.astype(np.uint8)[...,::-1])[1].tostring()
@interact(filterRadius=filterRadius, sigmaColor=sigmaColor, sigmaDist=sigmaDist, time_step=time_step)
def f(filterRadius, sigmaColor, sigmaDist, time_step):
    play.interval = time_step*1000
    load_frame(None)
slider.observe(load_frame, names='value')

In [None]:
# https://en.wikipedia.org/wiki/Bilateral_filter#Definition
print(filterRadius.value*2+1, sigmaColor.value, sigmaDist.value)
for f in source:
    print("Denoising frame:", f, end='\r')
    denoised = cv2.bilateralFilter(as_array(from_db(f, edit_current)),
                                   filterRadius.value*2+1, sigmaColor.value, sigmaDist.value)
    to_db(denoised, f, edit_next)

### This should give you a decent value to set for the sigmaColor both above and below
but requires running the above first, so you may want to undo it (Undo is under the Temporal filter on this page)

In [None]:
compare_edit = conn.execute(f"SELECT MAX(n) FROM image").fetchone()[0]
compare_to = compare_edit - 1
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=?", [compare_edit]).fetchall()]
dffs = np.array([np.abs(as_array(from_db(f, compare_edit)).astype(float) - as_array(from_db(f, compare_to)).astype(float)) for f in source])

In [None]:
noise_thresh = np.mean(dffs) + 5*np.std(dffs)

In [None]:
noise_thresh

## Temporal filter (probably run second):

In [None]:
edit_current = conn.execute(f"SELECT MAX(n) FROM image").fetchone()[0]
edit_next = edit_current + 1
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]

In [None]:
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
show = widgets.Image(value=from_db(source[0]))
play = widgets.Play(
    interval=1/fps,
    value=0,
    min=0,
    max=len(source)-1,
    step=1,
    description="Press play",
    disabled=False
)
slider = widgets.IntSlider(min=0, max=len(source)-1)
widgets.jslink((play, 'value'), (slider, 'value'))
control = widgets.HBox([play, slider])
#widgets.VBox([show, control])
display(show)
display(control)
sigmaColor2 = widgets.IntSlider(min=0, max=255, step=1, value=75)
sigmaTime = widgets.IntSlider(min=0.0, max=255, step=1, value=75)
time_step = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=1/fps)
def load_frame(change):
    prv = as_array(from_db(source[(slider.value-1)%len(source)], edit_current))
    cur = as_array(from_db(source[(slider.value)], edit_current))
    nxt = as_array(from_db(source[(slider.value+1)%len(source)], edit_current))
    weight_next = np.exp(-1/(2*sigmaTime.value**2) - np.square(nxt.astype(float)-cur.astype(float))/(2*sigmaColor2.value**2))
    weight_prev = np.exp(-1/(2*sigmaTime.value**2) - np.square(prv.astype(float)-cur.astype(float))/(2*sigmaColor2.value**2))
    out = (nxt*weight_next + cur + prv*weight_prev) / (1 + weight_next + weight_prev)
    show.value = cv2.imencode(".jpg", out[...,::-1])[1].tostring()
@interact(sigmaColor2=sigmaColor2, sigmaTime=sigmaTime, time_step=time_step)
def f(sigmaColor2, sigmaTime, time_step):
    play.interval = time_step*1000
    load_frame(None)
slider.observe(load_frame, names='value')

In [None]:
# https://en.wikipedia.org/wiki/Bilateral_filter#Definition
# Single dimension version, with time being the dimension.
first = None
second = None
prev = None
pprv = None
dff = None
show = widgets.Image(value=from_db(source[0], edit_current))
display(show)
for f in source:
    print(f, end='\r')
    cur = as_array(from_db(f, edit_current))
    if first is None:
        first = cur, f
    elif second is None:
        second = cur, f
    if pprv is not None:
        #exp(-time_dist/(2*sigmaTime.value^2) - color_dist/(2*sigmaColor2.value^2))
        #weight_cur = 1
        weight_next = np.exp(-1/(2*sigmaTime.value**2) - np.square(cur.astype(float)-prev[0].astype(float))/(2*sigmaColor2.value**2))
        weight_prev = np.exp(-1/(2*sigmaTime.value**2) - np.square(pprv[0].astype(float)-prev[0].astype(float))/(2*sigmaColor2.value**2))
        out = (cur*weight_next + prev[0] + pprv[0]*weight_prev) / (1 + weight_next + weight_prev)
        show.value = cv2.imencode(".jpg", out[...,::-1])[1].tostring()
        to_db(out.astype(np.uint8), prev[1], edit_next)
    pprv = prev
    prev = cur, f

weight_next = np.exp(-1/(2*sigmaTime.value**2) - np.square(first[0].astype(float)-prev[0].astype(float))/(2*sigmaColor2.value**2))
weight_prev = np.exp(-1/(2*sigmaTime.value**2) - np.square(pprv[0].astype(float)-prev[0].astype(float))/(2*sigmaColor2.value**2))
out = (first[0]*weight_next + prev[0] + pprv[0]*weight_prev) / (1 + weight_next + weight_prev)
show.value = cv2.imencode(".jpg", out[...,::-1])[1].tostring()
to_db(out.astype(np.uint8), prev[1], edit_next)
pprv = prev
prev = first

weight_next = np.exp(-1/(2*sigmaTime.value**2) - np.square(second[0].astype(float)-prev[0].astype(float))/(2*sigmaColor2.value**2))
weight_prev = np.exp(-1/(2*sigmaTime.value**2) - np.square(pprv[0].astype(float)-prev[0].astype(float))/(2*sigmaColor2.value**2))
out = (second[0]*weight_next + prev[0] + pprv[0]*weight_prev) / (1 + weight_next + weight_prev)
show.value = cv2.imencode(".jpg", out[...,::-1])[1].tostring()
to_db(out.astype(np.uint8), prev[1], edit_next)
pprv = prev
prev = first

# Undo

In [None]:
conn.execute(f"DELETE FROM image WHERE n=(SELECT MAX(n) FROM image)")

In [None]:
conn.execute("SELECT MAX(n), COUNT(f) FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchone()

## This can undo the undo if you do it before doing anything else
but it will only recover back to whatever you had before you ran any undos in that time period, not selectively one redo at a time.

In [None]:
conn.rollback()

# Examine results

In [None]:
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
show = widgets.Image(value=from_db(source[0]))
play = widgets.Play(
    interval=1/fps,
    value=0,
    min=0,
    max=len(source)-1,
    step=1,
    description="Press play",
    disabled=False
)
slider = widgets.IntSlider(min=0, max=len(source)-1)
widgets.jslink((play, 'value'), (slider, 'value'))
def load_frame(change):
    show.value = from_db(source[change['new']])
slider.observe(load_frame, names='value')
control = widgets.HBox([play, slider])
#widgets.VBox([show, control])
display(show)
display(control)
start = widgets.IntSlider(min=0, max=len(source)-1, step=1, value=0)
end = widgets.IntSlider(min=0.0, max=len(source)-1, step=1, value=len(source)-1)
time_step = widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=1/fps)
@interact(start=start, end=end, time_step=time_step)
def f(start, end, time_step):
    play.min = start
    play.max = end
    play.interval = time_step*1000

# Save to disk

## First, a master copy in mp4 format. Can be used with GifV.

In [None]:
writer = imageio.get_writer('loop.mp4', fps=fps)

for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall():
    writer.append_data(as_array(from_db(f,0)))
writer.close()

## Also, WebP, because it's better than gif in pretty much every way.

In [None]:
# Gifs kind of suck. I suggest webp, if this returns true.
from PIL import features
features.check("webp_anim")

In [None]:
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
frames = (Image.open(BytesIO(from_db(im))).convert('RGBA') for im in source)
first = next(frames)
x,y = first.size
tosave = Image.new('RGBA', (x,y))
tosave.paste(first, (0,0,x,y), first)
tosave.save('output2.webp', save_all=True, append_images=frames, duration=int(1000/fps), loop=0)

In [None]:
# Save for web. Play with quality vs resize here.
resize = 25 #%
quality = 80
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
x, y = Image.open(BytesIO(from_db(source[0]))).convert('RGBA').size
x //= 100/resize
y //= 100/resize
x, y = int(x), int(y)
source = [f for f, in conn.execute(f"SELECT f FROM image WHERE n=(SELECT MAX(n) FROM image)").fetchall()]
frames = (Image.open(BytesIO(from_db(im))).convert('RGBA').resize((x,y), Image.ANTIALIAS) for im in source)
first = next(frames)
tosave = Image.new('RGBA', (x,y))
tosave.paste(first, (0,0,x,y), first)
tosave.save(f'output-{resize}-{quality}.webp', save_all=True, append_images=frames, quality=quality, duration=int(1000/fps), loop=0)

# If you really want a gif, you'll have to do it this way:

In [None]:
writer = imageio.get_writer('loop.gif',fps=int(fps))

for im in source3:
    writer.append_data(cv2.cvtColor(cv2.imdecode(np.frombuffer(im, np.uint8), cv2.IMREAD_ANYCOLOR), cv2.COLOR_RGBA2BGR))

writer.close()

# At a minimum, you should save a MP4
They're smaller and less lossy. And you can use the below to create a gif from them.

In [None]:
import imageio
import numpy as np
reader = imageio.get_reader("loop.mp4", 'ffmpeg')
writer = imageio.get_writer('loop.gif',fps=reader.get_meta_data()['fps'])

print("Frames processed:", 0, end='\r')
for index, frame in enumerate(reader):
    x,y = frame.size
    writer.append_data(np.array(frame))
    print("Frames processed:", index+1, end='\r')
writer.close()