In [1]:
import Pkg; 

cd(joinpath(@__DIR__, "../../"))
Pkg.activate("Project.toml")

using MorphoMol
using PyCall
using JLD2
using LinearAlgebra
using NearestNeighbors
using Rotations
using GLMakie

[32m[1m  Activating[22m[39m project at `~/Doktor/Code/MorphoMol/MorphoMolNotebooks`


In [126]:
function get_boundary_filtration(points, n_atoms_per_mol)
    py"""
    import oineus as oin
    import numpy as np
    import torch
    import diode

    def get_boundary_filtration(points, n_atoms_per_mol):
        points = np.asarray(points)
        simplices = diode.fill_alpha_shapes(points)
        fil = oin.Filtration_double([oin.Simplex_double(s[0], s[1]) for s in simplices])
        def is_multi(sigma):
            return len(set(v // n_atoms_per_mol for v in sigma.vertices)) >= 2
        fil = fil.subfiltration(is_multi)
        return fil
    """
    py"get_boundary_filtration"(points, n_atoms_per_mol)
end

function get_boundary_filtration_with_adjusted_filtration_values(points, n_atoms_per_mol)
    py"""
    import oineus as oin
    import numpy as np
    import torch
    import diode

    def get_boundary_filtration_with_adjusted_filtration_values(points, n_atoms_per_mol):
        points = np.asarray(points)
        simplices = diode.fill_alpha_shapes(points)
        fil = oin.Filtration_double([oin.Simplex_double(s[0], s[1]) for s in simplices])
        def is_multi(sigma):
            return len(set(v // n_atoms_per_mol for v in sigma.vertices)) >= 2
        fil = fil.subfiltration(is_multi)

        def assign_filtration_value(cell):
            parts = [v // n_atoms_per_mol for v in cell.vertices]
            n_parts = len(set(parts))
            p_agg = [np.array([0.0, 0.0, 0.0]) for _ in range(n_parts)]
            weights = [0 for _ in range(n_parts)]
            for i, p in enumerate(parts):
                p_agg[p] += points[cell.vertices[i]]
                weights[p] += 1
            bcs = [p_agg[i] / weights[i] for i in range(n_parts)]
            filtration_value = sum([np.linalg.norm(bcs[i] - bcs[j]) for i in range(n_parts) for j in range(i, n_parts)])
            cell.value = filtration_value
            return cell

        altered_cells = [assign_filtration_value(cell) for cell in fil.cells()]
        fil = oin.Filtration_double(altered_cells)
        return fil
    """
    py"get_boundary_filtration_with_adjusted_filtration_values"(points, n_atoms_per_mol)
end

get_boundary_filtration_with_adjusted_filtration_values (generic function with 1 method)

In [127]:
function get_multichromatic_edges_of_tetrahedron(verts; split = 1206)
    a,b,c,d = verts
    combinations = [[a,b], [a,c], [a,d], [b,c], [b,d], [c,d]]
    [e for e in combinations if div(e[1], split) != div(e[2], split)]
end

function get_barycenters_of_multichromatic_edges(edges, points)
    barycenters = Vector{Point3f}([])
    for e in edges
        a,b = e[1],e[2]
        bc = Point3f(0.5 * (points[a+1] + points[b+1]))
        push!(barycenters, bc)
    end
    if length(barycenters) == 4
        bc = Point3f(0.25 * (barycenters[1] + barycenters[2] + barycenters[3] + barycenters[4]))
        push!(barycenters, bc)
    end 
    barycenters
end

function get_faces_from_barycenters(barycenters)
    if length(barycenters) == 3
        faces = [2 3 1]
    elseif length(barycenters) == 5
        faces = [2 4 5; 3 1 5; 4 3 5; 1 2 5]
    end
    faces
end

get_faces_from_barycenters (generic function with 1 method)

In [128]:
@load "../../Data/collected_simulation_results/rwm_wp_2_6r7m/23.jld2"
mindex = argmin(output["Es"])
points = MorphoMol.Utilities.get_matrix_realization(output["states"][mindex], input["template_centers"])
points = [e for e in eachcol(hcat(points...))];
n_atoms_per_mol = length(input["template_radii"])
n_mol = Int(length(points) / n_atoms_per_mol)
fil = get_boundary_filtration_with_adjusted_filtration_values(points, n_atoms_per_mol);

In [138]:
f = Figure(fontsize = 7)
i_sc = LScene(f[1, 1])
max_v = sqrt(maximum([c.value for c in fil.cells()]))
min_v = sqrt(minimum([c.value for c in fil.cells()]))
for c in fil.cells()
    if length(c.vertices) == 4
        mce = get_multichromatic_edges_of_tetrahedron(c.vertices)
        bcs = get_barycenters_of_multichromatic_edges(mce, points)
        faces = get_faces_from_barycenters(bcs)
        #colors = [(b[3] - min_z) / (max_z - min_z) for b in bcs]
        colors = [c.value for _ in 1:length(bcs)]
        mesh!(i_sc, bcs, faces, color = colors, colorrange = (min_v, max_v), colormap = :viridis)
    end
end

p_sc = LScene(f[1, 2])
points3f = [Point3f(p) for p in points]
colors = vcat([[j for _ in 1:1206] for j in 1:n_mol]...)
scatter!(p_sc, points3f, markersize = conf_ms, color = colors, colormap = :rainbow)

display(f)

GLMakie.Screen(...)

In [70]:
for cell in fil.cells()
    parts = [Vector{Int}([]) for i in 1:n_mol]
    for v in cell.vertices
        push!(parts[div(v, n_atoms_per_mol) + 1], v)
    end
    cell.value = 0.0
end