In [None]:
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.lines import Line2D

from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms


if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("..")

from pollen_datasets.poleno.registry import register_condition_fn
from pollen_datasets.poleno import HolographyImageFolder, PairwiseHolographyImageFolder
from notebooks.notebook_helpers import plot_unit_sphere_with_orientation_arrows, orientation_to_basis, plot_dual_camera_planes

In [None]:
def load(config_file):
    with open(config_file, 'r') as stream:
        try:
            config = yaml.safe_load(stream)
            return config
        except yaml.YAMLError as exc:
            print(exc)
    
config = load("dummy_ldm_config.yaml")
dataset_config = config["dataset"]
condition_config = config["conditioning"]


# Transformations
transforms_list = []
transforms_list.append(transforms.ToTensor())

if dataset_config.get("img_interpolation"):
    transforms_list.append(
        transforms.Resize((dataset_config["img_interpolation"], 
                            dataset_config["img_interpolation"]),
                            interpolation = transforms.InterpolationMode.BILINEAR))
    
transforms_list.append(
    transforms.Normalize(
        [0.5] * dataset_config["img_channels"], 
        [0.5] * dataset_config["img_channels"]))

transform = transforms.Compose(transforms_list)

In [None]:
# from PIL import Image

# png = Image.open("./data/dummy_images/test_img_1.png").convert("RGBA")
# png.load() # required for png.split()

# background = Image.new("RGB", png.size, (255, 255, 255))
# background.paste(png, mask=png.split()[3]) # 3 is the alpha channel

# background.save("./data/dummy_images/test_img_1.png", 'PNG', quality=80)

In [None]:
@register_condition_fn("relative_viewpoint_rotation")
def relative_viewpoint_rotation(val0, val1, meta, r=0.55):
    """
    Zero-1-to-3 encoding for the relative angle between two orthogonal cameras,
    after applying a global rotation (0, 90, 180, 270 degrees).

    Output (in radians):
        [ Δθ, sin(φ), cos(φ), r ]
    """

    rot_deg = meta.get("rotation_deg", 0) % 360
    rot_rad = np.deg2rad(rot_deg)
    print(rot_deg, rot_rad)

    # Initial relative direction: left → vector (-1, 0)
    v = np.array([-1.0, 0.0], dtype=np.float32)

    # 2D rotation matrix
    R = np.array([
        [np.cos(rot_rad), -np.sin(rot_rad)],
        [np.sin(rot_rad),  np.cos(rot_rad)]
    ], dtype=np.float32)

    # rotated direction
    v_rot = R @ v   # (x', y')
    x, y = v_rot

    # Classify direction: horizontal vs vertical
    if abs(x) > abs(y):
        # left / right bordering
        phi = np.sign(x) * (np.pi / 2)   # -π/2 (left) or +π/2 (right)
        theta = 0.0                      # Δθ = 0 for horizontal
    else:
        # top / bottom bordering
        phi = 0.0                        # φ = 0 for vertical
        theta = np.sign(y) * (np.pi / 2) # -π/2 (top) or +π/2 (bottom)

    # image0 → always default
    out0 = np.array([0.0, np.sin(0.0), np.cos(0.0),  r], dtype=np.float32)

    # image1 → rotated value
    out1 = np.array([theta, np.sin(phi), np.cos(phi), r], dtype=np.float32)

    return out0, out1


In [None]:
from pollen_datasets.poleno.transforms import SwapRotate180, RotatePairKx90
from notebooks.notebook_helpers import plot_unit_sphere_with_orientation_arrows, plot_dual_camera_planes

# Dataset
dataset = PairwiseHolographyImageFolder(
    root=dataset_config["root"], 
    labels=dataset_config["labels_train"],
    dataset_cfg=dataset_config,
    cond_cfg=condition_config,
    verbose=True,
    transform=transform,
    pair_transform = RotatePairKx90(p=1)
)

In [None]:
loader = iter(dataset)
(img0, img1), (cond0, cond1), (filepath0, filepath1) = next(iter(loader))


meta0 = cond0["meta"]
meta1 = cond1["meta"]

def unnormalize(t, mean=0.5, std=0.5):
    return t * std + mean

img0 = unnormalize(img0)
img1 = unnormalize(img1)

img0 = img0.numpy().transpose(1, 2, 0)
img1 = img1.numpy().transpose(1, 2, 0)

plt.imshow(img0)
plt.show()
plt.imshow(img1)
plt.show()

In [None]:
# orientation0 = (0, np.sin(0), np.cos(0), 0.55)  # image 0
# orientation1 = (0, np.sin(-np.pi/2), np.cos(-np.pi/2), 0.55) # left bordering (default)
# orientation1 = (0, np.sin(np.pi/2), np.cos(np.pi/2), 0.5) # right bordering
# orientation1 = (np.pi/2, np.sin(0), np.cos(0), 0.5) # bottom bordering
# orientation1 = (-np.pi/2, np.sin(0), np.cos(0), 0.5) # top bordering

plot_dual_camera_planes(img0, img1,
                        orientation0=cond0["rotation"],
                        orientation1=cond1["rotation"],
                        stride=3,
                        invert_z=False)