In [27]:
from PIL import Image
from julia.api import Julia
from dpmmpythonStreaming.dpmmwrapper import DPMMPython
from dpmmpythonStreaming.priors import niw
from utils import (
    gif_to_stream,
    stream_to_np_array,
    transpose_np_array,
    stream_to_gif,
    np_array_to_stream
)
import numpy as np
from typing import List
jl = Julia(compiled_modules=False)

## First load the gif into a stream of images

In [2]:
gif_path = "./gifs/border_collie_running.gif"
frames: List[Image.Image] = gif_to_stream(gif_path)
# resize if needed
# frames = resize_stream(frames, (200, 150))
print(f"Loaded {len(frames)} frames from {gif_path} with size {frames[0].size}")

Loaded 8 frames from ./gifs/border_collie_running.gif with size (300, 213)


## Change data structure to fit the DPMMPython interface

In [3]:
data = stream_to_np_array(frames, dtype=np.float32)
print(f"Data shape: {data.shape}")
data = transpose_np_array(data)
shape = data.shape
print(f"Data shape after transpose of each frame: {shape}")

Data shape: (8, 63900, 3)
Data shape after transpose of each frame: (8, 3, 63900)


## Creating a prior suited for images using the niw prior, and other required parameters
###### Note: in NIW definition, nu should be greater than the dimension of the data, kappa should be greater than 0, mu should be a vector of the same dimension as the data, and phi should be a square matrix of the same dimension as the data

In [4]:
prior = niw(1, np.zeros(shape[1]), 3, np.eye(shape[1]) * 0.5)
alpha = 10.0
head_frame = data[0]
tail_frames = data[1:]

labels_list = []

## Initialize the model with the first frame

In [5]:
model = DPMMPython.fit_init(data=head_frame, alpha=alpha, prior=prior, verbose=True, burnout=5, gt=None, epsilon=0.0000001)
labels = np.array(DPMMPython.get_labels(model))
labels_list.append(labels)

## Fit the model to the rest of the frames

In [58]:
model.group.local_clusters[0].cluster_params.cluster_params.distribution.μ

array([137.8177  , 121.021324, 100.37352 ], dtype=float32)

In [25]:
for frame in tail_frames:
    model = DPMMPython.fit_partial(model=model, iterations=1, t=2, data=frame)
    labels = np.array(DPMMPython.get_labels(model))
    labels_list.append(labels)

In [26]:
print (f"labels list shape: {np.array(labels_list).shape}")
print(labels_list[0])

labels list shape: (8, 63900)
[ 3 36 36 ...  1  1  1]


## Create centroids from the labels and replace each pixel with the centroid of its cluster

In [34]:
def create_clustered_data(data, labels) -> np.ndarray:
    clusters = np.unique(labels)
    new_data = np.zeros(data.shape)
    for cluster in clusters:
        cluster_data = data[labels == cluster]
        centroid = np.mean(cluster_data, axis=0)
        new_data[labels == cluster] = centroid
    return new_data

clustered_data_array = []
for i in range(len(labels_list)):
    clustered_data = create_clustered_data(data[i].T, labels_list[i])
    clustered_data_array.append(clustered_data)
clustered_data_stream = np.array(clustered_data_array)
print(f"Clustered data stream shape: {clustered_data_stream.shape}")

Clustered data stream shape: (8, 63900, 3)


## Reconstruct a gif from the clustered data

In [40]:
size = (frames[0].size[1], frames[0].size[0])
clustered_frames = np_array_to_stream(clustered_data_array, size=size, dtype=np.uint8)
stream_to_gif(clustered_frames, gif_path.replace(".gif", "_clustered.gif"))