In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import random
from PIL import Image, ImageFilter
import io

import gaussian_mixture


def dump_to_pickle(data, file_path: str):
    try:
        with open(file_path, "wb") as file:
            pickle.dump(data, file)
        print(f"Data successfully dumped to {file_path}")
    except Exception as e:
        print(f"An error occurred while dumping data: {e}")


def load_from_pickle(file_path: str):
    try:
        with open(file_path, "rb") as file:
            data = pickle.load(file)
        print(f"Data successfully loaded from {file_path}")
        return data
    except Exception as e:
        print(f"An error occurred while loading data: {e}")
        return None


img_frs_seq = load_from_pickle("img_frs_seq.pkl")[:10]
msk_frs_seq = load_from_pickle("msk_frs_seq.pkl")[:10]
memo_flows_fwd = load_from_pickle("memo_flows_fwd.pkl")[:10]
memo_flows_bwd = load_from_pickle("memo_flows_bwd.pkl")[:10]

Data successfully loaded from img_frs_seq.pkl
Data successfully loaded from msk_frs_seq.pkl
Data successfully loaded from memo_flows_fwd.pkl
Data successfully loaded from memo_flows_bwd.pkl


In [3]:
print(msk_frs_seq[0].shape)
msk_frs_seq[0]

(540, 960)


array([[  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       ...,
       [255, 255, 255, ...,   0,   0,   0],
       [255, 255, 255, ...,   0,   0,   0],
       [255, 255, 255, ...,   0,   0,   0]], dtype=uint8)

In [4]:
print(memo_flows_fwd[0].shape)
memo_flows_fwd[0]

(540, 960, 2)


array([[[-4.132888  , -0.46109635],
        [-4.1293077 , -0.470209  ],
        [-4.127161  , -0.47790512],
        ...,
        [-4.3802576 , -0.28641346],
        [-4.368211  , -0.2896058 ],
        [-4.3526516 , -0.29223448]],

       [[-4.125264  , -0.46568415],
        [-4.1210237 , -0.47342286],
        [-4.118477  , -0.48078793],
        ...,
        [-4.3693104 , -0.2865148 ],
        [-4.3581553 , -0.2894902 ],
        [-4.3402147 , -0.2917064 ]],

       [[-4.1255813 , -0.46747857],
        [-4.119181  , -0.4753293 ],
        [-4.1161637 , -0.48072743],
        ...,
        [-4.354258  , -0.28574163],
        [-4.3425856 , -0.28841752],
        [-4.325902  , -0.28950787]],

       ...,

       [[-3.9279518 , -0.39873332],
        [-3.928284  , -0.39801285],
        [-3.9285238 , -0.3969612 ],
        ...,
        [-3.8778856 , -0.68054926],
        [-3.8709984 , -0.6824064 ],
        [-3.8612819 , -0.6832982 ]],

       [[-3.9284027 , -0.39916244],
        [-3.9290943 , -0.39

In [2]:
output_from, output_to = gaussian_mixture.generate(
    msk_frs_seq[0], memo_flows_fwd[0], 10, 10, 20, 0
)

In [3]:
output_from[0]

(175.03506469726562, 267.6099853515625)

In [4]:
print(output_to[:10])

[(170.95513916015625, 267.1158447265625), (186.9873046875, 257.1708679199219), (15.52493667602539, 331.8918762207031), (191.7922821044922, 276.1449279785156), (61.568416595458984, 312.13775634765625), (158.15711975097656, 279.3189392089844), (146.9554901123047, 281.4184875488281), (27.887954711914062, 343.26544189453125), (169.16452026367188, 278.46148681640625), (186.54273986816406, 292.6495361328125)]


In [5]:
len(output_to)

1321

In [23]:
def generate_gaussian_plot(
    output: list[tuple[float, float]],
    height: int,
    width: int,
    colors,
    radius: float = 10.0,
    save_png=False,
    save_path="output.png",
    show=False,
    blur_radius=2.0,
):
    fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
    ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
    fig.add_axes(ax)

    ax.set_xlim(0, width)
    ax.set_ylim(0, height)
    ax.set_aspect("equal")

    ax.axis("off")

    xs, ys = zip(*output)
    ys = [height - y for y in ys]  # Flip y-coordinates

    ax.scatter(xs, ys, s=np.pi * radius**1.8, c=colors, marker="o", edgecolor="none")

    fig.canvas.draw()

    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)

    buf.seek(0)

    image_pil = Image.open(buf)
    image_pil = image_pil.convert("RGB")

    image_blurred = image_pil.filter(ImageFilter.GaussianBlur(blur_radius))

    if save_png:
        image_blurred.save(save_path, format="png")
    image_blurred = np.array(image_blurred)

    buf.close()

    if show:
        plt.imshow(image_blurred)
    else:
        plt.close(fig)
    return image_blurred

In [27]:
random.seed(0)
random_colors = np.random.rand(len(output_from), 3)

In [29]:
image_rgb = generate_gaussian_plot(
    output_from,
    540,
    960,
    random_colors,
    radius=10,
    save_png=True,
    save_path="cool.png",
    blur_radius=3.0,
)

In [30]:
image_rgb = generate_gaussian_plot(
    output_to,
    540,
    960,
    random_colors,
    radius=10,
    save_png=True,
    save_path="cool1.png",
    blur_radius=3.0,
)