In [1]:
import os

file_path = os.path.abspath(".")
execution_path = os.path.abspath(f"../../../../Binaries/Debug")

In [None]:
import os

os.chdir(file_path)
import sys

target_path = os.path.abspath(f"../../../../Binaries/Debug")
sys.path.append(target_path)
print(f"Added {target_path} to sys.path")
package_path = os.path.abspath(f"../python")
sys.path.append(package_path)
print(f"Added {package_path} to sys.path")

In [3]:
import glints.scratch_grid
import glints.renderer
import glints.test_utils as test_utils
import torch
import numpy as np
import pytest

In [None]:
import importlib

importlib.reload(glints.scratch_grid)

In [5]:
def linear_to_gamma(image):
    return image ** (1.0 / 2.2)


def gamma_to_linear(image):
    return image**2.2


def render_and_save_field(field, resolution, filename):
    r = glints.renderer.Renderer()
    vertices, indices = glints.renderer.plane_board_scene_vertices_and_indices()
    camera_position_np = np.array([4.0, 0.1, 2.5], dtype=np.float32)
    r.set_camera_position(camera_position_np)
    fov_in_degrees = 35
    r.set_perspective(
        np.pi * fov_in_degrees / 180.0, resolution[0] / resolution[1], 0.1, 1000.0
    )
    r.set_mesh(vertices, indices)
    r.set_light_position(torch.tensor([4.0, -0.1, 2.5], device="cuda"))

    r.set_width(torch.tensor([0.001], device="cuda"))

    image, sampled_mask = glints.scratch_grid.render_scratch_field(r, resolution, field)
    test_utils.save_image(image, resolution, filename)


import matplotlib.pyplot as plt


def optimize_field(
    field,
    renderer,
    resolution,
    target_image,
    loss_fn,
    regularization_loss_fn,
    regularizer,
    optimizer,
):
    def calculate_regularization_loss(field, regularization_loss_fn):
        divergence, smoothness = field.calc_divergence_smoothness()
        loss_divergence = regularization_loss_fn(
            divergence, torch.zeros_like(divergence)
        )
        loss_smoothness = regularization_loss_fn(
            smoothness, torch.zeros_like(smoothness)
        )
        return loss_divergence + loss_smoothness

    old_regularization_loss = None
    for i in range(150):
        regularizer.zero_grad()
        regularization_loss = calculate_regularization_loss(
            field, regularization_loss_fn
        )

        if i == 0:
            old_regularization_loss = regularization_loss.item()

        regularization_loss.backward()
        regularizer.step()
        field.fix_direction()

    for _ in range(500):
        optimizer.zero_grad()
        image, sampled_mask = glints.scratch_grid.render_scratch_field(
            renderer, resolution, field
        )
        loss_image = loss_fn(image, gamma_to_linear(target_image)) * 1000
        density_loss = torch.mean(
            torch.norm(field.field[sampled_mask].reshape(-1, 2), dim=1) * 0.01
        )
        total_loss = loss_image + density_loss
        total_loss.backward()
        optimizer.step()
        field.fill_masked_holes(sampled_mask)

        regularization_loss = torch.tensor(10000000000000.0)
        regularization_steps = 0

        if True:
            while regularization_loss.item() > old_regularization_loss * 0.1:
                regularizer.zero_grad()
                regularization_loss = calculate_regularization_loss(
                    field, regularization_loss_fn
                )
                regularization_loss.backward()
                regularizer.step()
                regularization_steps += 1

        print(
            "iteration:",
            _,
            "regularization_loss",
            regularization_loss.item(),
            "density_loss",
            density_loss.item(),
            "loss_image",
            loss_image.item(),
            "total_loss",
            total_loss.item(),
            "regularization_steps",
            regularization_steps,
        )

    field.fix_direction()


def save_images(field, resolution, divergence, smoothness):
    for i in range(field.field.shape[2]):
        test_utils.save_image(
            1000 * divergence[:, :, i], resolution, f"divergence_{i}.exr"
        )
        test_utils.save_image(
            100 * smoothness[:, :, i], resolution, f"smoothness_{i}.exr"
        )

        density = torch.norm(field.field[:, :, i], dim=2)
        directions = field.field[:, :, i] / density.unsqueeze(2)
        directions = torch.cat(
            [directions, torch.zeros_like(directions[:, :, :1])], dim=2
        )

        test_utils.save_image(directions, resolution, f"directions_{i}.exr")
        test_utils.save_image(density, resolution, f"density_{i}.exr")
        test_utils.save_image(field.field[:, :, i, :1], resolution, f"field_{i}.exr")

In [None]:
os.chdir(execution_path)
r = glints.renderer.Renderer()

vertices, indices = glints.renderer.plane_board_scene_vertices_and_indices()
camera_position_np = np.array([4.0, 0.0, 3.5], dtype=np.float32)
r.set_camera_position(camera_position_np)
fov_in_degrees = 35
resolution = [768 * 2, 512 * 2]
r.set_perspective(
    np.pi * fov_in_degrees / 180.0, resolution[0] / resolution[1], 0.1, 1000.0
)
r.set_mesh(vertices, indices)
r.set_light_position(torch.tensor([0.0, 4.0, 4.5], device="cuda"))

r.set_width(torch.tensor([0.001], device="cuda"))

field = glints.scratch_grid.ScratchField(512, 1)
image, sampled_mask = glints.scratch_grid.render_scratch_field(r, resolution, field)
test_utils.save_image(image, resolution, "scratch_field_initial.exr")
target_image = r.prepare_target("texture.png", resolution)

loss_fn = torch.nn.MSELoss()
regularization_loss_fn = torch.nn.HuberLoss()
regularizer = torch.optim.Adam([field.field], lr=0.005)
optimizer = torch.optim.Adam([field.field], lr=0.03)

optimize_field(
    field,
    r,
    resolution,
    target_image,
    loss_fn,
    regularization_loss_fn,
    regularizer,
    optimizer,
)

In [7]:
test_utils.save_image(gamma_to_linear(target_image), resolution, "target_image.exr")

In [None]:
divergence, smoothness = field.calc_divergence_smoothness()
save_images(field, resolution, divergence, smoothness)

image, sampled_mask = glints.scratch_grid.render_scratch_field(r, resolution, field)
test_utils.save_image(image, resolution, "scratch_field.exr")

directions = torch.rot90(field.field[:, :, 0, :2])


test_utils.plot_arrows(
    directions, "directions", spacing=16, scale=0.1, filename="directions.pdf"
)




In [9]:
field.field.requires_grad = False
field.field = field.field.detach()
torch.cuda.empty_cache()

In [None]:
line_counts = []
for density_held in np.linspace(0.5, 1.5, 9):
    lines = field.discretize_to_lines(density_held, 0.6)
    line_counts.append(lines.shape[0])

plt.plot(np.linspace(0.5, 1.5, 9), line_counts)
plt.xlabel("Density held")
plt.ylabel("Line count")
plt.savefig("line_count.pdf")
plt.show()

In [None]:
print(lines.shape)



In [14]:
torch.cuda.empty_cache()
r.set_type("bspline")
image, _ = r.render(resolution, lines)
test_utils.save_image(image, resolution, "lines.exr")