In [None]:
import os

vin = 'out.mp4'
out_dir = 'results/'


In [None]:
from feature_extraction import PatchFeatureGenerator   
    
pfg = PatchFeatureGenerator('dinov2_vits14')
patch_size = pfg.model.patch_size

In [None]:
from video_processing import VideoManager

kernel_frame_size = 8
vm = VideoManager(vin, out_dir, kernel_frame_size)

In [None]:
from matplotlib import pyplot as plt

kernel_frame = vm.getkernelframe()
plt.figure(figsize=(10, 8))
for i in range(4):
    it = kernel_frame[i]
    plt.subplot(1, 4, i+1)
    plt.imshow(it)
    plt.axis('off')
plt.show()
plt.close()

In [None]:
# transform

import torchvision.transforms as tt
resolution = 518
patch_len = resolution // 14

img2tensor = tt.Compose([
    tt.ToTensor(), # range [0, 255] -> [0.0,1.0]
    tt.Resize((resolution, resolution) ),
    tt.Normalize(mean=0.5, std=0.2), # range [0.0,1.0] -> [-2.5, 2.5]

])

tensor2img = tt.Compose([
    tt.Normalize(mean=-2.5, std=5), # range [-2.5, 2.5] -> [0.0,1.0]
    tt.ToPILImage()
])

In [None]:
from feature_extraction import FeatureFilter
import numpy as np


kernel_frame = vm.getkernelframe()
kernel_tensor = [img2tensor(it) for it in kernel_frame]
patch_feature = pfg.batch_run(kernel_tensor)
print(patch_feature.shape)

In [None]:
X = np.array(patch_feature)
n_kernel, n_token_per_frame, n_feat = X.shape
print(n_kernel, n_token_per_frame, n_feat)
X = X.reshape((-1, n_feat))
print(X.shape)

In [None]:
ff = FeatureFilter()


In [None]:
lvl1_pca = 1
ff.addLayer(X, lvl1_pca)
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl1_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl1_pca)), extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()
plt.close()

In [None]:
ff.setLayerThreshold(0, [(.35, .6)])
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl1_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame, lvl1_pca))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl1_pca)), extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(10, 8))
for i, _ in enumerate(kernel_frame):
    plt.subplot(4, 4, i+1)
    plt.imshow(kernel_frame[i], extent=(0, resolution, resolution,0))
    plt.imshow(frames_feat[i].reshape((patch_len, patch_len, lvl1_pca)), extent=(0, resolution, resolution,0), alpha=0.5)

    plt.axis('off')
plt.show()
plt.close()

In [None]:
lvl2_pca = 3
ff.addLayer(X, lvl2_pca)
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl2_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame, lvl2_pca))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl2_pca)), extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.show()
plt.close()

In [None]:
ff.setLayerThreshold(1, [(.0, .6), (.1, .9), (.6, 1.)])
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl2_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame, lvl2_pca))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl2_pca)), extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(10, 8))
for i, _ in enumerate(kernel_frame):
    plt.subplot(4, 4, i+1)
    plt.imshow(kernel_frame[i], extent=(0, resolution, resolution,0))
    plt.imshow(frames_feat[i].reshape((patch_len, patch_len, lvl2_pca)), extent=(0, resolution, resolution,0), alpha=0.5)

    plt.axis('off')
plt.show()
plt.close()

In [None]:
lvl3_pca = 1
ff.addLayer(X, lvl3_pca)
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl3_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame, lvl3_pca))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl3_pca)), extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.show()
plt.close()

In [None]:
ff.setLayerThreshold(2, [(.35, .55)])
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl3_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame, lvl3_pca))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl3_pca)), extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(10, 8))
for i, _ in enumerate(kernel_frame):
    plt.subplot(4, 4, i+1)
    plt.imshow(kernel_frame[i], extent=(0, resolution, resolution,0))
    plt.imshow(frames_feat[i].reshape((patch_len, patch_len, lvl3_pca)), extent=(0, resolution, resolution,0), alpha=0.5)

    plt.axis('off')
plt.show()
plt.close()

In [None]:
lvl4_pca = 5
ff.addLayer(X, lvl4_pca)

In [None]:
ff.setLayerThreshold(3, [(.0, 1.), (.0, 1.), (.0, 1.), (.0, 1.), (.0, 1.)])
mm, zz = ff.getFeature(X)
print(mm.shape, zz.shape)
frames_feat = np.zeros((n_kernel * n_token_per_frame, lvl4_pca))
frames_feat[mm] = zz
frames_feat = frames_feat.reshape((n_kernel, n_token_per_frame, lvl4_pca))
kernel_frame = vm.getkernelframe()
plt.imshow(kernel_frame[0], extent=(0, resolution, resolution,0))
plt.imshow(frames_feat[0].reshape((patch_len, patch_len, lvl4_pca))[:, :, 3], extent=(0, resolution, resolution,0), alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(10, 8))
for i, _ in enumerate(kernel_frame):
    plt.subplot(4, 4, i+1)
    plt.imshow(kernel_frame[i], extent=(0, resolution, resolution,0))
    plt.imshow(frames_feat[i].reshape((patch_len, patch_len, lvl4_pca))[:, :, 3], extent=(0, resolution, resolution,0), alpha=0.5)

    plt.axis('off')
plt.show()
plt.close()

In [None]:
import glob
import os
from tqdm import tqdm

class Watcher:
    def __init__(self, videoManager, patchFeatureGenerator, featureFilter, export):
        self.frames = videoManager
        self.pfg = patchFeatureGenerator
        self.ff = featureFilter
        self.export = export
        self.attn_dir = export + 'attn/'
        self.num_features = self.pfg.model.num_features
    def run(self):
        os.makedirs(self.export, exist_ok=True)
        os.makedirs(self.attn_dir, exist_ok=True)
        last_dim_pca = self.ff.blk[-1].n_pca
        
        for i in tqdm(range(len(self.frames.full_frame))):
            frame = self.frames.getframe(i)
            tensor = img2tensor(frame)
            patch_feature = self.pfg.single_run(tensor)
            X = patch_feature.reshape((-1, self.num_features))
            mm, zz = self.ff.getFeature(X)
            assert len(self.ff.blk) != 0, "no filter in featureFilter."
            
            last_dim_pca = self.ff.blk[-1].n_pca
            frames_feat = np.zeros((n_token_per_frame, last_dim_pca))
            frames_feat[mm] = zz
            plt.imshow(frame, extent=(0, resolution, resolution, 0))
            plt.imshow(frames_feat.reshape((patch_len, patch_len, last_dim_pca))[:, :, 3], extent=(0, resolution, resolution,0), alpha=0.3, cmap="inferno")

            plt.axis('off')
            plt.savefig(self.attn_dir + 'attn-{:04d}.jpg'.format(i), bbox_inches='tight')
            plt.close()



wa = Watcher(vm, pfg, ff, out_dir)
wa.run()

In [None]:
import cv2
from PIL import Image
video_format = 'mp4'
fps = 80

FOURCC = {
    "mp4": cv2.VideoWriter_fourcc(*"MP4V"),
    "avi": cv2.VideoWriter_fourcc(*"XVID"),
}


def generate_video_from_images(inp: str, out: str):
    img_array = []
    attention_images_list = sorted(glob.glob(os.path.join(inp, "*.jpg")))

    # Get size of the first image
    with open(attention_images_list[0], "rb") as f:
        img = Image.open(f)
        img = img.convert("RGB")
        size = (img.width, img.height)


    print(f"Generating video {size} to {out}")

    for filename in tqdm(attention_images_list[1:]):
        with open(filename, "rb") as f:
            img = Image.open(f)
            img = img.convert("RGB")
            img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))

    out = cv2.VideoWriter(
        os.path.join(out, "video." + video_format),
        FOURCC[video_format],
        fps,
        size,
    )

    for i in range(len(img_array)):
        out.write(img_array[i])
    out.release()
    print("Done")

generate_video_from_images(wa.attn_dir, out_dir)